In [None]:
import torch
import torch.nn as nn
import numpy as np

In [None]:
class ScaledDotProductAttention(nn.Module):
  def __init__(self, dropout):
    super(ScaledDotProductAttention, self).__init__()
    self.dropout = nn.Dropout(dropout)
    self.softmax = nn.Softmax(dim = -1)

  def forward(self, query , key, value, mask = None):
    # query : [batchsize, query_len, dim]
    # key   : [batchsize, key_len ,  dim]
    # value : [batchsize, value_len, dim]
    dim = query.shape[2]

    score = torch.bmm(query, key.transpose(1,2)) / np.sqrt(dim) # [batchsize, query_len, key_len]

    if mask is not None:
      score = score.masked_fill(mask == 0, -1e10)

    attention_weight = self.softmax(score)
    return torch.bmm(self.dropout(attention_weight), value) # [batchsize, query_len, dim]


In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, embed_size, heads, dropout=0.1):
    super(MultiHeadAttention, self).__init__()
    self.heads = heads
    self.heads_dim = embed_size // heads
    self.attention = ScaledDotProductAttention(dropout)
    self.query = nn.Linear(embed_size, embed_size)
    self.key = nn.Linear(embed_size , embed_size)
    self.value = nn.Linear(embed_size, embed_size)
    self.fc = nn.Linear(embed_size, embed_size)

  def forward(self, query , key , value, mask):
     # query : [batchsize, query_len, dim]
     # key   : [batchsize, key_len ,  dim]
     # value : [batchsize, value_len, dim]
     querys = self.query(query) #[batchsize, query_len, embed_size]
     keys = self.key(key)       #[batchsize, key_len,   embed_size]
     values = self.value(value) #[batchsize, value_len, embed_size]

     querys = querys.reshape(querys.shape[0], querys.shape[1], self.heads, self.heads_dim)
     # [batchsize, query_len, self.heads, self.heads_dim]
     querys = querys.permute(0,2,1,3) #[batchsize, self.heads, query_len, self.heads_dim]
     querys = querys.reshape(-1, querys.shape[2], querys.shape[3]) #[batchsize*self.heads, query_len, self.heads_dim]

     keys = keys.reshape(keys.shape[0], keys.shape[1], self.heads, self.heads_dim)
     # [batchsize, key_len, self.heads, self.heads_dim]
     keys = keys.permute(0,2,1,3) #[batchsize, self.heads, keys_len, self.heads_dim]
     keys = keys.reshape(-1, keys.shape[2], keys.shape[3]) #[batchsize*self.heads, key_len, self.heads_dim]

     values = values.reshape(values.shape[0], values.shape[1], self.heads, self.heads_dim)
     # [batchsize, value_len, self.heads, self.heads_dim]
     values = values.permute(0,2,1,3) #[batchsize, self.heads, value_len, self.heads_dim]
     values = values.reshape(-1, values.shape[2], values.shape[3]) #[batchsize*self.heads, value_len, self.heads_dim]

     output = self.attention(querys, keys, values, mask) 
     #[batchsize * self.heads, query_len, self.heads_dim]

     output = output.reshape(-1, self.heads, output.shape[1], output.shape[2])
     #[batchsize, self.heads, query_len, self.heads_dim]

     output = output.permute(0, 2, 1, 3)
     #[batchsize, query_len, self.heads,  self.heads_dim]

     output = output.reshape(output.shape[0],output.shape[1], -1)
     #[batchsize, query_len, embed_size]

     return output



