In [1]:
import random
import torch
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.dropout(torch.relu(self.fc_1(x)))
        x = self.fc_2(x)        
        return x

In [3]:
class SelfAtt(nn.Module):
  def __init__(self, emb_size, heads,dropout):
    super(SelfAtt,self).__init__()
    self.emb_size = emb_size
    self.heads = heads
    self.head_dim = emb_size//heads
    assert(self.head_dim*heads == self.emb_size), "head_dim*heads != emb_size"

    self.query = nn.Linear(self.emb_size,self.emb_size)
    self.key = nn.Linear(self.emb_size,self.emb_size)
    self.value = nn.Linear(self.emb_size,self.emb_size)

    self.fc_out = nn.Linear(self.head_dim*heads, self.emb_size)
    self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
    self.dropout = nn.Dropout(dropout)

  def forward(self,values,keys,query,mask=None):
    N = query.shape[0]
    values = self.value(values)
    keys = self.key(keys)  
    query1 = self.query(query) 

    values = values.view(N, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)
    keys = keys.view(N, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)
    query1 = query1.view(N, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)   

    #dot product of keys and query
    energy = torch.matmul(query1, keys.permute(0, 1, 3, 2)) / self.scale
    #print(energy.shape)

    if mask is not None:
      energy = energy.masked_fill(mask == 0, float("-1e20"))
      #print(energy)
    
    attention = torch.softmax(energy, dim= -1)
    print(attention)
    #print(values.shape)
    x = torch.matmul(self.dropout(attention), values)    
    x = x.permute(0, 2, 1, 3).contiguous()
    out = x.view(N, -1, self.emb_size)
    
    out =  self.fc_out(out)
    return out

In [4]:
class TransformerBlock(nn.Module):
  def __init__(self,emb_size, heads, dropout, forward_expansion):
    super(TransformerBlock,self).__init__()
    self.att = SelfAtt(emb_size,heads,dropout)
    self.norm1 = nn.LayerNorm(emb_size)
    self.norm2 = nn.LayerNorm(emb_size)

    self.feed_forward = PositionwiseFeedforwardLayer(emb_size, forward_expansion*emb_size, dropout)
    self.dropout = nn.Dropout(dropout)

  def forward(self, value, key, query):
    att = self.att(value, key, query)
    x = self.norm1(query + self.dropout(att))
    forward = self.feed_forward(x)
    out  = self.norm2(x + self.dropout(forward))
    return out

In [5]:
class Encoder(nn.Module):
  def __init__(self, 
               src_vocab_size, 
               emb_size,
               num_layers,
               heads,
               device,
               forward_expansion,
               dropout,
               max_length,
               ):
    super(Encoder,self).__init__()
    self.emb_size = emb_size
    self.device = device
    self.word_embedding  = nn.Embedding(src_vocab_size,emb_size)
    self.position_embedding = nn.Embedding(max_length, emb_size)
    self.layers = nn.ModuleList(
        [
         TransformerBlock(
             emb_size,
             heads,
             dropout = dropout,
             forward_expansion = forward_expansion,
         )
         for _ in range(num_layers)
        ]
    )
    self.dropout = nn.Dropout(dropout)
    self.scale = torch.sqrt(torch.FloatTensor([emb_size])).to(device)
  def forward(self,x):
    N,seq_length = x.shape
    positions = torch.arange(0, seq_length).unsqueeze(0).repeat(N, 1).to(self.device)
    out = self.dropout(self.word_embedding(x)*self.scale + self.position_embedding(positions))
    for layer in self.layers:
      out  = layer(out,out,out)
      
    return out

