In [176]:
import torch
import torch.nn as nn
import math

In [177]:
class inputEmbeddingLayer(nn.Module):
    def __init__(self,vocab_size,emb_dim):
        super().__init__()
        self.vocab_size=vocab_size
        self.emb_dim=emb_dim
        self.embedding_layer=nn.Embedding(self.vocab_size,self.emb_dim,dtype=torch.float16)
    def forward(self,x):
        embeddings=self.embedding_layer(x)
        return embeddings*math.sqrt(self.emb_dim)

In [178]:
embedding_layer=inputEmbeddingLayer(10,6)

In [179]:
test_input=torch.tensor([[1,2,3],[4,5,6]])
input_embeddings=embedding_layer(test_input)
input_embeddings

tensor([[[-1.1660,  5.8320, -1.0488, -2.8320,  0.0495, -1.2549],
         [-2.5664,  0.9014,  1.9170, -4.7734,  1.9697, -0.9199],
         [ 0.1134, -1.7256, -1.6123,  2.3223,  2.4082, -0.6030]],

        [[-2.0723, -1.4678,  1.7715, -3.6016,  0.2427, -3.8945],
         [ 1.1396,  2.1289, -0.4321,  0.6802,  1.5361, -0.9561],
         [-1.2676, -0.5278, -0.1243, -0.2330,  2.8613, -0.7212]]],
       dtype=torch.float16, grad_fn=<MulBackward0>)

In [180]:
class positionalEncodingLayer(nn.Module):
    def __init__(self,max_seq_len,emb_dim,dropout):
        super().__init__()
        self.max_seq_len=max_seq_len
        self.emb_dim=emb_dim
        self.dropout=nn.Dropout(dropout)
        static_positional_info=torch.zeros((self.max_seq_len,self.emb_dim),dtype=torch.float16)
        positions=torch.arange(0,max_seq_len,dtype=torch.float16).reshape(max_seq_len,-1)
        indices_for_denominator=torch.arange(0,emb_dim,2,dtype=torch.float16) ### 2i
        denominators=torch.exp((-2*indices_for_denominator*math.log(1e4))/emb_dim)
        static_positional_info[:,0::2]=torch.sin(positions*denominators)
        static_positional_info[:,1::2]=torch.cos(positions*denominators)
        self.register_buffer('static_positional_info',static_positional_info)
    def forward(self,x):
        position_encoded_embedding=x+self.static_positional_info[:x.shape[1],:]
        dropped_embeddings=self.dropout(position_encoded_embedding)
        return dropped_embeddings


In [181]:
positional_encoding_layer=positionalEncodingLayer(3,6,0.3)

In [182]:
position_encoded_embedding=positional_encoding_layer(input_embeddings)
position_encoded_embedding

tensor([[[-0.0000,  9.7578, -1.4980, -2.6172,  0.0707, -0.3643],
         [-2.4648,  0.0000,  2.7422, -5.3906,  2.8145,  0.1144],
         [ 1.4609, -3.0605, -2.2988,  4.7461,  3.4414,  0.0000]],

        [[-2.9609, -0.6685,  2.5312, -3.7168,  0.3467, -4.1367],
         [ 0.0000,  3.8145, -0.0000,  0.0000,  0.0000,  0.0628],
         [-0.5122, -1.3496, -0.1714,  1.0957,  4.0898,  0.3984]]],
       dtype=torch.float16, grad_fn=<MulBackward0>)

