In [18]:
import torch
import torch.nn.functional as F
from torch import nn
import math

def scaled_dot_product(q,k,v,mask=None):
    scaled=torch.matmul(q,k.transpose(-1,-2))/math.sqrt(k.shape[-1])
    if mask==True:
        mask=torch.full(scaled.size(),float('-inf'))
        mask=torch.triu(mask,diagonal=1)
        scaled=scaled+mask
        # print(mask)
        # print(scaled)
    attention=F.softmax(scaled,dim=-1)
    values=torch.matmul(attention,v)
    return values,attention
class MultiHeadAttention(nn.Module):
    def __init__(self,input_dim,model_dim,num_head):
        super(MultiHeadAttention, self).__init__() 
        self.input_dim=input_dim
        self.model_dim=model_dim
        self.num_head=num_head
        self.head_dim=model_dim//num_head
        self.qkv_layer=nn.Linear(input_dim,3*model_dim)
        self.linear=nn.Linear(model_dim,input_dim)
    def forward(self,x,mask=None):
        self.mask=mask
        batch_dim,seq_length,input_dim=x.size()
        qkv=self.qkv_layer(x)
        qkv=qkv.reshape(batch_dim,seq_length,self.num_head,3*self.head_dim)
        qkv=qkv.permute(0,2,1,3)
        q,k,v=qkv.chunk(3,dim=-1)
        values,attention=scaled_dot_product(q,k,v,mask=mask)
        values=values.reshape(batch_size,seq_length,self.num_head*self.head_dim)
        out=self.linear(values)
        return out

In [28]:
class LayerNormalisation(nn.Module):
    def __init__(self,parameter_shape,epsilon):
        super().__init__()
        self.parameter_shape=parameter_shape
        self.eps=epsilon
        self.gamma=nn.Parameter(torch.ones(parameter_shape))
        self.beta=nn.Parameter(torch.zeros(parameter_shape))
    def forward(self,input):
        dim=[-i-1 for i in range(len(self.parameter_shape))]
        mean=input.mean(dim=dim,keepdim=True)
        var=((input-mean)**2).mean(dim=dim,keepdim=True)
        std=(var+self.eps).sqrt()
        y=(input-mean)/std
        out=self.gamma*y + self.beta
        return out

In [31]:
class FeedForwardLayer(nn.Module):
    def __init__(self,model_dim,hidden_dim,dropout):
        super(FeedForwardLayer,self).__init__()
        self.linear1=nn.Linear(model_dim,hidden_dim)
        self.relu=nn.ReLU()
        self.linear2=nn.Linear(hidden_dim,model_dim)
        self.dropout=nn.Dropout(p=dropout)
    def forward(self,x):
        x=self.linear1(x)
        x=self.relu(x)
        x=self.linear2(x)
        x=self.dropout(x)
        return x
    

In [32]:
class EncoderLayer(nn.Module):
    
    def __init__(self,model_dim,hidden_dim,num_head,dropout):
        super(EncoderLayer,self).__init__()
        self.model_dim=model_dim
        self.hidden_dim=hidden_dim
        self.num_head=num_head
        self.dropout=dropout
        self.attention=MultiHeadAttention(model_dim,model_dim,num_head)
        self.norm1=LayerNormalisation([model_dim],1e-5)
        self.dropout1=nn.Dropout(p=dropout)
        self.ffn=FeedForwardLayer(model_dim,hidden_dim,dropout)
        self.norm2=LayerNormalisation([model_dim],1e-5)
        self.dropout2=nn.Dropout(p=dropout)
    def forward(self,x):
        x_res=x
        print("------- ATTENTION 1 ------")
        x=self.attention(x)
        print("------- DROPOUT 1 ------")
        x=self.dropout1(x)
        print("------- ADD AND LAYER NORMALIZATION 1 ------")
        x=self.norm1(x+x_res)
        x_res=x
        x=self.ffn(x)
        print("------- DROPOUT 2 ------")
        x=self.dropout2(x)
        print("------- ADD AND LAYER NORMALIZATION 2 ------")
        x=self.norm2(x+x_res)
        return x
        
        