In [6]:
class DecoderBlock(nn.Module):
  def __init__(self, emb_size, heads, forward_expansion, dropout, device):
    super(DecoderBlock,self).__init__()
    self.attention = SelfAtt(emb_size,heads,dropout)
    self.norm1 = nn.LayerNorm(emb_size)
    self.norm = nn.LayerNorm(emb_size)
    self.transformer_block = TransformerBlock(emb_size, heads, dropout, forward_expansion)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, value, key):
    attention = self.attention(x,x,x)
    query = self.norm1(self.dropout(attention) + x)
    out  = self.transformer_block(value, key, query)
    return out

In [7]:
class Decoder(nn.Module):
    def __init__(
        self,
        trg_vocab_size,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout,
        device,
        max_length,
    ):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
                for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([embed_size])).to(device)

    def forward(self, x, enc_out):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).unsqueeze(0).repeat(N, 1).to(self.device)
        x = self.dropout((self.word_embedding(x)*self.scale) + self.position_embedding(positions))
        for layer in self.layers:
            x = layer(x, enc_out, enc_out)
        out = self.fc_out(x)

        return out

In [8]:
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        trg_vocab_size,
        embed_size=5,
        num_layers=6,
        forward_expansion=4,
        heads=1,
        dropout=0,
        device="cpu",
        max_length=10,
        src_pad_idx = 0,
        trg_pad_idx = 0,
        teacher_force = 0.5,
    ):

        super(Transformer, self).__init__()

        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length,
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length,
        )

        self.device = device
        self.trg_vocab_size = trg_vocab_size
        self.teacher_force = teacher_force

    def forward(self, src, trg):
        enc_src = self.encoder(src)
        outputs = torch.zeros((trg.shape[0],trg.shape[1],self.trg_vocab_size)).to(self.device)
        trg_dec = trg[:,0:1]
        for i in range(trg.shape[1]):
            out = self.decoder(trg_dec, enc_src)
            outputs[:,i,:] = out[:,-1,:]
            trg_dec = torch.argmax(outputs[:,0:i+2,:],dim = 2) if random.random() < self.teacher_force else trg[:,0:i+2]
        return outputs


In [9]:
if __name__ == "__main__":
    print(device)
    x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0]]).to(
        device
    )
    trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0]]).to(device)

    src_pad_idx = 1
    trg_pad_idx = 3
    src_vocab_size = 10
    trg_vocab_size = 10 
    model  = Transformer( src_vocab_size, trg_vocab_size,).to(device)
    out = model(x, trg)
    print(out.shape)

cpu
tensor([[[[1.4548e-01, 2.5826e-04, 7.8264e-03, 1.1178e-04, 1.8012e-04,
           4.9817e-04, 1.3810e-04, 8.3550e-01, 1.0006e-02],
          [2.4547e-05, 4.3949e-03, 3.4911e-05, 4.7845e-02, 9.4335e-01,
           1.9375e-04, 4.0054e-03, 5.3271e-05, 9.3907e-05],
          [1.3042e-01, 5.9124e-02, 1.4305e-02, 8.2531e-02, 3.9502e-01,
           2.0913e-02, 4.7821e-02, 1.5379e-01, 9.6078e-02],
          [3.7214e-05, 4.1774e-01, 1.2315e-03, 1.0910e-02, 2.3462e-02,
           1.3337e-01, 4.1000e-01, 1.2799e-06, 3.2424e-03],
          [1.1290e-05, 1.6732e-01, 7.1817e-03, 2.2533e-01, 3.2824e-02,
           1.5877e-01, 4.0797e-01, 1.2408e-06, 5.9048e-04],
          [9.4389e-04, 4.0365e-03, 4.1105e-05, 5.6134e-03, 9.8208e-01,
           9.0527e-05, 1.5102e-03, 4.7910e-03, 8.8982e-04],
          [2.1466e-05, 1.2921e-02, 2.3758e-05, 2.2060e-02, 9.5548e-01,
           3.9910e-04, 8.8980e-03, 1.9516e-05, 1.8077e-04],
          [2.8297e-01, 5.9892e-04, 3.3723e-01, 9.2095e-04, 1.4994e-05,
        