In [1]:
import torch
import torch.nn as nn

In [5]:
class SelfAttention(nn.Module):
  def __init__(self,embed_size,heads):
    super(SelfAttention, self).__init__()
    self.embed_size = embed_size
    self.heads = heads
    self.heads_dim = embed_size // heads

    #Value, Key and Query
    self.value = nn.Linear(self.heads_dim,self.heads_dim, bias=None)
    self.key = nn.Linear(self.heads_dim,self.heads_dim, bias=None)
    self.query = nn.Linear(self.heads_dim,self.heads_dim, bias=None)
    self.fc_out = nn.Linear(heads*self.heads_dim, embed_size)
  
  def forward(self,value,key,query,mask):
    N = query.shape([0])
    value_len, key_len, query_len = value.shape([1]), key.shape([1]), query.shape([1])
    value = value.reshape(N, value_len, self.heads, self.head_dim)
    key = key.reshape(N, key_len, self.heads, self.head_dim)
    query = query.reshape(N, query_len, self.heads, self.head_dim)
    enery = torch.einsum('nqhd,nkhd->nhqd',[queries,keys])
    #query shape : N, query_len, heads, heads_dim
    #key shape : N, key_len, heads, heads_dim

    if mask == None:
      energy = energy.masked_fill(mask == 0, float('-1e20'))

    attention = softmax(energy / (self.embed_size ** 0.5), dim=3)
    out = torch.einsum('nhql,nlhd->nqhd',[attention, values]).reshape(N, self.heads*self.head_dim)
    out = self.fc_out(out)
    return out


class TransformerBlock(nn.Module):
  def __init__(self,embed_size,heads,dropout,forward_expansion):
    super(TransformerBlock,self).__init__()
    self.attention = SelfAttention(embed_size,heads)
    self.norm1 = nn.LayerNorm(embed_size)
    self.norm2 = nn.LayerNorm(embed_size)
    self.feed_forward = nn.Sequential(
        nn.Linear(embed_size, forward_expansion*embed_size),
        nn.ReLU(),
        nn.Linear(forward_expansion*embed_size, embed_size)
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self,value,key,query,mask):
    attention = self.attention(value,key,query,mask)
    x = self.dropout(self.norm1(attention + query))
    forward = self.feed_forward(x)
    out = self.dropout(self.norm2(forward + x))
    return out


class Encoder(nn.Module):
  def __init__(self,src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length):
    super(Encoder,self).__init__()
    self.embed_size = embed_size
    self.device = device
    self.word_embedding = nn.Embedding(src_vocab_size,embed_size)
    self.positional_embedding = nn.embedding(max_length,embed_size)
    self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion)])
    self.dropout = nn.Dropout(dropout)

  def forward(self,x,mask):
    N,seq_length = x.shape
    positions = torch.arange(0,seq_length).expand(N,seq_length).to(self.device)
    out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
    for layers in self.layer:
      out = layer(out,out,out,mask)
  
  
class DecoderBlock(nn.Module):
  def __init__(self,embed_size,heads,forward_expansion,dropout,device):
    super(DecoderBlock,self).__init__()
    self.attention = SelfAttention(embed_size,heads)
    self.norm = nn.LayerNorm(embed_size)
    self.transformer_block = TransformerBlock(embed_size,heads,dropout,forward_expansion)
    self.dropout = nn.Dropout(dropout)

  def forward(self,x,value,key,query,src_mask,trg_mask):
    attention = self.attention(x,x,x,trg_mask)
    query = self.dropout(self.norm(attention + x))
    out = self.transformer_block(value,key,query,src_mask)
    return out


class Decoder(nn.Module):
  def __init__(self,trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length):
    super(Decoder,self).__init__()
    self.device = device
    self.word_embedding = nn.Embedding(trg_vocab_size,embed_size)
    self.position_embedding = nn.Embedding(max_length,embed_size)
    self.layers = nn.ModuleList([DecoderBlock(embed_size,heads,forward_expansion,dropout,device) for i in range (num_layers)])
    self.fc_out = nn.Linear(embed_size,trg_vocab_size)
    self.dropout = dropout
  
  def forward(self,x,enc_out,src_mask_trg_mask):
    N,seq_length = x.shape
    positions = torch.arange(0,seq_length).expand(N,seq_length).to(self.device)
    x = self.dropout(self.word_embedding(x) + self.position_embeddings(positions))
    for layer in self.layers:
      x = layer(x,src_out,src_out,src_mask,trg_mask)
    out = self.fc_out(x)
  

class Transformer(nn.Module):
  def __init__(src_vocab_size,trg_vocab_size,src_pad_size,trg_pad_size,embed_size=256,num_layers=6,froward_expansion=0,heads=8,dropout=0,device='cuda',max_length=100):
    super(self,Transformer).__init__()
    self.encoder = Encoder(src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length)
    self.decoder = Decoder(trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length)
    self.src_pad_idx = src_pad_idx
    self.trg_pad_idx = trg_pad_idx
    self.device = device
  
  def make_src_mask(self,src):
    src_mask = (src != self.src_pad_idx).unsqeeze(1).unsqueeze(2)
    return src_mask.to(self.device)
  
  def make_trg_mask(self,trg):
    N, trg_len = trg.shape
    trg_mask = trg.trill(torch.ones((trg_len,trg_len))).expand(N,1,trg_len,trg_len)
    return trg_mask.to(self.device)
  
  def forward(self,src,target):
    src = self.make_src_mask(src)
    trg = self.make_trg_mask(trg)
    enc_src = self.encoder(src,src_mask)
    out = self.decoder(trg,enc_src,src_mask,trg_mask)
    return out
