<a href="https://colab.research.google.com/github/azilya/torch_tutorials/blob/main/PyTorch_Transformer_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Building a basic transformer model to better understand its structure

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

## Basic building blocks: multi-head attention, positional embeddings, position-wise ff

### Multi-Head Attention

The main point of transformer models: computes attention between input tokens (how much each token influences the others).
Multiple heads can compute multiple attentions for the same input, focusing on different features

The forward method of the torch module does the following:
1. reshapes input according to the number of heads,
1. splits input into query, key and value vectors,
1. calculates scaled dot product attention
1. reshapes output from multiple heads

In [2]:
class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads, dim_model):
    super().__init__()
    assert dim_model % num_heads == 0, "Model dimensions must be divisible by the number of heads"

    # init vector dimensions
    self.dim_model = dim_model  # model dimensions
    self.num_heads = num_heads  # number of heads
    self.sub_dim = dim_model // num_heads  # dimensions of Q, K, V vectors

    # init linear layers
    self.W_q = nn.Linear(dim_model, dim_model)  # Q transform
    self.W_k = nn.Linear(dim_model, dim_model)  # K transform
    self.W_v = nn.Linear(dim_model, dim_model)  # V transform
    self.W_o = nn.Linear(dim_model, dim_model)  # output transform

  def scaled_dot_prod_attention(
      self,
      Q: torch.Tensor,
      K: torch.Tensor,
      V: torch.Tensor,
      mask: torch.Tensor = None
    ):
    # calculate attention scores: multiply Q by K, and divide by root of sub-dimension
    attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.sub_dim)

    # mask some of the inputs, if mask is provided, to improve predictions
    if mask is not None:
      attn_scores = attn_scores.masked_fill(mask==0, -1e9)

    # apply softmax to scores to obtain probabilities
    attn_probs = torch.softmax(attn_scores, dim=-1)

    # multiply scores by V to obtain the final output
    output = torch.matmul(attn_probs, V)

    return output

  def split_heads(self, x: torch.Tensor):
    # reshape input by the number of attention heads
    batch_size, seq_len, dim_model = x.size()
    return x.view(batch_size, seq_len, self.num_heads, self.sub_dim).transpose(1, 2)

  def combine_heads(self, x: torch.Tensor):
    # recombine outputs from attention heads in a single tensor
    # and make it contiguous for easier further use
    batch_size, _, seq_len, sub_dim = x.size()
    return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.dim_model)

  def forward(self,
      Q: torch.Tensor,
      K: torch.Tensor,
      V: torch.Tensor,
      mask: torch.Tensor = None
    ):
    # apply transformations and split heads
    Q = self.split_heads(self.W_q(Q))
    K = self.split_heads(self.W_k(K))
    V = self.split_heads(self.W_v(V))

    # calculate attention scores
    attn_output = self.scaled_dot_prod_attention(Q, K, V, mask)
    # combine head outputs and apply final transformation
    output = self.W_o(self.combine_heads(attn_output))
    return output

### Position-wise Feed-Forward Network

Transformation, applied separately to each position of attention outputs. Two FF layers with a ReLU activation between them.

In [3]:
class PositionWiseFF(nn.Module):
  def __init__(self, dim_model, dim_ff):
    super().__init__()
    self.layer1 = nn.Linear(dim_model, dim_ff)
    self.layer2 = nn.Linear(dim_ff, dim_model)
    self.relu = nn.ReLU()

  def forward(self, x):
    return self.layer2(self.relu(self.layer1(x)))

### Positional Embedding

Sine/cosine function, encoding each token's position in the input. This helps the model to consider relative positions of tokens in the sequence, in addition to their attention to each other.

In [4]:
class PositionalEmbedding(nn.Module):
  def __init__(
      self,
      dim_model: int,
      max_seq_len: int = 4096,
      dropout: float = 0.5
    ):
    super().__init__()
    self.dropout = nn.Dropout(dropout)

    position = torch.arange(max_seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, dim_model, 2) * -(math.log(10_000.0) / dim_model))

    pe = torch.zeros(max_seq_len, dim_model)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    # register as an untrainable parameter
    self.register_buffer('pe', pe.unsquueze(0))

  def forward(self, x):
    return self.dropout(x + self.pe[:, :x.size(1)])