In [33]:
class Encoder(nn.Module):
    def __init__(self,model_dim,hidden_dim,num_head,dropout,num_layers):
        super().__init__()
        self.layers=nn.Sequential(*[EncoderLayer(model_dim,hidden_dim,num_head,dropout) for _ in range(num_layers)])
    def forward(self,x):
        x=self.layers(x)
        return x

In [83]:
class MultiheadCrossAttention(nn.Module):
    def __init__(self,model_dim,num_heads):
        super().__init__()
        self.model_dim=model_dim
        self.num_heads=num_heads
        self.head_dim=model_dim//num_heads
        self.kv_layer=nn.Linear(model_dim,2*model_dim)
        self.q_layer=nn.Linear(model_dim,model_dim)
        self.linear=nn.Linear(model_dim,model_dim)
    def forward(self,x,y):
        batch_size,seq_length,model_dim=x.size()
        kv=self.kv_layer(x)
        q=self.q_layer(y)
        kv=kv.reshape(batch_size,seq_length,self.num_heads,2*self.head_dim)
        q=q.reshape(batch_size,seq_length,self.num_heads,self.head_dim)
        kv=kv.permute(0,2,1,3)
        q=q.permute(0,2,1,3)
        k,v=kv.chunk(2,dim=-1)
        values,attention=scaled_dot_product(q,k,v,mask=True)
        values=values.reshape(batch_size,seq_length,model_dim)
        out=self.linear(values)
        return out

In [89]:
class DecoderLayer(nn.Module):
    def __init__(self,model_dim,hidden_dim,num_heads,dropout):
        super(DecoderLayer,self).__init__()
        self.model_dim=model_dim
        self.hidden_dim=hidden_dim
        self.num_heads=num_head
        self.dropout=dropout
        self.attention=MultiHeadAttention(model_dim,model_dim,num_heads)
        self.dropout1=nn.Dropout(p=dropout)
        self.norm1=LayerNormalisation([model_dim],1e-5)
        self.encoder_decoder_attention=MultiheadCrossAttention(model_dim,num_heads)
        self.dropout2=nn.Dropout(p=dropout)
        self.norm2=LayerNormalisation([model_dim],1e-5)
        self.ffn=FeedForwardLayer(model_dim,model_dim,dropout)
        self.dropout3=nn.Dropout(p=dropout)
        self.norm3=LayerNormalisation([model_dim],1e-5)
    def forward(self,x,y):
        y_res=y
        
        print("MASKED SELF ATTENTION")
        y=self.attention(y,mask=True)
        print("DROP OUT 1")
        y=self.dropout1(y)
        print("ADD + LAYER NORMALIZATION 1")
        y=self.norm1(y+y_res)
        
        y_res=y
        
        print("CROSS ATTENTION")
        y=self.encoder_decoder_attention(x,y)
        print("DROP OUT 2") 
        y=self.dropout2(y)
        print("ADD + LAYER NORMALIZATION 2")
        y=self.norm2(y+y_res)
        
        print("FEED FORWARD 1")
        y=self.ffn(y)
        print("DROP OUT 3")
        y=self.dropout3(y)
        print("ADD + LAYER NORMALIZATION 3")
        y=self.norm3(y+y_res)
        return y


        

In [90]:
class SequentialDecoder(nn.Sequential):
    def forward(self,*inputs):
        x,y=inputs
        for module in self._modules.values():
            y=module(x,y)
            return y
        

In [91]:
class Decoder(nn.Module):
    def __init__(self,model_dim,hidden_dim,num_head,dropout,num_layers=1):
        super().__init__()
        self.layers=SequentialDecoder(*[DecoderLayer(model_dim,hidden_dim,num_head,dropout) for _ in range(num_layers)])
    def forward(self,x,y):
        y = self.layers(x, y)
        return y
    

