<a href="https://colab.research.google.com/github/lorrespz/Transformers-Language-Models--Pytorch-/blob/main/Transformer_Encoder_Decoder_Seq2seq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Encoder-Decoder architecture (Seq2seq) from scratch

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset
import numpy as np
import matplotlib.pyplot as plt

# Multihead Attention Block (with Causal Mask option)

This block has an option to implement the causal mask (if set to True)

In [2]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_k, d_model, n_heads, max_len, causal = False):
    super().__init__()
    #Assume d_v = d_k (len(Q) = len(K) = d_k, len(V) = d_v)
    self.d_k = d_k
    self.n_heads = n_heads
    self.key = nn.Linear(d_model, d_k*n_heads)
    self.query = nn.Linear(d_model, d_k*n_heads)
    self.value = nn.Linear(d_model, d_k*n_heads)
    #final linear layer
    self.fc = nn.Linear(d_k*n_heads, d_model)

    #causal mask: a square matrix of size max_len x max_len
    #with the lower triangle half being all 1,
    #upper triangle half being all 0
    self.causal = causal
    if causal:
      cm = torch.tril(torch.ones(max_len, max_len))
      self.register_buffer('causal_mask',
                         cm.view(1,1,max_len, max_len))

  def forward(self, q, k, v, pad_mask = None):
    q = self.query(q)   # N x T x (hd_k)
    k = self.key(k)     # N x T x (hd_k)
    v = self.value(v)    # N x T x (hd_v)

    #h = n_heads
    # N = batch size
    N = q.shape[0]
    # In seq2seq, the input and output
    #sequences might have different lengths

    # q comes from the decoder (output)
    T_output = q.shape[1]
    # k, v come from the encoder (input)
    T_input = k.shape[1]

    #change the shape to:
    # (N, T, h, d_k) --> (N, h, T, d_k)
    q = q.view(N, T_output, self.n_heads, self.d_k).transpose(1,2)
    k = k.view(N, T_input, self.n_heads, self.d_k).transpose(1,2)
    v = v.view(N, T_input, self.n_heads, self.d_k).transpose(1,2)

    #compute attention weights
    # q * k^T
    #(N,  h, T,  d_k) x (N, h, d_k, T) --> (N, h, T, T)
    #transposing the last 2 dimensions of k
    attn_scores = q @ k.transpose(-2, -1)/math.sqrt(self.d_k)
    #apply the mask, which is a tensor of size (N,T) of values 0, 1
    #for each of the N samples, need to know which of the T tokens is important
    #Change from 2D to 4D by adding None, which introduces superfluous dim of size 1
    # (N, T) --> (N, 1, 1, T)
    if pad_mask is not None:
      #mask_fill(arg1, arg2): if arg1 = True, apply arg2
      #softmax(-inf) = 0
       attn_scores = attn_scores.masked_fill(pad_mask[:, None, None,:] == 0, float('-inf'))
       #HERE IS THE CAUSAL MASK !!!
    if self.causal:
       attn_scores = attn_scores.masked_fill(self.causal_mask[:, :, :T_output,:T_input] == 0, float('-inf'))
    attn_weights = F.softmax(attn_scores, dim = -1)

    #compute attention-weighted values
    #(N, h, T, T) x (N, h, T, d_k) --> (N, h, T, d_k)
    A = attn_weights @ v

    #reshape it back before the final linear layer
    A = A.transpose(1, 2) # (N, T, h, d_k)
    A = A.contiguous().view(N, T_output, self.d_k*self.n_heads) #(N, T, h*d_k)

    #final step is to project A with the Linear layer to
    #get the same shape as the input sequence
    return self.fc(A)

# Encoder Transformer Block

This remains almost the same as the standalone encoder.

In [3]:
#This is the Transformer block of the Encoder
class EncoderBlock(nn.Module):
  def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob = 0.1):
    super().__init__()

    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)
    self.mha = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=False)
    self.ann = nn.Sequential(
        nn.Linear(d_model, d_model*4),
        nn.GELU(),
        nn.Linear(d_model*4, d_model),
        nn.Dropout(dropout_prob)
        )
    self.dropout = nn.Dropout(p = dropout_prob)

  def forward(self, x, mask = None):
    #x is an input sequence of size (NxTXD)
    # mask is of size (NxT)
    #FIRST LAYER NORM:
    #pass x in as the query, key, value into the multihead attention block
    #then add the output to the residual 'x' to be passed in the 1st layer norm
    x = self.ln1(x+ self.mha(x,x,x,mask))
    # SECOND LAYER NORM: ann + x
    x = self.ln2(x + self.ann(x))
    x = self.dropout(x)
    return(x)

