In [378]:
import numpy as np

#DL Framework imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import math,copy,time
import matplotlib.pyplot as plt



In [414]:
class EncoderDecoder(nn.Module):
    
    def __init__(self,encoder,decoder,cur_embed,target_embed,generator):
        super().__init__()
        
        self.encoder=encoder
        self.decoder=decoder
        
        self.cur_embed=cur_embed
        self.target_embed=target_embed
        self.generator=generator
        
    def forward(self,curr,trg,curr_mask,trg_mask):
        return self.generator(self.decode(self.encode(curr,curr_mask),curr_mask,trg,trg_mask))
    
    def encode(self,curr,curr_mask):
        return self.encoder(self.cur_embed(curr),curr_mask)
    
    def decode(self,memory, curr_mask,tgt,tgt_mask):
        return self.decoder(self.target_embed(tgt),memory,curr_mask,tgt_mask)
    

In [415]:
class Generator(nn.Module):
    
    def __init__(self,d_model,vocab):
        super().__init__()
        self.proj=nn.Linear(d_model,vocab,bias=False)
        
    def forward(self,inp):
        x.type(torch.LongTensor)
        print(x.dtype)
        return self.proj(x)
    
    

In [416]:
def clones(module,N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [417]:
class Encoder(nn.Module):
    
    def __init__(self,layer,N):
        super().__init__()
        
        self.layers=clones(layer,N)
        self.norm=LayerNorm(layer.size)
    
    def forward(self,x,mask):
        
        for layer in self.layers:
            x=layer(x,mask)
        
        return self.norm(x)

In [418]:
class LayerNorm(nn.Module):
    
    def __init__(self,features,eps=1e-6):
        super().__init__()
        self.a_2=nn.Parameter(torch.ones(features))
        self.b_2=nn.Parameter(torch.zeros(features))
        self.eps=eps
        
    def forward(self,x):
        mean=x.mean(-1,keepdim=True)
        std=x.std(-1,keepdim=True)
        return self.a_2*(x-mean)/(x+std)+self.b_2

In [419]:
class SublayerConnection(nn.Module):
    
    def __init__(self,size,dropout):
        super().__init__()
        
        self.dropout=nn.Dropout(dropout)
        self.norm=LayerNorm(size)
        
    def forward(self,x,sublayer):
        return x+self.dropout(sublayer(self.norm(x)))
    

In [420]:
class EncoderLayer(nn.Module):
    def __init__(self,size,self_attn,feed_forward,dropout):
        super().__init__()
        
        self.attn=self_attn
        self.feed_forward=feed_forward
        self.sublayer=clones(SublayerConnection(size,dropout),2)
        self.size=size
        
    def forward(self,x,mask):
        
        x=self.sublayer[0](x,lambda x: self.attn(x,x,x,mask))
        return self.sublayer[1](x,self.feed_forward)
        

In [421]:
class Decoder(nn.Module):
    
    def __init__(self,layer,N):
        super().__init__()
        
        self.layers=clones(layer,N)
        self.norm=LayerNorm(layer.size)
    
    def forward(self,x,memory,curr_mask,tgt_mask):
        
        for layer in self.layers:
            x=layer(x,memory,curr_mask,tgt_mask)
            
        return self.norm(x)
    

In [422]:
class DecoderLayer(nn.Module):
    
    def __init__(self,size,self_attn,src_attn,feed_forward,dropout):
        super().__init__()
        
        self.size=size
        self.self_attn=self_attn
        self.src_attn=src_attn
        self.feed_forward=feed_forward
        
        self.sublayer=clones(SublayerConnection(size,dropout),3)
        
    def forward(self,x,memory,src_mask,tgt_mask):
        
        m=memory
        x=self.sublayer[0](x,lambda x:self.self_attn(x,x,x,tgt_mask))
        x=self.sublayer[1](x,lambda x: self.src_attn(x,m,m,src_mask))
        return self.sublayer[2](x,self.feed_forward)
        

In [423]:
def attention(query,key,value,mask=None,dropout=None):
    
    d_k=query.size(-1)
    print("query: "+str(query.size()))
    print("key: "+str(key.size()))
    scores=torch.matmul(query,key.transpose(-2,-1))/math.sqrt(d_k)
    
    if mask is not None:
        print(scores.size())
        print(mask.size())
        scores=scores.masked_fill(mask==0,-1e9)
        
    p_attn=F.softmax(scores,dim=-1)
    
    if dropout is not None:
        p_attn=dropout(p_attn)
        
    return torch.matmul(p_attn,value),p_attn
    

In [424]:
class MultiHeadedAttention(nn.Module):
    
    def __init__(self,h,d_model,dropout=0.1):
        super().__init__()
        
        assert d_model%h==0
        
        self.d_k=d_model//h
        self.h=h
        self.linears=clones(nn.Linear(d_model,d_model),4)
        self.attn=None
        self.dropout=nn.Dropout(dropout)
        
    def forward(self,query,key,values,mask=None):
        
        if mask is not None:
            mask=mask.unsqueeze(1)
            
        nbatches=query.size(0)
        
        query,key,values=[l(x).view(nbatches,-1,self.h,self.d_k).transpose(1,2) for l, x in zip(self.linears,(query,key,values))]
        
        x,self.attn=attention(query,key,values,mask=mask,dropout=self.dropout)
        
        x=x.transpose(1,2).contiguous().view(nbatches,-1,self.h*self.d_k)
        
        return self.linears[-1](x)
        

In [425]:
class PositionwiseFeedForward(nn.Module):
    
    def __init__(self,d_model,d_ff,dropout=0.1):
        super().__init__()
        
        self.w_1=nn.Linear(d_model,d_ff)
        self.w_2=nn.Linear(d_ff,d_model)
        self.dropout=nn.Dropout(dropout)
        
    def forward(self,x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))
    

