# Transformer Implementation
- Paper: Attention Is All You Need
- Template source: https://goyalpramod.github.io/blogs/Transformers_laid_out/


## Imports

In [1]:
import math

import torch
import torch.nn as nn
from torch.nn.functional import softmax

## Multi-Head Attention

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads" #think why?

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Note: use integer division //

        # Create the learnable projection matrices
        self.W_q = nn.Linear(d_model, d_model) #think why we are doing from d_model -> d_model
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    @staticmethod
    def scaled_dot_product_attention(query, key, value, mask=None):
        """
        Args:
            query: (batch_size, num_heads, seq_len_q, d_k)
            key: (batch_size, num_heads, seq_len_k, d_k)
            value: (batch_size, num_heads, seq_len_v, d_v)
            mask: Optional mask to prevent attention to certain positions
        """
        # get the size of d_k using the query or the key
        d_k = query.size(-1)

        # calculate the attention score using the formula given. Be vary of the dimension of Q and K. And what you need to transpose to achieve the desired results.
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

        # hint 1: batch_size and num_heads should not change
        # hint 2: nXm @ mXn -> nXn, but you cannot do nXm @ nXm, the right dimension of the left matrix should match the left dimension of the right matrix. The easy way I visualize it is as, who face each other must be same
        # add inf is a mask is given, This is used for the decoder layer. You can use help for this if you want to. I did!!
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # get the attention weights by taking a softmax on the scores, again be wary of the dimensions. You do not want to take softmax of batch_size or num_heads. Only of the values. How can you do that?
        attention_weights = softmax(scores, dim=-1)

        # return the attention by multiplying the attention weights with the Value (V)
        return torch.matmul(attention_weights, value), attention_weights

    def forward(self, query, key, value, mask=None):
        #get batch_size and sequence length
        batch_size, seq_len_q, _ = query.size()

        # 1. Linear projections
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)

        # 2. Split into heads
        Q = Q.view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)

        # 3. Apply attention
        output = self.scaled_dot_product_attention(Q, K, V, mask)

        # 4. Concatenate heads
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model)

        # 5. Final projection
        return self.W_o(output)

## Feed-Forward Network

In [3]:
class FeedForwardNetwork(nn.Module):
    """Position-wise Feed-Forward Network

    Args:
        d_model: input/output dimension
        d_ff: hidden dimension
        dropout: dropout rate (default=0.1)
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        #create a sequential ff model as mentioned in section 3.3
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
        Returns:
            Output tensor of shape (batch_size, seq_len, d_model)
        """
        return self.ff(x)

## Positional Encoding

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()

        # Create matrix of shape (max_seq_length, d_model)
        pe = torch.zeros(max_seq_length, d_model)

        # Create position vector
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)

        # Create division term
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        # Compute positional encodings
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Register buffer
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        """
        Args:
            x: Tensor shape (batch_size, seq_len, d_model)
        """
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :].detach()

## Encoder Layer

In [5]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        # 1. Multi-head attention
        self.mha = MultiHeadAttention(d_model, num_heads)

        # 2. Layer normalization
        self.layer_norm1 = nn.LayerNorm(d_model)

        # 3. Feed forward
        self.ff = FeedForwardNetwork(d_model, d_ff, dropout)

        # 4. Another layer normalization
        self.layer_norm2 = nn.LayerNorm(d_model)

        # 5. Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
            mask: Optional mask for padding
        Returns:
            x: Output tensor of shape (batch_size, seq_len, d_model)
        """
        # 1. Multi-head attention with residual connection and layer norm
        att_output = self.mha(x, x, x, mask)
        x = self.dropout(x + att_output)
        x = self.layer_norm1(x)

        # 2. Feed forward with residual connection and layer norm
        ff_output = self.ff(x)
        x = self.dropout(x + ff_output)
        x = self.layer_norm2(x)
        
        return x

## Decoder Layer

In [6]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        # 1. Masked Multi-head attention
        self.mha1 = MultiHeadAttention(d_model, num_heads)

        # 2. Layer norm for first sub-layer
        self.layer_norm1 = nn.LayerNorm(d_model)

        # 3. Multi-head attention for cross attention with encoder output
        # This will take encoder output as key and value
        self.mha2 = MultiHeadAttention(d_model, num_heads)

        # 4. Layer norm for second sub-layer
        self.layer_norm2 = nn.LayerNorm(d_model)

        # 5. Feed forward network
        self.ff = FeedForwardNetwork(d_model, d_ff, dropout)

        # 6. Layer norm for third sub-layer
        self.layer_norm3 = nn.LayerNorm(d_model)

        # 7. Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: Target sequence embedding (batch_size, target_seq_len, d_model)
            encoder_output: Output from encoder (batch_size, source_seq_len, d_model)
            src_mask: Mask for source padding
            tgt_mask: Mask for target padding and future positions
        """
        # 1. Masked self-attention
        # Remember: In decoder self-attention, query, key, value are all x
        attn_output1 = self.mha1(x, x, x, tgt_mask)
        x = self.dropout(x + attn_output1)
        x = self.layer_norm1(x)

        att_output2 = self.mha2(x, encoder_output, encoder_output, src_mask)
        x = self.dropout(x + att_output2)
        x = self.layer_norm2(x)

        ff_output = self.ff(x)
        x = self.dropout(x + ff_output)
        x = self.layer_norm3(x)
        
        return x

## Encoder

