In [151]:
import torch
import torch.nn as nn
import math

In [152]:
class inputEmbeddingLayer(nn.Module):
    def __init__(self,vocab_size,emb_dim):
        super().__init__()
        self.vocab_size=vocab_size
        self.emb_dim=emb_dim
        self.embedding_layer=nn.Embedding(self.vocab_size,self.emb_dim,dtype=torch.float16)
    def forward(self,x):
        embeddings=self.embedding_layer(x)
        return embeddings*math.sqrt(self.emb_dim)

In [153]:
embedding_layer=inputEmbeddingLayer(10,6)

In [154]:
test_input=torch.tensor([[1,2,3],[4,5,6]])
input_embeddings=embedding_layer(test_input)
input_embeddings

tensor([[[-0.4856, -3.7363, -0.5425,  0.2228, -4.3164,  2.9336],
         [-2.4258, -0.3982, -5.5391, -0.9116, -1.2793, -2.8398],
         [ 2.9238, -1.1357, -1.4229,  2.4688,  0.2939,  3.7324]],

        [[ 4.7891, -6.4414,  0.9263, -2.6855,  0.6948,  1.1211],
         [ 1.1445, -1.2451,  2.8418,  0.6812,  1.4258,  0.9746],
         [ 3.5156, -1.3770, -1.2559,  0.4175, -0.7314,  2.2715]]],
       dtype=torch.float16, grad_fn=<MulBackward0>)

In [155]:
class positionalEncodingLayer(nn.Module):
    def __init__(self,max_seq_len,emb_dim,dropout):
        super().__init__()
        self.max_seq_len=max_seq_len
        self.emb_dim=emb_dim
        self.dropout=nn.Dropout(dropout)
        static_positional_info=torch.zeros((self.max_seq_len,self.emb_dim),dtype=torch.float16)
        positions=torch.arange(0,max_seq_len,dtype=torch.float16).reshape(max_seq_len,-1)
        indices_for_denominator=torch.arange(0,emb_dim,2,dtype=torch.float16) ### 2i
        denominators=torch.exp((-2*indices_for_denominator*math.log(1e4))/emb_dim)
        static_positional_info[:,0::2]=torch.sin(positions*denominators)
        static_positional_info[:,1::2]=torch.cos(positions*denominators)
        self.register_buffer('static_positional_info',static_positional_info)
    def forward(self,x):
        position_encoded_embedding=x+self.static_positional_info[:x.shape[1],:]
        dropped_embeddings=self.dropout(position_encoded_embedding)
        return dropped_embeddings


In [156]:
positional_encoding_layer=positionalEncodingLayer(3,6,0.3)

In [157]:
position_encoded_embedding=positional_encoding_layer(input_embeddings)
position_encoded_embedding

tensor([[[-0.6938, -3.9102, -0.0000,  0.0000, -0.0000,  5.6211],
         [-2.2637,  0.2034, -7.9062,  0.0000, -1.8281, -2.6289],
         [ 0.0000, -0.0000, -0.0000,  4.9570,  0.4199,  6.7656]],

        [[ 0.0000, -7.7734,  1.3232, -0.0000,  0.9927,  3.0312],
         [ 2.8379, -1.0068,  4.0625,  2.4023,  0.0000,  0.0000],
         [ 6.3242, -2.5625, -1.7891,  2.0254, -1.0449,  4.6758]]],
       dtype=torch.float16, grad_fn=<MulBackward0>)

In [158]:
class multiHeadAttentionBlock(nn.Module):
    def __init__(self,emb_dim,n_heads):
        super().__init__()
        assert emb_dim%n_heads==0 ### checking if multi head splitting is possible.
        self.w_q=nn.Linear(emb_dim,emb_dim,dtype=torch.float16)
        self.w_k=nn.Linear(emb_dim,emb_dim,dtype=torch.float16)
        self.w_v=nn.Linear(emb_dim,emb_dim,dtype=torch.float16)
        self.w_o=nn.Linear(emb_dim,emb_dim,dtype=torch.float16) ### multi-head-projection-layer
        self.emb_dim=emb_dim
        self.single_head_dim=self.emb_dim//n_heads
        self.n_heads=n_heads

    @staticmethod
    def contextual_embedding(m_q,m_k,m_v,per_head_emb_dim,mask):
        ### return contexual embedding and attention scores
        attention_scores=m_q@m_k.transpose(2,3)/math.sqrt(per_head_emb_dim)
        ##batch,head,seq,dim @ batch,head,dim,seq==batch,head,seq,seq
        if mask is not None:
            attention_scores.masked_fill_(mask,value=float('-inf'))
        normalized_attention_scores=torch.softmax(attention_scores,dim=-1)
        ### batch,head,seq,seq @ batch,head,seq,dim=batch,head,seq,dim
        contexual_embeddings=normalized_attention_scores@m_v
        return normalized_attention_scores,contexual_embeddings
    
    def forward(self,q,k,v,mask):
        query=self.w_q(q) ### batch, seqeunce, dim
        key=self.w_k(k)
        value=self.w_v(v)

        multihead_query=query.view(query.shape[0],query.shape[1],self.n_heads,self.single_head_dim).transpose(1,2)
        multihead_key=key.view(key.shape[0],key.shape[1],self.n_heads,self.single_head_dim).transpose(1,2)
        multihead_value=value.view(value.shape[0],value.shape[1],self.n_heads,self.single_head_dim).transpose(1,2)
        _,contextual_embeddings=multiHeadAttentionBlock.contextual_embedding(multihead_query,multihead_key,multihead_value,self.single_head_dim,mask)
        final_contextual_embeddings=contextual_embeddings.transpose(1,2).contiguous().view(value.shape[0],value.shape[1],self.n_heads*self.single_head_dim)
        multihead_final_contextual_embedding_proj=self.w_o(final_contextual_embeddings)
        return multihead_final_contextual_embedding_proj

