# Implementation of a Transformer Architecture

In [45]:
from math import sqrt

import torch
from torch import tensor, nn
import torch.nn.functional as F

In [3]:
from dataclasses import dataclass

@dataclass
class TransformerEncoderConfig:
    vocab_size: int
    max_len: int  # maximum number of tokens in a sequence
    dim: int  # dimension of the embeddings
    num_attention_heads: int
    num_encoder_layers: int
    intermediate_dim: int  # dimension of the hidden layer in the feed forward part of an encoder layer
    hidden_dropout_prob: float  # dropout probability in the feed forward part of an encoder layer
    layer_norm_eps: float = 1e-12  # layer normalization epsilon in the embedding layer

In [4]:
def number_of_trainable_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def scaled_dot_product_attention(query, key, value):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
    weights = F.softmax(scores, dim=-1)
    
    return torch.bmm(weights, value)

In [5]:
class AttentionHead(nn.Module):

    def __init__(self, embedding_dim: int, head_dim: int):
        super().__init__()

        self.q = nn.Linear(embedding_dim, head_dim)
        self.k = nn.Linear(embedding_dim, head_dim)
        self.v = nn.Linear(embedding_dim, head_dim)
    
    def forward(self, hidden_state):
        return scaled_dot_product_attention(
            self.q(hidden_state),
            self.k(hidden_state),
            self.v(hidden_state)
        )

In [26]:
class MultiHeadAttention(nn.Module):

    def __init__(self, config: TransformerEncoderConfig):
        super().__init__()

        head_dim = config.dim // config.num_attention_heads
        assert head_dim * config.num_attention_heads == config.dim
        
        self.heads = nn.ModuleList(
            [AttentionHead(config.dim, head_dim) for _ in range(config.num_attention_heads)]
        )
        self.output_linear = nn.Linear(config.dim, config.dim)

    def forward(self, hidden_state):
        x = torch.cat([head(hidden_state) for head in self.heads], dim=-1)
        x = self.output_linear(x)
        return x

In [27]:
class FeedForward(nn.Module):

    def __init__(self, config: TransformerEncoderConfig):
        super().__init__()

        self.linear_1 = nn.Linear(config.dim, config.intermediate_dim)
        self.linear_2 = nn.Linear(config.intermediate_dim, config.dim)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.gelu(x)
        x = self.linear_2(x)
        x = self.dropout(x)
        return x

In [32]:
class TransformerEncoderLayer(nn.Module):
    
    def __init__(self, config: TransformerEncoderConfig):
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(config.dim)
        self.layer_norm_2 = nn.LayerNorm(config.dim)
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)
    
    def forward(self, x):
        # pre layer normalization with skip connections branching off before normalization
        
        hidden_state = self.layer_norm_1(x)
        x = x + self.attention(hidden_state)
        
        hidden_state = self.layer_norm_2(x)
        x = x + self.feed_forward(hidden_state)
        
        return x

In [33]:
class Embeddings(nn.Module):

    def __init__(self, config: TransformerEncoderConfig):
        super().__init__()
        
        self.token_embeddings = nn.Embedding(config.vocab_size, config.dim)
        self.position_embeddings = nn.Embedding(config.max_len, config.dim)
        self.layer_norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout()

    def forward(self, input_ids):
        # create position IDs
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0)
        
        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        
        embeddings = token_embeddings + position_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        
        return embeddings

In [34]:
class TransformerEncoder(nn.Module):
    
    def __init__(self, config: TransformerEncoderConfig):
        super().__init__()
        
        self.embeddings = Embeddings(config)
        self.layers = nn.ModuleList([TransformerEncoderLayer(config) for _ in range(config.num_encoder_layers)])
    
    def forward(self, x):
        x = self.embeddings(x)
        for layer in self.layers:
            x = layer(x)
        return x

In [60]:
BERT_config = TransformerEncoderConfig(
    vocab_size=30_522,
    max_len=512,
    dim=768,
    num_attention_heads=12,
    num_encoder_layers=12,
    intermediate_dim=3072,  # usually 4 times dim
    hidden_dropout_prob=0.1,
    layer_norm_eps=1e-12,
)

encoder = TransformerEncoder(BERT_config)
embeddings = encoder.embeddings
first_encoder_layer = encoder.layers[0]
print(f"Number of trainable parameters: {number_of_trainable_parameters(encoder):,}")
print(f"Number of trainable parameters in embedding layer: {number_of_trainable_parameters(embeddings):,}")
print(f"Number of trainable parameters per encoder layer: {number_of_trainable_parameters(first_encoder_layer):,}")

Number of trainable parameters: 108,890,112
Number of trainable parameters in embedding layer: 23,835,648
Number of trainable parameters per encoder layer: 7,087,872


In [51]:
example_input = tensor([[2051, 10029, 2066, 2019, 8612]])
encoded_example_input = encoder(example_input)
encoded_example_input.size()

torch.Size([1, 5, 768])