In [7]:
class Encoder(nn.Module):
    def __init__(self,
                 vocab_size,
                 d_model,
                 num_layers=6,
                 num_heads=8,
                 d_ff=2048,
                 dropout=0.1,
                 max_seq_length=5000):
        super().__init__()

        # 1. Input embedding
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.scale = math.sqrt(d_model)

        # 2. Positional encoding
        self.pe = PositionalEncoding(d_model, max_seq_length)

        # 3. Dropout
        self.dropout = nn.Dropout(dropout)

        # 4. Stack of N encoder layers
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])

    def forward(self, x, mask=None):
        """
        Args:
            x: Input tokens (batch_size, seq_len)
            mask: Mask for padding positions
        Returns:
            encoder_output: (batch_size, seq_len, d_model)
        """
        # 1. Pass through embedding layer and scale
        x = self.embedding(x) * self.scale

        # 2. Add positional encoding and apply dropout
        x = self.dropout(self.pe(x))

        # 3. Pass through each encoder layer
        for layer in self.encoder_layers:
            x = layer(x, mask)

        return x

## Decoder

In [8]:
class Decoder(nn.Module):
    def __init__(self,
                 vocab_size,
                 d_model,
                 num_layers=6,
                 num_heads=8,
                 d_ff=2048,
                 dropout=0.1,
                 max_seq_length=5000):
        super().__init__()

        # 1. Output embedding
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.scale = math.sqrt(d_model)

        # 2. Positional encoding
        self.pe = PositionalEncoding(d_model, max_seq_length)

        # 3. Dropout
        self.dropout = nn.Dropout(dropout)

        # 4. Stack of N decoder layers
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: Target tokens (batch_size, target_seq_len)
            encoder_output: Output from encoder (batch_size, source_seq_len, d_model)
            src_mask: Mask for source padding
            tgt_mask: Mask for target padding and future positions
        Returns:
            decoder_output: (batch_size, target_seq_len, d_model)
        """
        # 1. Pass through embedding layer and scale
        x = self.embedding(x) * self.scale

        # 2. Add positional encoding and dropout
        x = self.dropout(self.pe(x))

        # 3. Pass through each decoder layer
        for layer in self.decoder_layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)

        return x

## Utilities

In [9]:
def create_padding_mask(seq):
    """
    Create mask for padding tokens (0s)
    Args:
        seq: Input sequence tensor (batch_size, seq_len)
    Returns:
        mask: Padding mask (batch_size, 1, 1, seq_len)
    """
    batch_size, seq_len = seq.shape
    output = torch.eq(seq, 0).float()
    return output.view(batch_size, 1, 1, seq_len)

def create_future_mask(size):
    """
    Create mask to prevent attention to future positions
    Args:
        size: Size of square mask (target_seq_len)
    Returns:
        mask: Future mask (1, 1, size, size)
    """
    # Create upper triangular matrix and invert it
    mask = torch.triu(torch.ones(1, 1, size, size), diagonal=1) == 0
    return mask

def create_masks(src, tgt):
    """
    Create all masks needed for training
    Args:
        src: Source sequence (batch_size, src_len)
        tgt: Target sequence (batch_size, tgt_len)
    Returns:
        src_mask: Padding mask for encoder
        tgt_mask: Combined padding and future mask for decoder
    """
    # 1. Create padding masks
    src_mask = create_padding_mask(src)
    tgt_mask = create_padding_mask(tgt)

    # 2. Create future mask
    tgt_len = tgt.size(1)
    future_mask = create_future_mask(tgt_len)

    # 3. Combine padding and future mask for target
    # Both masks should be True for allowed positions
    tgt_mask = tgt_mask & future_mask

    return src_mask, tgt_mask

## Transformer

In [10]:
class Transformer(nn.Module):
    def __init__(self,
                 src_vocab_size,
                 tgt_vocab_size,
                 d_model,
                 num_layers=6,
                 num_heads=8,
                 d_ff=2048,
                 dropout=0.1,
                 max_seq_length=5000):
        super().__init__()

        # Pass all necessary parameters to Encoder and Decoder
        self.encoder = Encoder(src_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)
        self.decoder = Decoder(tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)
        
        # The final linear layer should project from d_model to tgt_vocab_size
        self.final_layer = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):
        # Create masks for source and target
        src_mask, tgt_mask = create_masks(src, tgt)

        # Pass through encoder
        encoder_output = self.encoder(src, src_mask)

        # Pass through decoder
        decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask)

        # Project to vocabulary size
        output = self.final_layer(decoder_output)

        return output

## Transformer Training Utilities

In [11]:
class TransformerLRScheduler:
    def __init__(self, optimizer, d_model, warmup_steps):
        """
        Args:
            optimizer: Optimizer to adjust learning rate for
            d_model: Model dimensionality
            warmup_steps: Number of warmup steps
        """
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps


    def step(self, step_num):
        """
        Update learning rate based on step number
        """
        # lrate = d_model^(-0.5) * min(step_num^(-0.5), step_num * warmup_steps^(-1.5))
        lrate = self.d_model ** (-0.5) * min(step_num ** (-0.5), step_num * self.warmup_steps ** (-1.5))

class LabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
        self.confidence = 1.0 - smoothing

    def forward(self, logits, target):
        """
        Args:
            logits: Model predictions (batch_size, vocab_size) #each row of vocab_size contains probability score of each label
            target: True labels (batch_size) #each row of batch size contains the index to the correct label
        """
        #Note: make sure to not save the gradients of these
        # Create a soft target distribution
        #create the zeros [0,0,...]
        #fill with calculated value [0.000125..,0.000125...] (this is an arbitarary value for example purposes)
        #add 1 to the correct index (read more on docs of pytorch)
        #return cross entropy loss
        vocab_size = logits.size(-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(logits)
            true_dist.fill_(self.smoothing / (vocab_size - 1))
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * torch.log_softmax(logits, dim=-1), dim=-1))