In [183]:
class multiHeadAttentionBlock(nn.Module):
    def __init__(self,emb_dim,n_heads,dropout):
        super().__init__()
        assert emb_dim%n_heads==0 ### checking if multi head splitting is possible.
        self.w_q=nn.Linear(emb_dim,emb_dim,dtype=torch.float16)
        self.w_k=nn.Linear(emb_dim,emb_dim,dtype=torch.float16)
        self.w_v=nn.Linear(emb_dim,emb_dim,dtype=torch.float16)
        self.w_o=nn.Linear(emb_dim,emb_dim,dtype=torch.float16) ### multi-head-projection-layer
        self.emb_dim=emb_dim
        self.dropout=nn.Dropout(dropout)
        self.single_head_dim=self.emb_dim//n_heads
        self.n_heads=n_heads

    @staticmethod
    def contextual_embedding(m_q,m_k,m_v,per_head_emb_dim,mask):
        ### return contexual embedding and attention scores
        attention_scores=m_q@m_k.transpose(2,3)/math.sqrt(per_head_emb_dim)
        ##batch,head,seq,dim @ batch,head,dim,seq==batch,head,seq,seq
        if mask is not None:
            attention_scores.masked_fill_(mask,value=float('-inf'))
        normalized_attention_scores=torch.softmax(attention_scores,dim=-1)
        ### batch,head,seq,seq @ batch,head,seq,dim=batch,head,seq,dim
        contexual_embeddings=normalized_attention_scores@m_v
        return normalized_attention_scores,contexual_embeddings
    
    def forward(self,q,k,v,mask):
        query=self.w_q(q) ### batch, seqeunce, dim
        key=self.w_k(k)
        value=self.w_v(v)

        multihead_query=query.view(query.shape[0],query.shape[1],self.n_heads,self.single_head_dim).transpose(1,2)
        multihead_key=key.view(key.shape[0],key.shape[1],self.n_heads,self.single_head_dim).transpose(1,2)
        multihead_value=value.view(value.shape[0],value.shape[1],self.n_heads,self.single_head_dim).transpose(1,2)
        _,contextual_embeddings=multiHeadAttentionBlock.contextual_embedding(multihead_query,multihead_key,multihead_value,self.single_head_dim,mask)
        final_contextual_embeddings=contextual_embeddings.transpose(1,2).contiguous().view(value.shape[0],value.shape[1],self.n_heads*self.single_head_dim)
        multihead_final_contextual_embedding_proj=self.w_o(final_contextual_embeddings)
        dropped_multihead_final_contextual_embedding_proj=self.dropout(multihead_final_contextual_embedding_proj)
        return dropped_multihead_final_contextual_embedding_proj

In [185]:
Mlab= multiHeadAttentionBlock(6,2,0.3)

In [186]:
a=lambda x: Mlab(x,x,x,None)
mha_out=a(position_encoded_embedding)
mha_out

tensor([[[ 0.4644, -1.9199,  0.2600,  0.8218,  1.6543, -0.0000],
         [-0.1791, -1.7666,  0.6660,  0.7456,  1.8135,  0.0154],
         [-2.3477,  1.7959,  0.0000,  1.1494, -0.0000,  0.0000]],

        [[-0.2021,  0.5527, -0.1633,  0.0631, -0.4094,  0.0034],
         [-0.0000,  0.2959,  0.5791,  0.3767,  0.0000,  0.0000],
         [ 0.0000,  0.2181, -0.0000,  0.0187, -0.6797, -0.9395]]],
       dtype=torch.float16, grad_fn=<MulBackward0>)

In [187]:
class layerNormalizationBlock(nn.Module):
    def __init__(self,emb_dim,eps=1e-5):
        super().__init__()
        self.scale=nn.Parameter(torch.ones(emb_dim,dtype=torch.float16))
        self.shift=nn.Parameter(torch.zeros(emb_dim,dtype=torch.float16))
        self.eps=eps

    def forward(self,x):
        mean=x.mean(dim=-1,keepdim=True)
        standard_deviation=x.std(dim=-1,keepdim=True,unbiased=False)
        normalized_x=(x-mean)/(standard_deviation+self.eps)
        scale_n_shift=self.scale*normalized_x+self.shift
        return scale_n_shift

In [188]:
lnb=layerNormalizationBlock(6)

In [189]:
layer_normalized_out=lnb(mha_out)
layer_normalized_out

tensor([[[ 0.2305, -1.9590,  0.0428,  0.5586,  1.3242, -0.1959],
         [-0.3616, -1.8145,  0.4119,  0.4849,  1.4619, -0.1835],
         [-1.8955,  1.3135, -0.0771,  0.8130, -0.0771, -0.0771]],

        [[-0.5869,  1.9268, -0.4578,  0.2966, -1.2773,  0.0978],
         [-0.9272,  0.3879,  1.6475,  0.7471, -0.9272, -0.9272],
         [ 0.5444,  1.0596,  0.5444,  0.5884, -1.0615, -1.6758]]],
       dtype=torch.float16, grad_fn=<AddBackward0>)