In [55]:
model_dim=512
hidden_dim=2048
num_head=8
dropout=0.1
num_layers=5
seq_length=200
batch_size=30

In [34]:
encoder=Encoder(model_dim,hidden_dim,num_head,dropout,num_layers)

In [45]:
x=torch.randn((batch_size,seq_length,model_dim))
y=torch.randn((batch_size,seq_length,model_dim))

In [39]:
x.shape

torch.Size([30, 200, 512])

In [35]:
out=encoder(x)

------- ATTENTION 1 ------
------- DROPOUT 1 ------
------- ADD AND LAYER NORMALIZATION 1 ------
------- DROPOUT 2 ------
------- ADD AND LAYER NORMALIZATION 2 ------
------- ATTENTION 1 ------
------- DROPOUT 1 ------
------- ADD AND LAYER NORMALIZATION 1 ------
------- DROPOUT 2 ------
------- ADD AND LAYER NORMALIZATION 2 ------
------- ATTENTION 1 ------
------- DROPOUT 1 ------
------- ADD AND LAYER NORMALIZATION 1 ------
------- DROPOUT 2 ------
------- ADD AND LAYER NORMALIZATION 2 ------
------- ATTENTION 1 ------
------- DROPOUT 1 ------
------- ADD AND LAYER NORMALIZATION 1 ------
------- DROPOUT 2 ------
------- ADD AND LAYER NORMALIZATION 2 ------
------- ATTENTION 1 ------
------- DROPOUT 1 ------
------- ADD AND LAYER NORMALIZATION 1 ------
------- DROPOUT 2 ------
------- ADD AND LAYER NORMALIZATION 2 ------


In [36]:
out

tensor([[[-9.5982e-01,  3.6257e+00, -2.9835e-01,  ..., -9.8806e-01,
          -2.0667e-01,  6.8550e-01],
         [-9.9231e-01, -6.8306e-01,  8.2421e-01,  ...,  1.5110e+00,
           5.9254e-01,  8.6756e-01],
         [-5.4171e-01,  1.0487e+00,  8.6307e-01,  ...,  1.5231e+00,
           8.1216e-01,  2.8084e-01],
         ...,
         [-1.3760e+00,  2.0386e-01,  1.1781e+00,  ...,  5.0603e-01,
          -1.3255e+00, -1.5862e+00],
         [-1.1459e-01,  1.3386e+00,  9.9571e-01,  ..., -8.8185e-01,
           7.6200e-01,  5.4307e-01],
         [ 3.4250e-01,  9.9287e-01,  5.0129e-01,  ..., -8.7733e-01,
          -1.3725e+00,  1.1900e+00]],

        [[ 2.1631e-01,  3.9607e-01,  3.0075e+00,  ...,  4.5226e-01,
          -9.1189e-01,  1.4128e+00],
         [ 6.4642e-01,  3.5981e-01,  7.7689e-01,  ..., -1.4002e-01,
           6.4884e-01, -6.3119e-01],
         [-1.9533e-01,  9.8765e-01,  1.5618e-01,  ...,  3.0139e-01,
          -8.6092e-01, -3.5307e-01],
         ...,
         [-1.3435e+00,  1

In [37]:
out.shape

torch.Size([30, 200, 512])

In [92]:
decoder=Decoder(model_dim,hidden_dim,num_head,dropout,num_layers)

In [93]:
out=decoder(x,y)

MASKED SELF ATTENTION
DROP OUT 1
ADD + LAYER NORMALIZATION 1
CROSS ATTENTION
DROP OUT 2
ADD + LAYER NORMALIZATION 2
FEED FORWARD 1
DROP OUT 3
ADD + LAYER NORMALIZATION 3


In [95]:
out.shape

torch.Size([30, 200, 512])