In [1]:
import numpy as np
import tensorflow as tf

### Classes

In [46]:


def scaled_dot_product(Q, K, V, mask=None):
    dk = K.shape[-1] * 1.0 ## To convert the dimension integer number to flaot
    print(f"K Shape: {K.shape}")
    attention_scores = tf.matmul(Q, tf.transpose(K, perm=[0, 1, 3, 2])) / tf.sqrt(dk)
    if mask is not None:
        attention_scores += mask
    attention_weights = tf.nn.softmax(attention_scores, axis=-1)
    attention = attention_weights @ V
    
    return attention_weights, attention

class PositionalEncoding(tf.Module):
    def __init__(self, max_seq_len, token_dim):
        self.max_seq_len = max_seq_len
        self.token_dim = token_dim
        
    def __call__(self):
        even_denom = tf.math.pow(10000,  tf.range(0, self.token_dim, 2)/ token_dim)
        odd_denom = tf.math.pow(10000,  tf.range(1, self.token_dim, 2)/ token_dim)
        position = tf.expand_dims(tf.range(0, self.max_seq_len, 1, dtype=tf.float64),-1)
        pe_even = tf.sin(position / even_denom)
        pe_odd = tf.cos(position / odd_denom)
        pe = tf.stack([pe_even, pe_odd], axis=2)
        pe = tf.reshape(pe, [self.max_seq_len, self.token_dim])
        
        return pe

class MultiheadAttention(tf.Module):
    
    def __init__(self, nheads, sequence_len, seq_dim):
        super().__init__(name=None)
        
        self.nheads = nheads
        self.sequence_len = sequence_len
        self.seq_dim = seq_dim
        self.head_dim = self.seq_dim // self.nheads
        self.qkv_net = tf.keras.layers.Dense(3 * self.seq_dim)
        self.fcnn = tf.keras.layers.Dense(self.seq_dim)
        self.attention_weights = None
        
        
    def __call__(self, X, mask=None):
        
        ## Create a network to generate QKV matrices
        assert self.seq_dim == X.shape[-1], "seaquence dimension given and sequence dimension in the input data is not matching "
        print(f"X shape: {X.shape}")
        QKV = self.qkv_net(X)
        print(f"QKV shape from qkv net: {QKV.shape}")
        QKV = tf.reshape(QKV, QKV.shape[:-1]+[self.nheads, QKV.shape[-1] // self.nheads])
        QKV = tf.transpose(QKV, perm=[0,2,1,3])
        print(f"QKV shape after heads added: {QKV.shape}")
        Q, K, V = tf.split(QKV, 3, axis=-1)
        print(f"QKV shape individually: {(Q.shape, K.shape, V.shape)}")
        self.attention_weights, attention_embeddings = scaled_dot_product(Q, K, V, mask)
        print(f"Attention weights and embeddings shape: {(self.attention_weights.shape, attention_embeddings.shape)}")
        batch_size = attention_embeddings.shape[0]
        attention_embeddings = tf.reshape(attention_embeddings, shape=[batch_size,self.sequence_len,self.nheads*self.head_dim])
        print(f"Attention embeddings shape before NN: {(self.attention_weights.shape, attention_embeddings.shape)}")
        attention_out  = self.fcnn(attention_embeddings)
        print(f"final attention block output shape: {attention_out.shape}")
        return attention_out
        
        ## Break that into 
        
        
class PositionwiseFeedForward(tf.Module):

    def __init__(self, d_model, hidden, drop_prob):
        super().__init__(name=None)
        self.linear1 = tf.keras.layers.Dense(hidden)
        self.linear2 = tf.keras.layers.Dense(d_model)
        self.relu = tf.keras.layers.ReLU()
        self.dropout = tf.keras.layers.Dropout(drop_prob)

    def __call__(self, x):
        x = self.linear1(x)
        print(f"x after first linear layer: {x.shape}")
        x = self.relu(x)
        print(f"x after activation: {x.shape}")
        x = self.dropout(x)
        print(f"x after dropout: {x.shape}")
        x = self.linear2(x)
        print(f"x after 2nd linear layer: {x.shape}")
        return x    
    
class EncoderBlock(tf.Module):
    def __init__(self, nheads, sequence_len, seq_dim, hidden_units, drop_prob):
        super().__init__(name=None)
        self.layer_norm = tf.keras.layers.LayerNormalization()
        self.mha = MultiheadAttention( nheads, sequence_len, seq_dim)
        self.pff = PositionwiseFeedForward(seq_dim, hidden_units, drop_prob)
        
    def __call__(self, X):
        X_A = layer_norm(X)
        X_A = mha(X_A)
        X_A = X + X_A
        X_F = layer_norm(X_A)
        X_F = pff(X_F)
        X_out = X_A + X_F
        
        return X_out
        
class Encoder(tf.Module):
    def __init__(self, nblocks, nheads, sequence_len, seq_dim, hidden_units, drop_prob=0.1 ):
        self.encoder_blocks = [EncoderBlock(nheads, sequence_len, seq_dim, hidden_units, drop_prob) for _ in range(nblocks)]
        
    def __call__(self, X):
        for encoder_block in self.encoder_blocks:
            X = encoder_block(X)
        return X
        

In [47]:


batch_size = 16
seq_len = 10
token_dim = 512
heads = 8

X = tf.random.normal(shape=(16,seq_len,token_dim))

encoder = Encoder(6, heads, seq_len, token_dim, 1024, 0.1)
encoder_out = encoder(X)
print(f"encoder output shape: {encoder_out.shape}")

X shape: (16, 10, 512)
QKV shape from qkv net: (16, 10, 1536)
QKV shape after heads added: (16, 8, 10, 192)
QKV shape individually: (TensorShape([16, 8, 10, 64]), TensorShape([16, 8, 10, 64]), TensorShape([16, 8, 10, 64]))
K Shape: (16, 8, 10, 64)
Attention weights and embeddings shape: (TensorShape([16, 8, 10, 10]), TensorShape([16, 8, 10, 64]))
Attention embeddings shape before NN: (TensorShape([16, 8, 10, 10]), TensorShape([16, 10, 512]))
final attention block output shape: (16, 10, 512)
x after first linear layer: (16, 10, 1024)
x after activation: (16, 10, 1024)
x after dropout: (16, 10, 1024)
x after 2nd linear layer: (16, 10, 512)
X shape: (16, 10, 512)
QKV shape from qkv net: (16, 10, 1536)
QKV shape after heads added: (16, 8, 10, 192)
QKV shape individually: (TensorShape([16, 8, 10, 64]), TensorShape([16, 8, 10, 64]), TensorShape([16, 8, 10, 64]))
K Shape: (16, 8, 10, 64)
Attention weights and embeddings shape: (TensorShape([16, 8, 10, 10]), TensorShape([16, 8, 10, 64]))
Atten

In [23]:
X_out.shape

TensorShape([16, 10, 512])