In [7]:
import math
import torch 
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
# constants
N = 6 # Number of layers
d_model = 512 # dimension of Sublayer o/p,embedding layer o/p
h = 8 # number of head for multi head attention
dk = d_model/h
dv = d_model/h

# Layer Norm

In [43]:
class layer_norm(nn.Module):
    def __init__(self,n_f,eps=1e-9):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(n_f))
        self.beta = nn.Parameter(torch.zeros(n_f))
    
    def forward(self,inputs):
        mean = torch.mean(inputs,-1,keepdim=True)
        std = torch.mean(inputs,-1,keepdim=True)
        nm = (inputs - mean)/(std + self.eps)
        return (nm * self.gamma) + self.beta

### Testing

In [44]:
a = torch.FloatTensor(64,20,512)
model = layer_norm(a.size()[2])
c = model(a)
c.size()

torch.Size([64, 20, 512])

# Scaled Dot Product Attention

In [45]:
class scaled_dot_product_attention(nn.Module):
    def __init__(self,dk=64,dv=64): # multi = True tells if the input to the layer will be multi headed or not
        super().__init__()
        self.dk = dk
        self.dv = dv
        
    def forward(self,q,k,v):
        return torch.matmul(F.softmax(torch.matmul(q,k.permute(0,2,1))/math.sqrt(self.dk),2),v)     
        

### Testing

In [46]:
a = torch.FloatTensor(64,20,64)
model = scaled_dot_product_attention()
out = model(a,a,a)
out.size()

torch.Size([64, 20, 64])

In [47]:
# utility module

class linear_activated(nn.Module):
    def __init__(self,in_out, activation = nn.ReLU, split = None):
        super().__init__()
        self.linear = nn.Linear(*in_out)
        self.activation = activation()
        self.split = split
        
    def forward(self,inputs):
        if self.split is not None:
            return self.activation(self.linear(inputs)).chunk(self.split,dim=2)
        else:
            return self.activation(self.linear(inputs))
        

In [48]:
a = torch.FloatTensor(64,20,512)
model = linear_activated((512,64))
c = model(a)
c.size()

torch.Size([64, 20, 64])

In [49]:
a = torch.FloatTensor(64,20,512)
model = linear_activated((512,512 * 3),split=3)
c = model(a)
for i in c: print(i.size())

torch.Size([64, 20, 512])
torch.Size([64, 20, 512])
torch.Size([64, 20, 512])


# Multi-head attention

In [55]:
class multi_head_attention(nn.Module):
    
    def __init__(self,h=8, dmodel=512, dk=64, dv=64, self_attention=False):
        super().__init__()
        self.self_attention = self_attention
        # method 1 : when dk = dv, and k = v
        # calculate projection of query, key and value at the same time
        if self_attention == True:
            self.linear_kqv = linear_activated((dmodel,dk * h * 3),split=3) 
        else:            
            # method 2
            # second dimension is multiplied by h to calculate all head simultaneously
            self.linear_k = linear_activated((dmodel,dk * h)) 
            self.linear_q = linear_activated((dmodel,dk * h))
            self.linear_v = linear_activated((dmodel,dv * h))
            
        self.linear_o = linear_activated((h*dv,dmodel))

        self.sdpa = scaled_dot_product_attention(dk*h,dv*h)
        
    def forward(self,q,k,v):
        if self.self_attention:
            out_sdpa_multi_head = self.sdpa(*self.linear_kqv(k))
        else:
            out_sdpa_multi_head = self.sdpa(self.linear_k(k),self.linear_q(q),self.linear_v(v))
        return self.linear_o(out_sdpa_multi_head)

### Testing

In [54]:
a = torch.FloatTensor(64,20,512)
model = multi_head_attention()
c = model(a,a,a)
c.size()

torch.Size([64, 20, 512])

In [56]:
a = torch.FloatTensor(64,20,512)
model = multi_head_attention(self_attention=True)
c = model(a,a,a)
c.size()

torch.Size([64, 20, 512])