In [None]:
class TransformerBlock(nn.Module):
  def __init__(self, embed_size, heads, dropout=0.1 , dim_feedforward = 2048):
    super(TransformerBlock, self).__init__()
    self.attention = MultiHeadAttention(embed_size, heads, dropout)
    self.norm1 = nn.LayerNorm(embed_size)
    self.norm2 = nn.LayerNorm(embed_size)
    self.feed_forward = nn.Sequential(
        nn.Linear(embed_size, dim_feedforward), 
        nn.ReLU(),
        nn.Linear(dim_feedforward , embed_size)
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self, query, key, value, mask):
    # query : [batchsize, query_len, dim]
    # key   : [batchsize, key_len ,  dim]
    # value : [batchsize, value_len, dim]
    attention = self.attention(query, key, value, mask) #[batchsize, query_len, embed_size]
    x = self.dropout( self.norm1(attention + query) )   #[batchsize, query_len, embed_size]
    forward = self.feed_forward(x)                      #[batchsize, query_len, embed_size]
    output = self.dropout(self.norm2(forward + x))      #[batchsize, query_len, embed_size]

    return output

In [None]:
class DecoderLayer(nn.Module):
  def __init__(self, embed_size, heads, dim_feedforward , dropout):
    super(DecoderLayer, self).__init__()
    self.attention = MultiHeadAttention(embed_size, heads, dropout)
    self.norm = nn.LayerNorm(embed_size)
    self.TransformerBlock = TransformerBlock(embed_size, heads, dropout, dim_feedforward)
    self.dropout = nn.Dropout(dropout)

  def forward(self, query, value, key, trg_mask, src_mask):
    attention = self.attention(query,query,query, trg_mask)
    query = self.dropout(self.norm(attention + query))
    output = self.TransformerBlock(query, key, value, src_mask)
    return output #[batchsize, query_len, embed_size]

In [None]:
class Decoder(nn.Module):
  def __init__(self, 
               vocab_size,
               embed_size,
               num_layers,
               heads,
               dim_feedforward,
               dropout,
               device):
    super(Decoder, self).__init__()
    self.embed_size = embed_size

    self.word_embedding = nn.Embedding(vocab_size, embed_size)
    
    self.layers = nn.ModuleList(
        [DecoderLayer(embed_size, heads, dim_feedforward, dropout) for _ in range(num_layers)]
    )

    self.fc_out = nn.Linear(embed_size, trg_vocab_size)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, enc_out, trg_mask, src_mask):
    # x shape [batch_size, seq_length]
    batch_size, seq_length = x.shape
    
    positional_embedding = nn.Parameter(torch.randn(1, seq_length, self.embed_size))

    x = self.dropout((positional_embedding + self.word_embedding(x)))

    for layer in self.layers:
      x = layer(x, enc_out, enc_out, trg_mask, src_mask)

    out = self.fc_out(x)

    return out

In [None]:
A = Decoder()

In [None]:
nn.Parameter(torch.randn(1, 15, 256))

torch.Size([1, 15, 256])

In [None]:
class Transformer(nn.Module):
  def __init__(
      self, 
      src_vocab_size,
      trg_vocab_size,
      src_pad_idx,
      trg_pad_idx,
      embed_size = 256,
      num_layer = 6,
      forward_expansion = 4,
      heads = 8,
      dropout = 0,
      device = "cuda",
      max_length = 100
  ):
    super(Transformer, self).__init__()

    self.encoder = Encoder(
        src_vocab_size,
        embed_size,
        num_layer,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length
    )

    self.decoder = Decoder(
        trg_vocab_size,
        embed_size,
        num_layer,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length
    )

    self.src_pad_idx = src_pad_idx
    self.trg_pad_idx = trg_pad_idx

    self.device = device

  def make_src_mask(self, src):
    src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

    # (N, 1, 1, src_len)
    return src_mask.to(self.device)

  def make_trg_mask(self, trg):
    N, trg_len = trg.shape

    trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1 ,trg_len, trg_len)
    return trg_mask.to(self.device)

  def forward(self, src, trg):
    src_mask = self.make_src_mask(src)
    trg_mask = self.make_trg_mask(trg)
    enc_src = self.encoder(src, src_mask)
    out = self.decoder(trg,enc_src, src_mask, trg_mask)
    return out
    