<a href="https://colab.research.google.com/github/mgaac/ml_repertoire/blob/main/TransformerArchitecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [72]:
import torch
import torch.nn as nn
import torch.nn.functional as f

import math

In [73]:
class Multi_head_attention(nn.Module):
  def __init__(self, embed_size, n_heads):
    super(Multi_head_attention, self).__init__()

    self.embed_size = embed_size
    self.n_heads = n_heads
    self.head_dim = embed_size // n_heads

    assert(self.head_dim * n_heads == embed_size)

    self.query = nn.Linear(embed_size, embed_size)
    self.key = nn.Linear(embed_size, embed_size)
    self.value = nn.Linear(embed_size, embed_size)

  def forward(self, x, value=None, query=None, key=None, mask=None):
    embed_size = self.embed_size 
    head_dim = self.head_dim
    n_heads = self.n_heads

    if None in [query, key, value]:
      query = self.query(x)
      key = self.key(x)
      value = self.value(x)
    
    else:
      query = self.query(query)
      key = self.key(key)
      value = self.value(value)


    query = query.reshape(x.size()[0], head_dim, n_heads)
    key = key.reshape(x.size()[0], head_dim, n_heads)
    value = value.reshape(x.size()[0], head_dim, n_heads)

    alingment = torch.matmul(torch.transpose(key, 1, 2), query)[:,0] # s_size, n_heads
    scaled_alingment = torch.div(alingment, (embed_size ** .5))

    if mask is not None:
      scaled_alingment = scaled_alingment.masked_fill(
          mask == 0, float(1e-20))

    weights = f.softmax(scaled_alingment, 1) 
    weighted_values = torch.einsum('sij,sj->sij', [value,weights]) # (s_size, e_size, n_heads) * (s_size, n_heads)
    weighted_values = weighted_values.reshape(
        weighted_values.size()[0], head_dim * n_heads)

    return weighted_values

In [74]:
def gen_mask(input, mask_idx):
  mask = torch.zeros(input.size())
  for x in range(mask_idx):
    idx = torch.tensor([x])
    mask = mask.index_fill(2, idx, 1)

  return mask

In [75]:
class Encoder(nn.Module):
  def __init__(self, embed_size, n_heads, ff_dim):
    super(Encoder, self).__init__()

    self.embed_size = embed_size
    self.n_heads = n_heads
    self.ff_dim = ff_dim

    self.feed_forward = nn.Sequential(
        nn.Linear(embed_size, ff_dim),
        nn.ReLU(),
        nn.Linear(ff_dim, embed_size))
  
    self.attention = Multi_head_attention(embed_size, n_heads)
    
  def forward(self, x):
    n_heads = self.n_heads

    att = self.attention(x)
    res1 = torch.add(att, x)  # Add residual conneciton to each head independently
    norm1 = f.layer_norm(res1, res1.size())   

    ff = self.feed_forward(norm1)
    res2 = torch.add(ff, norm1)
    norm2 = f.layer_norm(res2, res2.size())

    return norm2

In [78]:
class Decoder(nn.Module):
  def __init__(self, embed_size, n_heads, ff_dim):
    super (Decoder, self).__init__()

    self.embed_size = embed_size       
    self.n_heads = n_heads
    self.ff_dim = ff_dim
    
    self.feed_forward = nn.Sequential(
        nn.Linear(embed_size, ff_dim),
        nn.ReLU(),
        nn.Linear(ff_dim, embed_size))
  
    self.attention = Multi_head_attention(embed_size, n_heads)
    
  def forward(self, x, query, key, value=None, mask=None):
    n_heads = self.n_heads
    embed_size = self.embed_size
    
    att1 = self.attention(x, mask)
    res1 = torch.add(att1, x)  
    norm1 = f.layer_norm(res1, res1.size())


    att2 = self.attention(norm1, query, key, norm1)

    if value is not None:
      att2 = self.attention(norm1, query, key, value)
      
    res2 = torch.add(att2, norm1)
    norm2 = f.layer_norm(res2, res2.size())

    ff = self.feed_forward(norm2)
    res3 = torch.add(ff, norm2)
    norm3 = f.layer_norm(res3, res3.size())

    return norm3


In [80]:
class Embedding(nn.Module):
  def __init__(self, input_size, embed_size, ff_dim):
    super (Embedding, self).__init__()

    self.sequential = nn.Sequential(
        nn.Linear(input_size, embed_size),
        nn.ReLU(),
        nn.Linear(embed_size, embed_size),
        nn.ReLU(),
        nn.Linear(embed_size, embed_size))
    
  def forward(self, x):
    x = self.sequential(x)
    return x

In [99]:
class Transformer(nn.Module):
  def __init__(self, input_size, embed_size, n_heads, n_encoder, n_decoder, ff_dim, out_dim):
    super(Transformer, self).__init__()

    self.input_size = input_size
    self.embed_size = embed_size
    self.n_heads = n_heads
    self.ff_dim = ff_dim
    self.n_encoder = n_encoder
    self.n_decoder = n_decoder

    encoder = Encoder(embed_size, n_heads, ff_dim)
    decoder = Decoder(embed_size, n_heads, ff_dim)

    self.encoder_block = nn.ModuleList([encoder for i in range(n_encoder)])
    self.decoder_block = nn.ModuleList([decoder for i in range(n_decoder)])

    self.ff = nn.Sequential(
        nn.Linear(embed_size, ff_dim),
        nn.ReLU(),
        nn.Linear(ff_dim, ff_dim),
        nn.ReLU(),
        nn.Linear(ff_dim, ff_dim),
        nn.ReLU(),
        nn.Linear(ff_dim, out_dim))
    
  def forward(self, encoder_var, decoder_var):
    for submodule in enumerate(self.encoder_block):
      encoder_var = submodule[1](encoder_var)
    
    for submodule in enumerate(self.decoder_block):
      decoder_var = submodule[1](decoder_var, encoder_var, encoder_var)

    out = self.ff(decoder_var)

    return out