In [1]:
import numpy as np

#DL Framework imports
import torch
import torch.nn as nn
import torch.nn.functional as F

import math,copy,time
import matplotlib.pyplot as plt



In [3]:
class EncoderDecoder(nn.Module):
    
    def __init__(self,encoder,decoder,cur_embed,target_embed,generator):
        super().__init__()
        
        self.encoder=encoder
        self.decoder=decoder
        
        self.cur_embed=cur_embed
        self.target_embed=target_embed
        self.generator=generator
        
    def forward(self,curr,trg,curr_mask,trg_mask):
        return self.decode(self.encode(curr,curr_mask),curr_mask,tgt,tgt_mask)
    
    def encode(self,curr,curr_mask):
        return self.encoder(self.cur_embed(curr),curr_mask)
    
    def decode(self,memory, curr_mask,tgt,tgt_mask):
        return self.decoder(self.tgt_embed(tgt),memory,src_mask,tgt_mask)
    

In [4]:
class Generator(nn.Module):
    
    def __init__(self,d_model,vocab):
        super().__init__()
        self.proj=nn.linear(d_model,vocab)
        
    def forward(self,inp):
        return F.log_softmax(self.proj(x),dim=-1)
    
    

In [5]:
def clones(module,N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [6]:
class Encoder(nn.Module):
    
    def __init__(self,layer,N):
        super().__init__()
        
        self.layers=clones(layer,N)
        self.norm=LayerNorm(layer.size)
    
    def forward(self,x,mask):
        
        for layer in self.layers:
            x=layer(x,mask)
        
        return self.norm(x)

In [7]:
class LayerNorm(nn.Module):
    
    def __init__(self,features,eps=1e-6):
        super().__init__()
        self.a_2=nn.Parameter(torch.ones(features))
        self.b_2=nn.Parameter(torch.zeros(features))
        self.eps=eps
        
    def forward(self,x):
        mean=x.mean(-1,keepdim=True)
        std=x.std(-1,keepdim=True)
        return self.a_2*(x-mean)/(x+std)+self.b_2

In [8]:
class SubLayerConnection(nn.Module):
    
    def __init__(self,size,dropout):
        super().__init__()
        
        self.droput=nn.Dropout(dropout)
        self.norm=LayerNorm(size)
        
    def forward(self,x,sublayer):
        return x+self.dropout(sublayer(self.norm(x)))
    

In [9]:
class EncoderLayer(nn.Module):
    def __init__(self,size,self_attn,feed_forward,dropout):
        super().__init__()
        
        self.attn=self_attn
        self.feed_forward=feed_forward
        self.sublayer=clones(SublayerConnection(size,dropout),2)
        self.size=size
        
    def forward(self,x,mask):
        
        x=self.sublayer[0](x,lambda x: self.self_attn(x,x,x,mask))
        return self.sublayer[1](x,self.feedforward)
        

In [10]:
class Decoder(nn.Module):
    
    def __init__(self,layer,N):
        super().__init__()
        
        self.layers=clones(layer,N)
        self.norm=LayerNorm(layer.size)
    
    def forward(self,x,memory,curr_mask,tgt_mask):
        
        for layer in self.layers:
            x=layer(x,memory,curr_mask,tgt_mask)
            
        return self.norm(x)
    