In [159]:
Mlab= multiHeadAttentionBlock(6,2)

In [160]:
a=lambda x: Mlab(x,x,x,None)
mha_out=a(position_encoded_embedding)
mha_out

tensor([[[ 1.8105, -0.6196,  2.2715, -1.0117,  2.3164, -1.4922],
         [ 0.3481, -0.6372,  1.6074, -1.5215,  0.5420, -0.4106],
         [ 1.5264, -0.2495,  2.4961, -0.9922,  2.2207, -0.9580]],

        [[ 2.6348, -1.2627,  2.3828, -1.1592,  2.6543, -0.4070],
         [ 2.0059, -1.2197,  2.0586, -1.3438,  1.6064, -0.4792],
         [ 2.7188, -1.1006,  3.2754, -1.4922,  2.5469, -1.1504]]],
       dtype=torch.float16, grad_fn=<ViewBackward0>)

In [161]:
class layerNormalizationBlock(nn.Module):
    def __init__(self,emb_dim,eps=1e-5):
        super().__init__()
        self.scale=nn.Parameter(torch.ones(emb_dim,dtype=torch.float16))
        self.shift=nn.Parameter(torch.zeros(emb_dim,dtype=torch.float16))
        self.eps=eps

    def forward(self,x):
        mean=x.mean(dim=-1,keepdim=True)
        standard_deviation=x.std(dim=-1,keepdim=True,unbiased=False)
        normalized_x=(x-mean)/(standard_deviation+self.eps)
        scale_n_shift=self.scale*normalized_x+self.shift
        return scale_n_shift

In [162]:
lnb=layerNormalizationBlock(6)

In [163]:
layer_normalized_out=lnb(mha_out)
layer_normalized_out

tensor([[[ 0.7827, -0.7217,  1.0684, -0.9644,  1.0957, -1.2627],
         [ 0.3635, -0.6309,  1.6348, -1.5234,  0.5591, -0.4023],
         [ 0.5850, -0.6338,  1.2510, -1.1436,  1.0615, -1.1201]],

        [[ 1.0312, -1.1680,  0.8892, -1.1094,  1.0420, -0.6851],
         [ 1.0566, -1.1172,  1.0918, -1.2002,  0.7866, -0.6177],
         [ 0.9302, -0.9214,  1.2012, -1.1113,  0.8472, -0.9453]]],
       dtype=torch.float16, grad_fn=<AddBackward0>)

In [None]:
class skipConnection(nn.Module):
    def __init__(self,dropout):
        super().__init__()
        self.dropout=nn.Dropout(dropout)

    def forward(self,x,sublayer):
        output=x+sublayer(x)
        dropped_output=self.dropout(output)
        return dropped_output

In [166]:
skipConnection=skipConnection(0.3)

In [172]:
skipConnection(position_encoded_embedding,a)

tensor([[[  0.0000,  -6.4727,   3.2461,  -0.0000,   3.3086,   5.8984],
         [ -2.7383,  -0.0000,  -0.0000,  -2.1738,  -1.8379,  -4.3438],
         [  2.1816,  -0.0000,   3.5664,   0.0000,   3.7734,   8.2969]],

        [[  3.7637, -12.9141,   5.2969,  -1.6562,   5.2109,   0.0000],
         [  0.0000,  -3.1816,   8.7422,   1.5127,   2.2949,  -0.6846],
         [ 12.9219,  -5.2344,   2.1230,   0.7617,   2.1465,   0.0000]]],
       dtype=torch.float16, grad_fn=<MulBackward0>)