In [374]:
import torch
import torch.nn as nn
import math

In [375]:
class inputEmbeddingLayer(nn.Module):
    def __init__(self,vocab_size,emb_dim):
        super().__init__()
        self.vocab_size=vocab_size
        self.emb_dim=emb_dim
        self.embedding_layer=nn.Embedding(self.vocab_size,self.emb_dim,dtype=torch.float16)
    def forward(self,x):
        embeddings=self.embedding_layer(x)
        return embeddings*math.sqrt(self.emb_dim)

In [376]:
embedding_layer=inputEmbeddingLayer(10,6)

In [377]:
test_input=torch.tensor([[1,2,3],[4,5,6]])
input_embeddings=embedding_layer(test_input)
input_embeddings

tensor([[[-1.0947, -4.6719, -2.4785,  1.9209, -0.7163,  1.7715],
         [ 1.7021, -1.1729,  1.3672, -0.1871, -0.7930, -4.7969],
         [ 2.8633,  2.8418, -1.0898, -1.4854,  1.1201, -0.5195]],

        [[ 0.7690, -1.3906, -4.7969,  5.4727, -2.4434, -1.7764],
         [-0.1520, -4.5078, -0.4961, -1.3623,  0.7603, -1.7178],
         [ 0.7769,  0.7603,  2.0234, -1.7139,  2.1465, -3.6973]]],
       dtype=torch.float16, grad_fn=<MulBackward0>)

In [342]:
class positionalEncodingLayer(nn.Module):
    def __init__(self,max_seq_len,emb_dim,dropout):
        super().__init__()
        self.max_seq_len=max_seq_len
        self.emb_dim=emb_dim
        self.dropout=nn.Dropout(dropout)
        static_positional_info=torch.zeros((self.max_seq_len,self.emb_dim),dtype=torch.float16)
        positions=torch.arange(0,max_seq_len,dtype=torch.float16).reshape(max_seq_len,-1)
        indices_for_denominator=torch.arange(0,emb_dim,2,dtype=torch.float16) ### 2i
        denominators=torch.exp((-2*indices_for_denominator*math.log(1e4))/emb_dim)
        static_positional_info[:,0::2]=torch.sin(positions*denominators)
        static_positional_info[:,1::2]=torch.cos(positions*denominators)
        self.register_buffer('static_positional_info',static_positional_info)
    def forward(self,x):
        position_encoded_embedding=x+self.static_positional_info[:x.shape[1],:]
        dropped_embeddings=self.dropout(position_encoded_embedding)
        return dropped_embeddings


In [343]:
positional_encoding_layer=positionalEncodingLayer(3,6,0.3)

In [344]:
position_encoded_embedding=positional_encoding_layer(input_embeddings)
position_encoded_embedding

tensor([[[ 6.7930, -0.0000, -3.4707, -0.1786,  0.0000, -0.6113],
         [ 1.5234,  0.0000, -0.0000,  4.4062,  3.3047, -2.2070],
         [ 7.6914, -2.3262, -2.8711, -4.4297,  1.3916,  3.4570]],

        [[ 0.0000, -0.3279,  0.0000,  0.0000,  4.3633,  2.4062],
         [ 2.5273,  0.0000, -1.1494,  0.0000, -0.0000,  2.8711],
         [-2.0020, -0.0000,  0.0000, -0.2791, -0.0000, -4.6094]]],
       dtype=torch.float16, grad_fn=<MulBackward0>)