In [190]:
class skipConnection(nn.Module):
    def __init__(self,dropout):
        super().__init__()
        self.dropout=nn.Dropout(dropout)

    def forward(self,x,sublayer):
        output=x+sublayer(x)
        dropped_output=self.dropout(output)
        return dropped_output

In [191]:
skipConnection=skipConnection(0.3)

In [196]:
skip_connections_output=skipConnection(position_encoded_embedding,a)
skip_connections_output

tensor([[[ 0.0000, 13.9375, -2.1406, -0.0000,  0.0000, -0.5205],
         [-3.7773, -2.5234,  3.9180, -6.6367,  6.6133,  0.1854],
         [-1.2666, -1.8066, -3.1191,  0.0000,  4.9180,  0.4897]],

        [[-4.2305, -0.1653,  3.3828, -5.2227,  0.0000, -0.0000],
         [-0.5723,  0.0000,  0.0000,  0.5381,  0.0000,  0.0000],
         [-0.7319, -1.6172, -1.1230,  1.5918,  4.8711, -0.7729]]],
       dtype=torch.float16, grad_fn=<MulBackward0>)

In [202]:
class feed_forward_block(nn.Module):
    ### Expansion Contraction layer.....
    def __init__(self,emb_dim,expand_dim,dropout):
        super().__init__()
        self.emb_dim=emb_dim
        self.expand_dim=expand_dim
        self.dropout=dropout
        self.network=nn.Sequential(
            nn.Linear(emb_dim,expand_dim,dtype=torch.float16),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(expand_dim,emb_dim,dtype=torch.float16),
        )
    def forward(self,x):
        output=self.network(x)
        return output

In [203]:
ffb=feed_forward_block(6,12,0.2)

In [204]:
ffb_output=ffb(skip_connections_output)
ffb_output

tensor([[[ 2.6777e+00,  1.0498e+00,  1.9746e+00, -1.1768e+00, -9.3408e-01,
           1.6172e+00],
         [ 4.2773e-01,  2.5195e-01, -2.7832e-01, -3.5864e-01,  3.8452e-01,
          -8.3301e-01],
         [ 3.5352e-01,  6.7578e-01, -4.8633e-01, -5.7959e-01,  1.0527e+00,
           3.7427e-01]],

        [[ 1.9946e-01, -1.4912e+00, -1.7651e-01, -9.8926e-01,  8.4863e-01,
          -8.2471e-01],
         [-2.0898e-01, -3.0420e-01,  3.7964e-02, -9.2316e-04,  4.3854e-02,
           1.6431e-01],
         [-2.4658e-01, -1.9458e-01, -8.0762e-01, -2.5586e-01,  1.1094e+00,
           3.7573e-01]]], dtype=torch.float16, grad_fn=<ViewBackward0>)

In [205]:
class encoderBlock(nn.Module):
    def __init__(self,emb_dim,n_heads,mha_dropout,expand_dim,ff_dropout,sk_dropout,mask):
        super().__init__()
        self.emb_dim=emb_dim
        self.n_heads=n_heads
        self.ff_dropout=ff_dropout
        self.expand_dim=expand_dim
        self.mha_dropout=mha_dropout
        self.sk_dropout=sk_dropout
        self.mask=mask
        self.layerNormalizationBlocks=[layerNormalizationBlock(self.emb_dim) for _ in range(2)]
        self.mha_block=multiHeadAttentionBlock(self.emb_dim,self.n_heads,self.mha_dropout)
        self.feed_forward_block=feed_forward_block(self.emb_dim,self.expand_dim,self.ff_dropout)
        self.skip_connections=[skipConnection(self.sk_dropout) for _ in range(2)]
    def forward(self,x):
        output1=self.skip_connections[0](x,lambda x: self.mha_block(x,x,x,self.mask))
        layer_normalized_output1=self.layerNormalizationBlocks[0](output1)
        output2=self.skip_connections[1](output1,self.feed_forward_block)
        layer_normalized_output2=self.layerNormalizationBlocks[1](output2)
        return layer_normalized_output2