## Encoder Layer

A single layer of the Transformer model, encoding the input. It encapsulates the multi-head attention block and the ff network, described above, adding a LayerNorm to each of their outputs. Thus the relations within the input can be efficiently encoded as embeddings, that can be used for other tasks.

Generally models consist of several encoder layers, output of previous ones being fed in the following ones.

Layer workings:
1. Multi-head attention block
1. Add + Layer normalization 1
1. Position-wise feed-forward network
1. Add + Layer normalization 2

Dropout is applied to layer outputs before normalization.

In [5]:
class EncoderLayer(nn.Module):
  def __init__(self, dim_model: int, num_heads: int, dim_ff: int, dropout: float):
    super().__init__()
    self.self_attention = MultiHeadAttention(num_heads, dim_model)
    self.ff = PositionWiseFF(dim_model, dim_ff)
    self.norm1 = nn.LayerNorm(dim_model)
    self.norm2 = nn.LayerNorm(dim_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask=None):
    attn_output = self.self_attention(x, x, x, mask)
    x = self.norm1(x + self.dropout(attn_output))
    ff_output = self.ff(x)
    x = self.norm2(x + self.dropout(ff_output))
    return x

## Transformer Model

Several encoder layers can be stacked together to create an encoder transformer. This model produces a more meaningful embedding of the output than a single layer, which can be consequently used for downstream tasks, e.g. generation, classification etc.

In [6]:
class TransformerModel(nn.Module):
  def __init__(
      self,
      vocab_size: int,
      dim_model: int,
      dim_ff: int,
      n_heads: int,
      n_layers: int,
      dropout: float = 0.5
    ):
    super().__init_()
    self.embedding = nn.Embedding(vocab_size, dim_model)
    self.pos_embedding = PositionalEmbedding(dim_model, dropout=dropout)
    encoder_layer = EncoderLayer(dim_model, n_heads, dim_ff, dropout)
    self.encoder_layers = nn.ModuleList([encoder_layer for _ in range(n_layers)])
    self.dropout = nn.Dropout(dropout)
    self.linear = nn.Linear(dim_model, vocab_size)

    self.init_weights()

  def init_weights(self) -> None:
      initrange = 0.1
      self.embedding.weight.data.uniform_(-initrange, initrange)
      self.linear.bias.data.zero_()
      self.linear.weight.data.uniform_(-initrange, initrange)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.dropout(self.pos_embedding(self.embedding(x)))
    x_enc = x
    for enc_layer in self.encoder_layers:
      x_enc = enc_layer(x)
    output = self.linear(x_enc)
    return output

## The best part

All Transformer parts can actually be just imported from Pytorch.

In [None]:
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class TransformerModel(nn.Module):
  def __init__(self, vocab_size: int, dim_model: int, dim_ff: int,
               n_heads: int, n_layers: int, dropout: float):
    super().__init__()
    self.model_type = 'Transformer'
    self.pos_encoder = PositionalEmbedding(dim_model, dropout)
    encoder_layers = TransformerEncoderLayer(dim_model, n_heads, dim_ff, dropout)
    self.encoder_layers = TransformerEncoder(encoder_layers, n_layers)
    self.embedding = nn.Embedding(vocab_size, dim_model)
    self.dim_model = dim_model
    self.linear = nn.Linear(dim_model, vocab_size)

    self.init_weights()

  def init_weights(self) -> None:
      initrange = 0.1
      self.embedding.weight.data.uniform_(-initrange, initrange)
      self.linear.bias.data.zero_()
      self.linear.weight.data.uniform_(-initrange, initrange)

  def forward(self, x: torch.Tensor,
              mask: torch.Tensor = None) -> torch.Tensor:
      x = self.pos_embedding(self.embedding(x))
      x_enc = self.encoder_layers(x, mask)
      output = self.linear(x_enc)
      return output