In [None]:
import tensorflow as tf
from tensorflow import matmul,math
from numpy import random
from math import sqrt
import random
from keras.backend import softmax
#from tensorflow.keras.layers import Layer
import numpy as np
from tensorflow.keras.layers import LayerNormalization, Layer, Dense, ReLU, Dropout

class AttentionDotProduct(Layer):
    
    # def __init__(self,**kwargs):
    #     super(AttentionDotProduct, self).__init__(**kwargs)
    
    def call(self,queries,keys,values,d_k):
        scores = matmul(queries,keys,transpose_b=True)/sqrt(d_k)
        weights = tf.keras.activations.softmax(scores)
        return matmul(weights,values)
        

In [None]:
#Testing code
d_q = 64
d_k = 64 #dim of key matrix
d_v = 64 #dim of value matrix
input_len = 5 #Length of input sequence
batch_size = 64
queries = np.random.random((batch_size,input_len,d_q))
keys = np.random.random((batch_size,input_len,d_k))
values = np.random.random((batch_size,input_len,d_v))

attention = AttentionDotProduct()
result = attention(queries,keys,values,d_k)
print(result)

In [63]:
print(tf.__version__)

2.13.0


In [None]:
class FeedForward(Layer):
    
    def __init__(self,d_ff,d_model,**kwargs):
        self.fullyconnected1 = Dense(d_ff) 
        self.fullyconnected2 = Dense(d_model)
        self.activation = ReLU()
        
    def call(self,x):
        
        x_fc = self.fullyconnected1(x)
        
        return self.fullyconnected2(self.activation(x_fc))

In [None]:
class AddNormalization(Layer):
    def __init__(self, **kwargs):
        super(AddNormalization, self).__init__(**kwargs)
        self.layer_norm = LayerNormalization() 
        
    def call(self, x, sublayer_x):
        add = x + sublayer_x
 
    # Apply layer normalization to the sum
        return self.layer_norm(add)

In [None]:
class EncoderLayer(Layer):
    def __init__(self, h, d_k, d_v, d_model, d_ff, rate, **kwargs):
        super(EncoderLayer, self).__init__(**kwargs)
        self.multihead_attention = MultiHeadAttention(h, d_k, d_v, d_model)
        self.dropout1 = Dropout(rate)
        self.add_norm1 = AddNormalization()
        self.feed_forward = FeedForward(d_ff, d_model)
        self.dropout2 = Dropout(rate)
        self.add_norm2 = AddNormalization()
        
        
    def call(self, x, padding_mask, training):
        # Multi-head attention layer
        multihead_output = self.multihead_attention(x, x, x, padding_mask)
    # Expected output shape = (batch_size, sequence_length, d_model)
 
    # Add in a dropout layer
        multihead_output = self.dropout1(multihead_output, training=training)
 
    # Followed by an Add & Norm layer
        addnorm_output = self.add_norm1(x, multihead_output)
    # Expected output shape = (batch_size, sequence_length, d_model)
 
    # Followed by a fully connected layer
        feedforward_output = self.feed_forward(addnorm_output)
    # Expected output shape = (batch_size, sequence_length, d_model)
 
    # Add in another dropout layer
        feedforward_output = self.dropout2(feedforward_output, training=training)
 
    # Followed by another Add & Norm layer
        return self.add_norm2(addnorm_output, feedforward_output)

In [None]:
class Encoder(Layer):
    def __init__(self, vocab_size, sequence_length, h, d_k, d_v, d_model, d_ff, n, rate, **kwargs):
        super(Encoder, self).__init__(**kwargs)
        self.pos_encoding = PositionEmbeddingFixedWeights(sequence_length, vocab_size, d_model)
        self.dropout = Dropout(rate)
        self.encoder_layer = [EncoderLayer(h, d_k, d_v, d_model, d_ff, rate) for _ in range(n)]
 
    def call(self, input_sentence, padding_mask, training):
        pos_encoding_output = self.pos_encoding(input_sentence)
        # Expected output shape = (batch_size, sequence_length, d_model)
 
        # Add in a dropout layer
        x = self.dropout(pos_encoding_output, training=training)
 
        # Pass on the positional encoded values to each encoder layer
        for i, layer in enumerate(self.encoder_layer):
            x = layer(x, padding_mask, training)
 
        return x