In [426]:
class Embeddings(nn.Module):
    
    def __init__(self,d_model,vocab):
        super().__init__()
        
        self.embed=nn.Embedding(vocab,d_model)
        self.d_model=d_model
    
    def forward(self,x):
        return self.embed(x)*math.sqrt(self.d_model)


In [427]:
class PositionalEncoding(nn.Module):
    
    def __init__(self,d_model,dropout,max_len=5000):
        super().__init__()
        
        self.dropout=nn.Dropout(dropout)
        pe=torch.zeros(max_len,d_model,dtype=torch.float)
        position=torch.arange(0.,max_len).unsqueeze(1)
        div_term=torch.exp(torch.arange(0.,d_model,2)*-(math.log(10000.0)/d_model))
        
        pe[:,0::2]=torch.sin(position*div_term)
        pe[:,1::2]=torch.cos(position*div_term)
        
        pe=pe.unsqueeze(0)
        self.register_buffer('pe',pe)
        
    def forward(self,x):
        print(x.size())
        x=x+Variable(self.pe[:,:x.size(1)],requires_grad=False)
        return self.dropout(x)
        

In [428]:
def make_model(src_vocab,tgt_vocab,N=6,d_model=512,d_ff=2048,h=8,dropout=0.1):
    
    c=copy.deepcopy
    attn=MultiHeadedAttention(h,d_model)
    ff=PositionwiseFeedForward(d_model,d_ff,dropout)
    position=PositionalEncoding(d_model,dropout)
    model=EncoderDecoder(Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),N),
                        Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),N),
                        nn.Sequential(Embeddings(d_model,src_vocab),c(position)),
                        nn.Sequential(Embeddings(d_model,tgt_vocab),c(position)),
                        Generator(d_model,tgt_vocab))
    
    for p in model.parameters():
        if p.dim()>1:
            nn.init.xavier_uniform_(p)
    
    return model

In [431]:
tmp_model=make_model(10,10)

TypeError: only floating-point types are supported as the default type

In [430]:
x=torch.ones(10,5,dtype=torch.long)
y=torch.ones(10,5,dtype=torch.long)
x_mask=torch.zeros(10,5,5,dtype=torch.long)
y_mask=torch.zeros(10,5,5,dtype=torch.long)

print(x.size())
out=tmp_model(x,y,x_mask,y_mask)
print("output size: "+str(out.size()))

torch.Size([10, 5])
torch.Size([10, 5, 512])
query: torch.Size([10, 8, 5, 64])
key: torch.Size([10, 8, 5, 64])
torch.Size([10, 8, 5, 5])
torch.Size([10, 1, 5, 5])
query: torch.Size([10, 8, 5, 64])
key: torch.Size([10, 8, 5, 64])
torch.Size([10, 8, 5, 5])
torch.Size([10, 1, 5, 5])
query: torch.Size([10, 8, 5, 64])
key: torch.Size([10, 8, 5, 64])
torch.Size([10, 8, 5, 5])
torch.Size([10, 1, 5, 5])
query: torch.Size([10, 8, 5, 64])
key: torch.Size([10, 8, 5, 64])
torch.Size([10, 8, 5, 5])
torch.Size([10, 1, 5, 5])
query: torch.Size([10, 8, 5, 64])
key: torch.Size([10, 8, 5, 64])
torch.Size([10, 8, 5, 5])
torch.Size([10, 1, 5, 5])
query: torch.Size([10, 8, 5, 64])
key: torch.Size([10, 8, 5, 64])
torch.Size([10, 8, 5, 5])
torch.Size([10, 1, 5, 5])
torch.Size([10, 5, 512])
query: torch.Size([10, 8, 5, 64])
key: torch.Size([10, 8, 5, 64])
torch.Size([10, 8, 5, 5])
torch.Size([10, 1, 5, 5])
query: torch.Size([10, 8, 5, 64])
key: torch.Size([10, 8, 5, 64])
torch.Size([10, 8, 5, 5])
torch.Size([

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'mat2'