# Decoder Transformer Block

 This is the transformer block of the decoder. There are 2 multihead attention blocks:
 - One for the self attention within the decoder
 - One for the output of the encoder
 Structurally, it is very different from the transformer block of the standalone decoder in that it requires 2 multihead attention blocks and 3 LayerNorms (compared to 1 multihead attention and 2 LayerNorms)

In [4]:

class DecoderBlock(nn.Module):
  def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob = 0.1):
    super().__init__()

    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)
    self.ln3 = nn.LayerNorm(d_model)
    self.mha1 = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal = True)
    self.mha2 = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal = False)
    self.ann = nn.Sequential(
        nn.Linear(d_model, d_model*4),
        nn.GELU(),
        nn.Linear(d_model*4, d_model),
        nn.Dropout(dropout_prob)
        )
    self.dropout = nn.Dropout(p = dropout_prob)

  def forward(self, enc_output, dec_input, enc_mask = None, dec_mask = None):
    #self attention on the decoder input
    x = self.ln1(dec_input+ self.mha1(dec_input, dec_input, dec_input, dec_mask))
    # multihead attention including encoder output
    x = self.ln2(x + self.mha2(x, enc_output, enc_output, enc_mask))
    x = self.ln3(x + self.ann(x))
    x = self.dropout(x)
    return(x)

# Positional Encoding Block

In [5]:
#This is exactly the same as the standalone encoder or decoder
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len = 2048, dropout_prob = 0.1):
    super().__init__()
    self.dropout = nn.Dropout(p = dropout_prob)
    #unsqueeze(1) adds a superfluous dim of size 1 at the end
    #so that we have a 2d array of size (max_len, 1)
    #position is pos variable in the formula
    position = torch.arange(max_len).unsqueeze(1)
    #exp_term is the '2i' in the exponent of the denominator in the formula
    exp_term = torch.arange(0, d_model, 2)
    #this is just the term 10000^(-2i/d_model)
    div_term = torch.exp(exp_term*(-math.log(10000.0)/d_model))
    #PE term
    pe = torch.zeros(1, max_len, d_model)
    #0::2 means 2, 4, 6, 8, ... indexing
    pe[0, :, 0::2] = torch.sin(position*div_term)
    #1::2 means 1,3,5, 7, ... indexing
    pe[0, :, 1::2] = torch.cos(position*div_term)
    #register_buffer allows for saving and loading the model correctly
    self.register_buffer('pe', pe)

  def forward(self, x):
    #x shape: NxTxD (D: d_model)
    x  = x + self.pe[:,:x.size(1), :]
    return self.dropout(x)


# Encoder block

In [6]:
class Encoder(nn.Module):
  def __init__(self, vocab_size,
               max_len, d_k, d_model, n_heads, n_layers,  dropout_prob):
    super().__init__()

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
    transformers_blocks = [EncoderBlock(d_k, d_model, n_heads, dropout_prob) for _ in range(n_layers)]
    self.transformer_blocks = nn.Sequential(*transformers_blocks)
    self.ln = nn.LayerNorm(d_model)
    #self.fc = nn.Linear(d_model, n_classes)

  def forward(self, x, pad_mask = None):
    x = self.embedding(x)
    x = self.pos_encoding(x)
    for block in self.transformer_blocks:
      x = block(x, pad_mask)

    #depends on the kind of tasks that we need, here:
    #many-to-one (x has the shape N x T x D)
    #x = x[:, 0, :]
    x = self.ln(x)
    #x = self.fc(x)

    return x

# Decoder block

In [7]:
class Decoder(nn.Module):
  def __init__(self, vocab_size, max_len, d_k, d_model, n_heads, n_layers, dropout_prob):
    super().__init__()

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
    transformers_blocks = [DecoderBlock(d_k, d_model, n_heads, max_len, dropout_prob) for _ in range(n_layers)]
    self.transformer_blocks = nn.Sequential(*transformers_blocks)
    self.ln = nn.LayerNorm(d_model)
    self.fc = nn.Linear(d_model, vocab_size)

  def forward(self, enc_output, dec_input, enc_mask = None, dec_mask = None):
    x = self.embedding(dec_input)
    x = self.pos_encoding(x)
    for block in self.transformer_blocks:
      x = block(enc_output, x, enc_mask, dec_mask)
    x = self.ln(x)
    x = self.fc(x) #many_to_many task

    return x

