<a href="https://colab.research.google.com/github/matejkvassay/colab-notebooks/blob/master/attention_is_all_you_need.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
MINUS_INF=float("-1e15")

class SelfAttention(nn.Module):

  def __init__(self, emb_size, n_heads):
    super(SelfAttention, self).__init__()
    self._check_input_params(emb_size, n_heads)

    self.emb_size = emb_size
    self.n_heads = n_heads
    self.head_dim = emb_size // n_heads
    self.keys, self.queries, self.values = self._init_att_matrices(self.head_dim)
    self.fc_out = nn.Linear(emb_size, emb_size)

  def forward(self, keys, queries, values, mask=None):
    keys, queries, values = self._split_input_embeddings(keys, queries, values)
    keys, queries, values = self._send_to_linear_layers(keys, queries, values)
    attention = self._compute_masked_attention(queries, keys, mask)
    result = self._multiply_with_values(attention, queries, values)
    return self.fc_out(result)

  def _send_to_linear_layers(self, keys, queries, values):
    keys = self.keys(keys)
    queries = self.queries(queries)
    values = self.values(values)
    return keys, queries, values

  def _compute_masked_attention(self, queries, keys, mask):
    '''
    1. nqhd = queries shape (N, query_len, n_heads, heads_dim)
    2. nkhd = keys shape: (N, key_len, heads, heads_dim)
    '''
    comparison = torch.einsum("nqhd, nkhd->nhqk", [queries, keys])
    if mask is not None:
      comparison=comparison.masked_fill(mask == 0, MINUS_INF)
    att = torch.softmax(comparison / (self.emb_size**(0.5)), dim=3)
    return att  

  def _multiply_with_values(self, attention, queries, values):
    '''
    1. attention shape: (N, heads, query_len, key_len)
    2. values_shape: (N, value_len, n_heads, heads_dim)
    l = key_len == val_len => placeholder l
    '''
    output = torch.einsum('nhql, nlhd->nqhd', [attention, values])
    return output.reshape(queries.shape[0], queries.shape[1], self.emb_size)

  def _split_input_embeddings(self, keys, queries, values):
    n_samples = queries.shape[0]
    input = (keys, queries, values)
    output = tuple(mtx.reshape(n_samples, mtx.shape[1], self.n_heads, self.head_dim) \
                 for i, mtx in enumerate(input))
    return output

  @staticmethod
  def _check_input_params(emb_size, n_heads):
    if emb_size % n_heads != 0:
      raise ValueError(f'emb_size must be divisible by n_heads, values given: \
      emb_size: {emb_size}, n_heads: {n_heads}')

  def _init_att_matrices(self, head_dim):
    return (nn.Linear(self.head_dim, self.head_dim, bias=False) for _ in range(3))

In [None]:
class TransformerBlock(nn.Module):
  def __init__(self, emb_size, n_heads, dropout, expansion):
    super(TransformerBlock, self).__init__()
    self.attention = SelfAttention(emb_size, n_heads)
    self.norm_before = nn.LayerNorm(emb_size)
    self.norm_after = nn.LayerNorm(emb_size)
    self.feed_fwd = nn.Sequential(
        nn.Linear(emb_size, expansion*emb_size),
        nn.ReLU(),
        nn.Linear(expansion*emb_size, emb_size)
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self, values, keys, queries, mask=None):
    attention = self.attention(keys, queries, values, mask=mask)
    x = self.norm_before(attention + queries)
    x = self.dropout(x)
    forward = self.feed_fwd(x) + x
    forward = self.norm_after(forward)
    return self.dropout(forward)
    

In [None]:
class Encoder(nn.Module):
  def __init__(self, vocab_size, emb_size, n_layers, n_heads, expansion, device, dropout, max_seq_len):
    super(Encoder, self).__init__()
    self.emb_size = emb_size
    self.device = device
    self.word_emb = nn.Embedding(vocab_size, emb_size)
    self.pos_emb = nn.Embedding(max_seq_len, emb_size)
    self.layers = nn.ModuleList(
        [TransformerBlock(emb_size, n_heads, dropout, expansion) \
         for _ in range(n_layers)]
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask=None):
    embedding = self._embed(x)
    embedding = self.dropout(embedding)
    for layer in self.layers:
      output = layer(embedding, embedding, embedding, mask)
    return output 

  def _embed(self, x):
    n_samples, seq_len = x.shape
    positions = torch.arange(0, seq_len)
    positions = positions.expand(n_samples, seq_len).to(self.device)
    embedding = self.word_emb(x) + self.pos_emb(x)
    return embedding