In [345]:
class multiHeadAttentionBlock(nn.Module):
    def __init__(self,emb_dim,n_heads,dropout):
        super().__init__()
        assert emb_dim%n_heads==0 ### checking if multi head splitting is possible.
        self.w_q=nn.Linear(emb_dim,emb_dim,dtype=torch.float16)
        self.w_k=nn.Linear(emb_dim,emb_dim,dtype=torch.float16)
        self.w_v=nn.Linear(emb_dim,emb_dim,dtype=torch.float16)
        self.w_o=nn.Linear(emb_dim,emb_dim,dtype=torch.float16) ### multi-head-projection-layer
        self.emb_dim=emb_dim
        self.dropout=nn.Dropout(dropout)
        self.single_head_dim=self.emb_dim//n_heads
        self.n_heads=n_heads

    @staticmethod
    def contextual_embedding(m_q,m_k,m_v,per_head_emb_dim,mask):
        ### return contexual embedding and attention scores
        attention_scores=m_q@m_k.transpose(2,3)/math.sqrt(per_head_emb_dim)
        ##batch,head,seq,dim @ batch,head,dim,seq==batch,head,seq,seq
        if mask is not None:
            attention_scores.masked_fill_(mask,value=float('-inf'))
        normalized_attention_scores=torch.softmax(attention_scores,dim=-1)
        ### batch,head,seq,seq @ batch,head,seq,dim=batch,head,seq,dim
        contexual_embeddings=normalized_attention_scores@m_v
        return normalized_attention_scores,contexual_embeddings
    
    def forward(self,q,k,v,mask):
        query=self.w_q(q) ### batch, seqeunce, dim
        key=self.w_k(k)
        value=self.w_v(v)

        multihead_query=query.view(query.shape[0],query.shape[1],self.n_heads,self.single_head_dim).transpose(1,2)
        multihead_key=key.view(key.shape[0],key.shape[1],self.n_heads,self.single_head_dim).transpose(1,2)
        multihead_value=value.view(value.shape[0],value.shape[1],self.n_heads,self.single_head_dim).transpose(1,2)
        _,contextual_embeddings=multiHeadAttentionBlock.contextual_embedding(multihead_query,multihead_key,multihead_value,self.single_head_dim,mask)
        final_contextual_embeddings=contextual_embeddings.transpose(1,2).contiguous().view(value.shape[0],value.shape[1],self.n_heads*self.single_head_dim)
        multihead_final_contextual_embedding_proj=self.w_o(final_contextual_embeddings)
        dropped_multihead_final_contextual_embedding_proj=self.dropout(multihead_final_contextual_embedding_proj)
        return dropped_multihead_final_contextual_embedding_proj

In [346]:
Mlab= multiHeadAttentionBlock(6,2,0.3)

In [347]:
a=lambda x: Mlab(x,x,x,None)
mha_out=a(position_encoded_embedding)
mha_out