# Transformer Class with both Encoder & Decoder

In [8]:
class Transformer(nn.Module):
  def __init__(self, encoder, decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder

  def forward(self, enc_input, dec_input, enc_mask, dec_mask):
    enc_output = self.encoder(enc_input, enc_mask)
    dec_output = self.decoder(enc_output, dec_input, enc_mask, dec_mask)
    return dec_output


# Test the model with a dummy input

Recall that this is a many-to-many model since the output of the model is text generation.

In [9]:
encoder = Encoder(vocab_size = 20000,
                  max_len = 512,
                  d_k = 16,
                  d_model = 64,
                  n_heads = 4,
                  n_layers = 2,
                  dropout_prob = 0.1)

decoder = Decoder(vocab_size = 10000,
                  max_len = 512,
                  d_k = 16,
                  d_model = 64,
                  n_heads = 4,
                  n_layers = 2,
                  dropout_prob = 0.1)
transformer = Transformer(encoder, decoder)

In [10]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
encoder.to(device)
decoder.to(device)

cuda:0


Decoder(
  (embedding): Embedding(10000, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): DecoderBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha1): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (mha2): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
 

In [21]:
#Define the dummy inputs and masks

#Encoder input: Batch size of 8, sequence length of 512
x_en = np.random.randint(0, 20000, size = (8,512))
x_ent = torch.tensor(x_en).to(device)
# Decoder input: Batch size of 8, sequence length of 256
x_dc = np.random.randint(0, 10000, size = (8,256))
x_dct = torch.tensor(x_dc).to(device)

#Masks
maske = np.ones((8,512))
maske[:,256:] = 0
maske_t = torch.tensor(maske).to(device)

maskd = np.ones((8,256))
maskd[:,128:] = 0
maskd_t = torch.tensor(maskd).to(device)


In [22]:
#Encoder takes as inputs encoder input and encoder mask (see the forward method in the Encoder class)
out1 = encoder(x_ent, maske_t)
out1.shape

torch.Size([8, 512, 64])

In [23]:
# Decoder takes as inputs: Encoder output, Decoder input, Encoder mask, Decoder mask
#forward method of decoder: (self, enc_output, dec_input, enc_mask = None, dec_mask = None):
out2 = decoder(out1, x_dct, maske_t, maskd_t)
out2.shape

torch.Size([8, 256, 10000])

In [24]:
out = transformer(x_ent, x_dct, maske_t, maskd_t)
out.shape

torch.Size([8, 256, 10000])

In [25]:
out

tensor([[[-0.8787, -0.2507, -1.2625,  ..., -0.0921,  0.3364,  1.3597],
         [ 0.0409,  0.1334,  0.3213,  ...,  0.3628, -0.1133,  0.7701],
         [-0.4443,  0.8256, -0.1874,  ..., -0.2034,  0.3263,  0.7999],
         ...,
         [ 0.5611,  0.4037,  0.3164,  ..., -0.7057, -0.0457,  0.0446],
         [ 0.3010,  0.1016, -0.3914,  ..., -0.3022,  0.1346, -0.5488],
         [ 0.6013,  0.4361, -0.3773,  ...,  0.0678, -0.4806, -0.6275]],

        [[ 0.0717,  0.4641, -0.6215,  ...,  1.0713,  0.0804,  0.4144],
         [-0.4752,  0.5061, -1.3076,  ..., -0.5369,  1.0394,  1.0067],
         [ 0.4456, -0.2973, -0.2378,  ...,  0.6716,  0.1164,  0.4415],
         ...,
         [ 0.6878,  0.5668,  0.2774,  ..., -0.3084,  0.6177,  0.7214],
         [ 0.7308,  0.5590,  0.3725,  ...,  1.1489,  0.1562,  0.2182],
         [ 1.0898,  0.4759, -0.4870,  ...,  0.1069,  0.7643,  0.2209]],

        [[-0.1841, -0.5683,  0.5857,  ...,  1.2628,  0.7154, -0.3790],
         [-0.9637, -0.0904,  0.5405,  ..., -0