In [None]:
class DecoderBlock(nn.Module):
  def __init__(self, emb_size, n_heads, expansion, dropout, device):
    super(DecoderBlock, self).__init__()
    self.attention = SelfAttention(emb_size, n_heads)
    self.norm = nn.LayerNorm(emb_size)
    self.transformer_block = TransformerBlock(emb_size, n_heads, dropout, expansion)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, values, keys, target_mask, source_mask):
    attention = self.attention(x, x, x, mask=target_mask)
    queries = attention + x
    queries = self.norm(attention)
    queries = self.dropout(attention)
    return self.transformer_block(values, keys, queries, source_mask)


In [None]:
import torch
import torch.nn as nn

In [None]:
class Decoder(nn.Module):
  def __init__(self, target_vocab_size, emb_size, n_layers, n_heads, expansion, dropout, device, max_seq_len):
    super(Decoder, self).__init__()
    self.device=device
    self.word_emb = nn.Embedding(target_vocab_size, emb_size)
    self.pos_emb = nn.Embedding(max_seq_len, emb_size)
    self.layers = nn.ModuleList(
        [DecoderBlock(emb_size, n_heads, expansion, dropout, device) for _ in range(n_layers)]
    )
    self.fc_out = nn.Linear(emb_size, target_vocab_size)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, enc_out, target_mask, source_mask):
    n_samples, seq_len = x.shape
    positions = torch.arange(0, seq_len)
    positions = positions.expand(n_samples, seq_len).to(self.device)
    x = self.word_emb(x) + self.pos_emb(positions)
    x = self.dropout(x)
    for layer in self.layers:
      x = layer(x, enc_out, enc_out, target_mask, source_mask)
    return self.fc_out(x)

In [None]:
class Transformer(nn.Module):
  def __init__(self,
               tgt_vocab_size, 
               src_vocab_size, 
               tgt_pad_idx, 
               src_pad_idx, 
               emb_size=256,
               n_layers=5, 
               expansion=4, 
               n_heads=8,
               dropout=0,
               device='cuda',
               max_seq_len=100):
    super(Transformer, self).__init__()
    self.encoder = Encoder(src_vocab_size, emb_size, n_layers, n_heads, expansion, device, dropout, max_seq_len)
    self.decoder = Decoder(tgt_vocab_size, emb_size, n_layers, n_heads, expansion, dropout, device, max_seq_len)
    self.src_pad_idx = src_pad_idx
    self.tgt_pad_idx = tgt_pad_idx
    self.device=device

  def make_src_mask(self, src):
    src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
    return src_mask.to(self.device)


  def make_tgt_mask(self, tgt):
    n_samples, tgt_len = tgt.shape
    tgt_mask = torch.tril(torch.ones((tgt_len, tgt_len)))
    tgt_mask = tgt_mask.expand(n_samples, 1, tgt_len, tgt_len)
    return tgt_mask.to(self.device)

  def forward(self, src, tgt):
    src_mask = self.make_src_mask(src)
    tgt_mask = self.make_tgt_mask(tgt)
    enc_src = self.encoder(src, mask = src_mask)
    return self.decoder(tgt, enc_src, tgt_mask, src_mask)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.tensor([[1,5,6,4,3,9,5,2,0], [1,8,7,3,4,5,6,7,2]]).to(device)
tgt = torch.tensor([[1,7,4,3,5,9,2,0],[1,5,6,2,4,7,6,2]]).to(device)

src_pad_idx = 0
tgt_pad_idx = 0
src_vocab_size = 10
tgt_vocab_size = 10

model = Transformer(tgt_vocab_size, src_vocab_size, tgt_pad_idx, src_pad_idx, device=device)
model.to(device)
output = model(x, tgt[:,:-1])
print(output.shape)

torch.Size([2, 7, 10])