tensor([[[-1.2090,  0.6333,  1.3828,  0.9443, -0.8652,  0.0000],
         [-1.4961,  0.8101,  0.0000, -0.4099, -0.4036, -0.5015],
         [-1.1279,  0.1962,  1.0977,  1.0703, -0.5791,  0.5308]],

        [[-0.0000, -0.0000, -0.1019, -0.0274,  0.0000,  0.6685],
         [-0.2084, -0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
         [-0.0000, -0.0000,  0.0000, -0.6294,  0.0000,  0.0000]]],
       dtype=torch.float16, grad_fn=<MulBackward0>)

In [348]:
class layerNormalizationBlock(nn.Module):
    def __init__(self,emb_dim,eps=1e-5):
        super().__init__()
        self.scale=nn.Parameter(torch.ones(emb_dim,dtype=torch.float16))
        self.shift=nn.Parameter(torch.zeros(emb_dim,dtype=torch.float16))
        self.eps=eps

    def forward(self,x):
        mean=x.mean(dim=-1,keepdim=True)
        standard_deviation=x.std(dim=-1,keepdim=True,unbiased=False)
        normalized_x=(x-mean)/(standard_deviation+self.eps)
        scale_n_shift=self.scale*normalized_x+self.shift
        return scale_n_shift

In [349]:
lnb=layerNormalizationBlock(6)

In [350]:
layer_normalized_out=lnb(mha_out)
layer_normalized_out

tensor([[[-1.4453e+00,  5.1758e-01,  1.3164e+00,  8.4912e-01, -1.0791e+00,
          -1.5735e-01],
         [-1.6992e+00,  1.6719e+00,  4.8755e-01, -1.1169e-01, -1.0242e-01,
          -2.4548e-01],
         [-1.6152e+00, -2.2316e-03,  1.0957e+00,  1.0625e+00, -9.4727e-01,
           4.0552e-01]],

        [[-3.4399e-01, -3.4399e-01, -7.3389e-01, -4.4873e-01, -3.4399e-01,
           2.2148e+00],
         [-2.2363e+00,  4.4727e-01,  4.4727e-01,  4.4727e-01,  4.4727e-01,
           4.4727e-01],
         [ 4.4727e-01,  4.4727e-01,  4.4727e-01, -2.2344e+00,  4.4727e-01,
           4.4727e-01]]], dtype=torch.float16, grad_fn=<AddBackward0>)

In [351]:
class skipConnection(nn.Module):
    def __init__(self,dropout):
        super().__init__()
        self.dropout=nn.Dropout(dropout)

    def forward(self,x,sublayer):
        output=x+sublayer(x)
        dropped_output=self.dropout(output)
        return dropped_output

In [352]:
skipConnectionLayer=skipConnection(0.3)

In [353]:
skip_connections_output=skipConnectionLayer(position_encoded_embedding,a)
skip_connections_output

tensor([[[ 9.7031,  0.9048, -2.9824,  0.0000, -1.2363, -0.0000],
         [ 2.1758,  0.0000,  1.3047,  0.0000,  4.1445, -3.8711],
         [ 9.3750, -3.0449, -4.1016, -4.8008,  1.1611,  0.0000]],

        [[-0.3081, -0.0000, -0.0000,  0.0000,  6.2344,  4.3906],
         [ 3.3125,  0.0000, -1.5723, -0.0180,  0.0000,  5.0938],
         [-4.2383, -0.0000,  0.1888, -0.3987,  0.0000, -0.0000]]],
       dtype=torch.float16, grad_fn=<MulBackward0>)

In [354]:
class feed_forward_block(nn.Module):
    ### Expansion Contraction layer.....
    def __init__(self,emb_dim,expand_dim,dropout):
        super().__init__()
        self.emb_dim=emb_dim
        self.expand_dim=expand_dim
        self.dropout=dropout
        self.network=nn.Sequential(
            nn.Linear(emb_dim,expand_dim,dtype=torch.float16),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(expand_dim,emb_dim,dtype=torch.float16),
        )
    def forward(self,x):
        output=self.network(x)
        return output

In [355]:
ffb=feed_forward_block(6,12,0.2)

In [356]:
ffb_output=ffb(skip_connections_output)
ffb_output

tensor([[[ 2.2266e+00,  1.9902e+00, -1.6514e+00,  5.9131e-01, -2.5757e-01,
           4.2920e-01],
         [-2.1045e-01,  5.7764e-01, -8.7402e-01, -7.4316e-01,  6.8555e-01,
          -8.4717e-01],
         [ 2.4238e+00,  2.4414e+00, -2.1816e+00,  1.4893e+00, -5.0244e-01,
          -4.4580e-01]],

        [[-8.7109e-01,  1.0000e+00, -6.9824e-01, -7.1924e-01,  3.5742e-01,
          -1.1904e+00],
         [ 4.6191e-01,  1.4766e+00, -2.1191e-01, -2.3718e-01,  5.7220e-02,
           2.4768e-01],
         [-7.5531e-04, -7.1594e-02, -6.1670e-01, -4.2114e-01, -6.5625e-01,
          -6.0840e-01]]], dtype=torch.float16, grad_fn=<ViewBackward0>)

In [357]:
class encoderBlock(nn.Module):
    def __init__(self,emb_dim,n_heads,mha_dropout,expand_dim,ff_dropout,sk_dropout):
        super().__init__()
        self.emb_dim=emb_dim
        self.n_heads=n_heads
        self.ff_dropout=ff_dropout
        self.expand_dim=expand_dim
        self.mha_dropout=mha_dropout
        self.sk_dropout=sk_dropout
        self.mha_block=multiHeadAttentionBlock(self.emb_dim,self.n_heads,self.mha_dropout)
        self.feed_forward_block=feed_forward_block(self.emb_dim,self.expand_dim,self.ff_dropout)
        self.skip_connections=nn.ModuleList([skipConnection(self.sk_dropout) for _ in range(2)])
        self.layerNormalizationBlocks=nn.ModuleList([layerNormalizationBlock(self.emb_dim) for _ in range(2)])
    def forward(self,x,mask=None):
        output1=self.skip_connections[0](x,lambda x: self.mha_block(x,x,x,mask))
        layer_normalized_output1=self.layerNormalizationBlocks[0](output1)
        output2=self.skip_connections[1](layer_normalized_output1,self.feed_forward_block)
        layer_normalized_output2=self.layerNormalizationBlocks[1](output2)
        return layer_normalized_output2

In [358]:
enc_blk=encoderBlock(6,2,0.3,12,0.3,0.3)
enc_blk

encoderBlock(
  (mha_block): multiHeadAttentionBlock(
    (w_q): Linear(in_features=6, out_features=6, bias=True)
    (w_k): Linear(in_features=6, out_features=6, bias=True)
    (w_v): Linear(in_features=6, out_features=6, bias=True)
    (w_o): Linear(in_features=6, out_features=6, bias=True)
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (feed_forward_block): feed_forward_block(
    (network): Sequential(
      (0): Linear(in_features=6, out_features=12, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.3, inplace=False)
      (3): Linear(in_features=12, out_features=6, bias=True)
    )
  )
  (skip_connections): ModuleList(
    (0-1): 2 x skipConnection(
      (dropout): Dropout(p=0.3, inplace=False)
    )
  )
  (layerNormalizationBlocks): ModuleList(
    (0-1): 2 x layerNormalizationBlock()
  )
)

In [359]:
enc_blk(ffb_output)

tensor([[[ 0.2754, -0.5693, -1.8574,  1.2549,  0.1554,  0.7417],
         [ 0.6211,  0.4089, -1.6514, -0.9795,  1.2910,  0.3101],
         [ 1.3115, -0.9668, -1.4072,  1.1611, -0.1322,  0.0338]],

        [[-0.8374,  1.6689, -1.4648,  0.0141,  0.6060,  0.0141],
         [-0.4473,  2.2363, -0.4473, -0.4473, -0.4473, -0.4473],
         [-1.4053, -0.7671, -0.0127,  1.8301,  0.2888,  0.0662]]],
       dtype=torch.float16, grad_fn=<AddBackward0>)

In [360]:
class encoder(nn.Module):
    def __init__(self,no_of_enc_blk,emb_dim,n_heads,mha_dropout,expand_dim,ff_dropout,sk_dropout):
        super().__init__()
        self.enc_blks=nn.ModuleList([encoderBlock(emb_dim,n_heads,mha_dropout,expand_dim,ff_dropout,sk_dropout) for _ in range(no_of_enc_blk)])
    def forward(self,x,mask=None):
        for blk in self.enc_blks:
            x=blk(x,mask)
        return x

In [361]:
enc=encoder(12,6,2,0.3,12,0.3,0.3)
enc

encoder(
  (enc_blks): ModuleList(
    (0-11): 12 x encoderBlock(
      (mha_block): multiHeadAttentionBlock(
        (w_q): Linear(in_features=6, out_features=6, bias=True)
        (w_k): Linear(in_features=6, out_features=6, bias=True)
        (w_v): Linear(in_features=6, out_features=6, bias=True)
        (w_o): Linear(in_features=6, out_features=6, bias=True)
        (dropout): Dropout(p=0.3, inplace=False)
      )
      (feed_forward_block): feed_forward_block(
        (network): Sequential(
          (0): Linear(in_features=6, out_features=12, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.3, inplace=False)
          (3): Linear(in_features=12, out_features=6, bias=True)
        )
      )
      (skip_connections): ModuleList(
        (0-1): 2 x skipConnection(
          (dropout): Dropout(p=0.3, inplace=False)
        )
      )
      (layerNormalizationBlocks): ModuleList(
        (0-1): 2 x layerNormalizationBlock()
      )
    )
  )
)

In [362]:
enc_output=enc(ffb_output)
enc_output

tensor([[[-0.1231, -0.1997, -1.8223,  1.2754, -0.1231,  0.9932],
         [ 0.1295,  1.3604, -1.9346, -0.0135, -0.1492,  0.6079],
         [-0.1536, -1.5244,  0.0087,  1.8965, -0.2368,  0.0087]],

        [[-1.4580, -0.3865, -0.3865, -0.3865,  1.3184,  1.3008],
         [-0.4028,  2.2168, -0.2942, -0.3960, -0.4028, -0.7202],
         [-1.6631,  0.7266,  0.3240,  1.0186,  0.6572, -1.0625]]],
       dtype=torch.float16, grad_fn=<AddBackward0>)

In [363]:
class decoderBlock(nn.Module):
    def __init__(self,emb_dim,n_heads,mha_dropout,expand_dim,ff_dropout,sk_dropout):
        super().__init__()
        self.emb_dim=emb_dim
        self.n_heads=n_heads
        self.ff_dropout=ff_dropout
        self.expand_dim=expand_dim
        self.mha_dropout=mha_dropout
        self.sk_dropout=sk_dropout
        self.mha_block1=multiHeadAttentionBlock(self.emb_dim,self.n_heads,self.mha_dropout) ### casual attention block
        self.mha_block2=multiHeadAttentionBlock(self.emb_dim,self.n_heads,self.mha_dropout) #### cross attention block
        self.feed_forward_block=feed_forward_block(self.emb_dim,self.expand_dim,self.ff_dropout)
        self.skip_connections=nn.ModuleList([skipConnection(self.sk_dropout) for _ in range(3)])
        self.layerNormalizationBlocks=nn.ModuleList([layerNormalizationBlock(self.emb_dim) for _ in range(3)])
    def forward(self,x,enc_out,mask1=None,mask2=None):
        output1=self.skip_connections[0](x,lambda x: self.mha_block1(x,x,x,mask1))
        layer_normalized_output1=self.layerNormalizationBlocks[0](output1)
        output2=self.skip_connections[1](layer_normalized_output1,lambda x: self.mha_block2(x,enc_out,enc_out,mask2))
        layer_normalized_output2=self.layerNormalizationBlocks[1](output2)
        output3=self.skip_connections[2](layer_normalized_output2,self.feed_forward_block)
        layer_normalized_output3=self.layerNormalizationBlocks[2](output3)
        return layer_normalized_output3

In [364]:
dec_blk=decoderBlock(6,2,0.3,12,0.3,0.3)
dec_blk

decoderBlock(
  (mha_block1): multiHeadAttentionBlock(
    (w_q): Linear(in_features=6, out_features=6, bias=True)
    (w_k): Linear(in_features=6, out_features=6, bias=True)
    (w_v): Linear(in_features=6, out_features=6, bias=True)
    (w_o): Linear(in_features=6, out_features=6, bias=True)
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (mha_block2): multiHeadAttentionBlock(
    (w_q): Linear(in_features=6, out_features=6, bias=True)
    (w_k): Linear(in_features=6, out_features=6, bias=True)
    (w_v): Linear(in_features=6, out_features=6, bias=True)
    (w_o): Linear(in_features=6, out_features=6, bias=True)
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (feed_forward_block): feed_forward_block(
    (network): Sequential(
      (0): Linear(in_features=6, out_features=12, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.3, inplace=False)
      (3): Linear(in_features=12, out_features=6, bias=True)
    )
  )
  (skip_connections): ModuleList(
    (0-2): 3 x skipConnection(

In [365]:
dec_blk_out=dec_blk(ffb_output,ffb_output)
dec_blk_out

tensor([[[ 1.6221, -0.4219, -1.6514,  0.6655, -0.1063, -0.1063],
         [-0.8452,  2.1895, -0.3579, -0.2430, -0.5000, -0.2430],
         [ 1.8447,  0.1548, -0.1047, -0.1047, -1.5840, -0.2065]],

        [[ 0.1368,  0.8843, -1.2129,  0.2515,  1.3232, -1.3828],
         [ 1.9453, -0.0217, -0.2496, -0.3293,  0.0839, -1.4287],
         [-0.1083,  0.3723, -1.7051,  1.6074,  0.3306, -0.4980]]],
       dtype=torch.float16, grad_fn=<AddBackward0>)

In [366]:
class decoder(nn.Module):
    def __init__(self,no_of_dec_blk,emb_dim,n_heads,mha_dropout,expand_dim,ff_dropout,sk_dropout):
        super().__init__()
        self.dec_blks=nn.ModuleList([decoderBlock(emb_dim,n_heads,mha_dropout,expand_dim,ff_dropout,sk_dropout) for _ in range(no_of_dec_blk)])
    def forward(self,x,mask=None):
        for blk in self.dec_blks:
            x=blk(x,mask)
        return x

In [367]:
dec=decoder(12,6,2,0.3,12,0.3,0.3)
dec

decoder(
  (dec_blks): ModuleList(
    (0-11): 12 x decoderBlock(
      (mha_block1): multiHeadAttentionBlock(
        (w_q): Linear(in_features=6, out_features=6, bias=True)
        (w_k): Linear(in_features=6, out_features=6, bias=True)
        (w_v): Linear(in_features=6, out_features=6, bias=True)
        (w_o): Linear(in_features=6, out_features=6, bias=True)
        (dropout): Dropout(p=0.3, inplace=False)
      )
      (mha_block2): multiHeadAttentionBlock(
        (w_q): Linear(in_features=6, out_features=6, bias=True)
        (w_k): Linear(in_features=6, out_features=6, bias=True)
        (w_v): Linear(in_features=6, out_features=6, bias=True)
        (w_o): Linear(in_features=6, out_features=6, bias=True)
        (dropout): Dropout(p=0.3, inplace=False)
      )
      (feed_forward_block): feed_forward_block(
        (network): Sequential(
          (0): Linear(in_features=6, out_features=12, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.3, inplace=False)
       

In [368]:
dec_output=dec(ffb_output,ffb_output)
dec_output

tensor([[[-0.5117,  1.6250, -1.4395, -0.0368,  0.8745, -0.5117],
         [ 1.5742,  0.5850, -1.7383, -0.0119, -0.0119, -0.3965],
         [-0.3464, -1.3574, -0.3464,  1.8770, -0.3464,  0.5195]],

        [[ 0.2954,  1.3076, -1.8916,  0.4050,  0.4192, -0.5347],
         [ 0.2145,  0.2145, -1.5996, -0.7158,  1.6719,  0.2145],
         [ 1.3438,  0.6050,  0.0032,  0.0032, -1.9580,  0.0032]]],
       dtype=torch.float16, grad_fn=<AddBackward0>)

In [369]:
class finalProjectionLayer(nn.Module):
    def __init__(self,emb_dim,vocab_size):
        super().__init__()
        self.linear=nn.Linear(emb_dim,vocab_size)
    
    def forward(self,x):
        output=self.linear(x)
        return output
###batch,seq,vocab

In [None]:
class transformers(nn.Module):
    def __init__(self,model_config,tokenizer_config):
        super().__init__()
        self.encoder_emb_layer=inputEmbeddingLayer(tokenizer_config['vocab_size'],model_config['enc_cfg']['emb_dim'])
        self.enc_positional_emb_layer=positionalEncodingLayer(model_config['enc_max_seq_len'],model_config['enc_cfg']['emb_dim'],model_config['enc_cfg']['pos_emb_dropout'])
        self.encoder=encoder(
            no_of_enc_blk=model_config['enc_cfg']['no_of_enc_blk'],
            emb_dim=model_config['enc_cfg']['emb_dim'],
            n_heads=model_config['enc_cfg']['n_heads'],
            mha_dropout=model_config['enc_cfg']['mha_dropout'],
            expand_dim=model_config['enc_cfg']['expand_dim'],
            ff_dropout=model_config['enc_cfg']['ff_dropout'],
            sk_dropout=model_config['enc_cfg']['sk_dropout']
        )
        
        self.decoder_emb_layer=inputEmbeddingLayer(tokenizer_config['vocab_size'],model_config['dec_cfg']['emb_dim'])
        self.dec_positional_emb_layer=positionalEncodingLayer(model_config['dec_max_seq_len'],model_config['dec_cfg']['emb_dim'],model_config['dec_cfg']['pos_emb_dropout'])
        self.decoder=decoder(
            no_of_dec_blk=model_config['dec_cfg']['no_of_dec_blk'],
            emb_dim=model_config['dec_cfg']['emb_dim'],
            n_heads=model_config['dec_cfg']['n_heads'],
            mha_dropout=model_config['dec_cfg']['mha_dropout'],
            expand_dim=model_config['dec_cfg']['expand_dim'],
            ff_dropout=model_config['dec_cfg']['ff_dropout'],
            sk_dropout=model_config['dec_cfg']['sk_dropout']
        )
        self.decoder_final_projection=finalProjectionLayer(model_config['dec_cfg']['emb_dim'],tokenizer_config['vocab_size'])

    def encode(self,x,mask=None):
        encoder_input_embedding=self.encoder_emb_layer(x)
        positional_encoded_input_embedding=self.enc_positional_emb_layer(encoder_input_embedding)
        encoder_contexual_embedding=self.encoder(positional_encoded_input_embedding,mask)
        return encoder_contexual_embedding

    def decode(self,x,encoder_output,mask1=None,mask2=None):
        decoder_input_embedding=self.decoder_emb_layer(x)
        positional_encoded_input_embedding=self.dec_positional_emb_layer(decoder_input_embedding)
        decoder_contexual_embedding=self.decoder(positional_encoded_input_embedding,encoder_output,mask1,mask2)
        final_output=self.decoder_final_projection(decoder_contexual_embedding)
        return final_output
    
    ###forward will be used during training.
    def forward(self,encoder_input,decoder_input,src_mask,tgt_mask):
        encoder_output=self.encode(encoder_input,src_mask)
        decoder_output=self.decode(decoder_input,encoder_output,src_mask,tgt_mask)
        return decoder_output

In [371]:
# Model configuration dictionary
model_config = {
    "enc_max_seq_len": 128,   # Max source sequence length
    "dec_max_seq_len": 128,   # Max target sequence length
    "enc_cfg": {
        "emb_dim": 512,
        "no_of_enc_blk": 6,
        "n_heads": 8,
        "pos_emb_dropout":0.1,
        "mha_dropout": 0.1,
        "expand_dim": 2048,
        "ff_dropout": 0.1,
        "sk_dropout": 0.1
    },
    "dec_cfg": {
        "emb_dim": 512,
        "no_of_dec_blk": 6,
        "pos_emb_dropout":0.1,
        "n_heads": 8,
        "mha_dropout": 0.1,
        "expand_dim": 2048,
        "ff_dropout": 0.1,
        "sk_dropout": 0.1
    }
}

# Tokenizer configuration dictionary
tokenizer_config = {
    "vocab_size": 32000,  # Vocabulary size of source & target tokenizer
    "pad_token_id": 0,
    "bos_token_id": 1,
    "eos_token_id": 2
}

In [372]:
import json

# Save
with open("model_config.json", "w") as f:
    json.dump(model_config, f, indent=4)

with open("tokenizer_config.json", "w") as f:
    json.dump(tokenizer_config, f, indent=4)

# Load
with open("model_config.json", "r") as f:
    model_config = json.load(f)

with open("tokenizer_config.json", "r") as f:
    tokenizer_config = json.load(f)


In [378]:
Transformer=transformers(model_config,tokenizer_config)
Transformer

transformers(
  (encoder_emb_layer): inputEmbeddingLayer(
    (embedding_layer): Embedding(32000, 512)
  )
  (enc_positional_emb_layer): positionalEncodingLayer(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): encoder(
    (enc_blks): ModuleList(
      (0-5): 6 x encoderBlock(
        (mha_block): multiHeadAttentionBlock(
          (w_q): Linear(in_features=512, out_features=512, bias=True)
          (w_k): Linear(in_features=512, out_features=512, bias=True)
          (w_v): Linear(in_features=512, out_features=512, bias=True)
          (w_o): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward_block): feed_forward_block(
          (network): Sequential(
            (0): Linear(in_features=512, out_features=2048, bias=True)
            (1): ReLU()
            (2): Dropout(p=0.1, inplace=False)
            (3): Linear(in_features=2048, out_features=512, bias=True)
          )
        )
   