### My implementation of transformer elements(https://arxiv.org/pdf/1706.03762.pdf)

![attention](attention.png)


In [92]:
import torch
import torch.nn as nn
import numpy as np

In [93]:
class FFN(nn.Module):
    def __init__(self, d_model, d_hidden):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_hidden)
        self.linear2 = nn.Linear(d_hidden, d_model)
        self.activation = nn.ReLU()
    
    def forward(self, x):
        return self.linear2(self.activation(self.linear1(x)))

In [94]:
class AddAndNorm(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x, y):
        return self.norm(x+y)

In [95]:
def Attention(query, key, values):
    dk = query.size(1)
    scores = nn.functional.softmax((torch.matmul(query, key.T)/np.sqrt(dk)), dim = 1)

    return torch.matmul(scores, values)

In [96]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, heads):
        super().__init__()
        self.d_model = d_model
        self.heads = heads

        self.WQ = nn.ModuleList([nn.Linear(self.d_model,(self.d_model//self.heads), bias= False) for _ in range(self.heads)])
        self.WK = nn.ModuleList([nn.Linear(self.d_model,(self.d_model//self.heads), bias= False) for _ in range(self.heads)])
        self.WV = nn.ModuleList([nn.Linear(self.d_model,(self.d_model//self.heads), bias= False) for _ in range(self.heads)])
        self.WO = nn.Linear(self.d_model, self.d_model, bias = False)
    
    def forward(self, query, key, values):
        attn = []
        for i in range(self.heads):
            q = self.WQ[i](query)
            k = self.WK[i](key)
            v = self.WV[i](values)
            
            attn.append(Attention(q, k, v))
        
        cat_attn = torch.cat(attn, dim = 1)

        return self.WO(cat_attn)
        

            

In [97]:
class CreateQKV(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.WQ = nn.Linear(d_model, d_model, bias= False)
        self.WK = nn.Linear(d_model, d_model, bias= False)
        self.WV = nn.Linear(d_model, d_model, bias= False)

    def forward(self, x):
        return self.WQ(x), self.WK(x), self.WV(x)

In [98]:
class CreateCrossQKV(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.WQ = nn.Linear(d_model, d_model, bias= False)
        self.WK = nn.Linear(d_model, d_model, bias= False)
        self.WV = nn.Linear(d_model, d_model, bias= False)

    def forward(self, x, encoder_hidden):
        return self.WQ(x), self.WK(encoder_hidden), self.WV(encoder_hidden)

In [99]:
class EncoderModule(nn.Module):
    def __init__(self, d_model, heads, d_hidden):
        super().__init__()
        self.d_model = d_model
        self.heads = heads
        self.d_hidden = d_hidden

        self.multiheadattention = MultiHeadAttention(self.d_model, self.heads)
        self.feedforward = FFN(self.d_model, self.d_hidden)
        self.addnorm1 = AddAndNorm(self.d_model)
        self.addnorm2 = AddAndNorm(self.d_model)

        self.qkv = CreateQKV(self.d_model)

        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x):
        q, v, k = self.qkv(x)
        attn = self.multiheadattention(q, v, k)
        output_1 = self.addnorm1(x, self.dropout(attn))
        ffn = self.feedforward(output_1)
        
        return self.addnorm2(output_1, self.dropout(ffn))
        



        
        


In [100]:
def MaskedAttention(query, key, value):
    dk = query.size(1)
    qk = torch.matmul(query, key.T)/np.sqrt(dk)
    for i in range(query.size(0)):
        qk[i][i+1:] = -float('inf')
    scores = nn.functional.softmax(qk, dim = 1)
    
    return torch.matmul(scores, value)
    

In [101]:
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, d_model, heads):
        super().__init__()
        self.d_model = d_model
        self.heads = heads

        self.WQ = nn.ModuleList([nn.Linear(self.d_model,(self.d_model//self.heads), bias= False) for _ in range(self.heads)])
        self.WK = nn.ModuleList([nn.Linear(self.d_model,(self.d_model//self.heads), bias= False) for _ in range(self.heads)])
        self.WV = nn.ModuleList([nn.Linear(self.d_model,(self.d_model//self.heads), bias= False) for _ in range(self.heads)])
        self.WO = nn.Linear(self.d_model, self.d_model, bias = False)
    
    def forward(self, query, key, values):
        attn = []
        for i in range(self.heads):
            q = self.WQ[i](query)
            k = self.WK[i](key)
            v = self.WV[i](values)
            
            attn.append(MaskedAttention(q, k, v))
        
        cat_attn = torch.cat(attn, dim = 1)

        return self.WO(cat_attn)

In [102]:
class DecoderModule(nn.Module):
    def __init__(self, d_model, heads, d_hidden):
        super().__init__()
        self.d_model = d_model
        self.heads = heads
        self.d_hidden = d_hidden

        self.maskedattention = MaskedMultiHeadAttention(self.d_model, self.heads)
        self.addnorm1 = AddAndNorm(self.d_model)

        self.crossattetnion = MultiHeadAttention(self.d_model, self.heads)
        self.addnorm2 = AddAndNorm(self.d_model)

        self.ffn = FFN(self.d_model, self.d_hidden)
        self.addnorm3 = AddAndNorm(self.d_model)

        self.dropout = nn.Dropout(0.1)

        self.qkv = CreateQKV(self.d_model)
        self.crossqkv = CreateCrossQKV(self.d_model)

    def forward(self, x, encoder_state):
        q, k, v = self.qkv(x)
        maskedattn = self.maskedattention(q, k, v)
        mask_output = self.addnorm1(x, self.dropout(maskedattn))
        print(mask_output.shape)
        print(encoder_state.shape)
        q2, k2, v2 = self.crossqkv(mask_output, encoder_state)
        crossattn = self.crossattetnion(q2, k2, v2)
        cross_output = self.addnorm2(mask_output, self.dropout(crossattn))
        ffn_output = self.ffn(cross_output)

        return self.addnorm3(cross_output, self.dropout(ffn_output))
            
