# Building a Transformer from Scratch with PyTorch

In this notebook, I'll implement a transformer architecture using PyTorch. Instead of using pre-defined components, I'll build each part from scratch, including:

- Custom tokenizer
- Embedding layer
- Positional encoding
- Self-attention mechanism
- Encoder and decoder layers
- Complete transformer model

This approach helps to understand the inner workings of transformer models that power modern NLP.

In [1]:
import torch
import torch.nn as nn
from collections import Counter
from typing import List, Dict, Union, Optional, Any
from torch import Tensor
import math
class PyTorchTokenizer(nn.Module):
    def __init__(self, max_vocab_size: int = 10000):
        super().__init__()
        
        # Special token indices
        self.PAD_IDX: int = 0
        self.UNK_IDX: int = 1
        self.SOS_IDX: int = 2
        self.EOS_IDX: int = 3
        
        # Tokenizer attributes
        self.max_vocab_size: int = max_vocab_size
        self.word_to_index: Dict[str, int] = {}
        self.index_to_word: Dict[int, str] = {}
        
        # Special tokens
        special_tokens: List[str] = ['<PAD>', '<UNK>', '< SOS >', '<EOS>']
        for idx, token in enumerate(special_tokens):
            self.word_to_index[token] = idx
            self.index_to_word[idx] = token
        
        # Track vocabulary size
        self.vocab_size: int = len(special_tokens)
    
    def fit_on_texts(self, texts: List[str]) -> None:
        """Build vocabulary from input texts"""
        # Tokenize and count word frequencies
        words: List[str] = [word for text in texts for word in text.split()]
        word_counts: Counter = Counter(words)
        
        # Sort words by frequency, descending order
        sorted_words: List[tuple] = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
        
        # Add most frequent words to vocabulary
        for word, _ in sorted_words:
            if word not in self.word_to_index and self.vocab_size < self.max_vocab_size:
                self.word_to_index[word] = self.vocab_size
                self.index_to_word[self.vocab_size] = word
                self.vocab_size += 1
    
    def texts_to_sequences(
        self, 
        texts: List[str], 
        add_sos_eos: bool = True, 
        max_length: Optional[int] = None
    ) -> torch.Tensor:
        """
        Convert texts to tensor or list of indices
        
        Args:
            texts: List of sentences
            add_sos_eos: Add start/end tokens
            padding: Pad sequences to same length
            max_length: Maximum sequence length
        
        Returns:
            Padded sequence tensor or list of sequences
        """
        # Convert to indices
        sequences: List[List[int]] = []
        for text in texts:
            tokens: List[str] = text.split()
            
            # Add special tokens if requested
            if add_sos_eos:
                tokens = ['< SOS >'] + tokens + ['<EOS>']
            
            # Convert to indices
            sequence: List[int] = [self.word_to_index.get(word, self.UNK_IDX) for word in tokens]
            sequences.append(sequence)
        
        # Determine max length
        if max_length is None:
            max_length = max(len(seq) for seq in sequences)
        
        # Pad sequences
        
        padded_sequences: List[List[int]] = []
        for seq in sequences:
            # Truncate or pad
            seq = seq[:max_length]
            seq = seq + [self.PAD_IDX] * (max_length - len(seq))
            padded_sequences.append(seq)
        
        return torch.tensor(padded_sequences, dtype=torch.long)

    
    def sequences_to_texts(self, sequences: torch.Tensor) -> List[str]:
        """Convert sequences back to texts"""
        texts: List[str] = []
        for sequence in sequences:
            # Convert indices to words; .item to get from torch
            words: List[str] = [self.index_to_word.get(idx.item(), '<UNK>') for idx in sequence]
            # Remove special tokens and padding
            words = [w for w in words if w not in ['<PAD>', '< SOS >', '<EOS>']]
            texts.append(' '.join(words))
        return texts

    # Sample texts
texts: List[str] = [
    "hello world",
    "machine learning is awesome",
    "pytorch is great for deep learning"
]

# Create tokenizer
tokenizer: PyTorchTokenizer = PyTorchTokenizer(max_vocab_size=20)

# Fit on texts
tokenizer.fit_on_texts(texts)

# Convert to sequences
sequences: torch.Tensor = tokenizer.texts_to_sequences(texts)

print("Sequences shape:", sequences.shape)
print("Vocab size:", tokenizer.vocab_size)

# Convert back to texts
reconstructed_texts: List[str] = tokenizer.sequences_to_texts(sequences)
print("\nReconstructed texts:")
for original, reconstructed in zip(texts, reconstructed_texts):
    print(f"Original:     {original}")
    print(f"Reconstructed: {reconstructed}\n")

Sequences shape: torch.Size([3, 8])
Vocab size: 14

Reconstructed texts:
Original:     hello world
Reconstructed: hello world

Original:     machine learning is awesome
Reconstructed: machine learning is awesome

Original:     pytorch is great for deep learning
Reconstructed: pytorch is great for deep learning



## Custom Embedding Layer

This class implements a custom embedding layer without relying on PyTorch's built-in `nn.Embedding`. 
The embedding layer maps token indices to dense vectors of fixed size. 

Key features:
- Support for different initialization methods (uniform, normal, xavier)
- Special handling for padding tokens
- Manual implementation of the lookup process

In [2]:
class PyTorchEmbedding(nn.Module):
    def __init__(
        self, 
        token_size: int, 
        d_model: int, 
        padding_idx: Optional[int] = None,
        init_method: str = 'uniform'
    ):
        """
        Custom embedding layer without using nn.Embedding
        
        Args:
            token_size: Number of tokens in vocabulary
            d_model: Embedding dimension
            padding_idx: Index to set to zero
            init_method: Weight initialization method
        """
        super().__init__()
        # Initialize weights based on method
        if init_method == 'uniform':
            self.weights = torch.rand(token_size, d_model) * 2 - 1  # [-1, 1]
        elif init_method == 'normal':
            self.weights = torch.randn(token_size, d_model)
        elif init_method == 'xavier':
            self.weights = torch.nn.init.xavier_uniform_(
                torch.empty(token_size, d_model)
            )
        else:
            raise ValueError(f"Unknown init method: {init_method}")
        
        # Zero out padding index if specified
        if padding_idx is not None:
            self.weights[padding_idx].zero_()
        
        self.token_size = token_size
        self.d_model = d_model
    
    def forward(self, token_sequences: Tensor) -> Tensor:
        """
        Lookup embeddings for given indices
        
        Args:
            token_sequences: Tensor of token indices (# sentence, token_size)
        
        Returns:
            Tensor of embedded tokens
        """
        # Create output tensor
        output = torch.zeros(
            token_sequences.shape[0],  # batch size
            token_sequences.shape[1],  # sequence length 
            self.d_model,      # embedding dimension
            dtype=self.weights.dtype
        )
        print(f'Output shape: {output.shape}')
        
        # Manually lookup embeddings
        for i, sentence in enumerate(token_sequences):
            for j, token_idx in enumerate(sentence):
                output[i, j] = self.weights[token_idx]
        
        return output
    def __call__(self, indices: Tensor) -> Tensor:
        """
        Make the class callable for convenience
        """
        return self.forward(indices)
embedding = PyTorchEmbedding(token_size=20,
                             d_model=6)
outputs = embedding.forward(sequences)
print((f'There are {outputs.shape[0]} sentences/batch size;'),
    (f"{outputs.shape[1]} unique tokens;"),
    f"{outputs.shape[2]} d_model")

Output shape: torch.Size([3, 8, 6])
There are 3 sentences/batch size; 8 unique tokens; 6 d_model


# Positional Encoding

Positional encoding is a crucial component of transformer architectures. Since transformers process tokens in parallel without any inherent notion of sequence order, we need to explicitly add position information.

This implementation uses sinusoidal positional encoding as described in the "Attention is All You Need" paper:
- Uses sine and cosine functions of different frequencies
- Position information is added to the token embeddings
- Each position is encoded as a unique vector that follows a specific pattern

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_seq_length: int = 5000):
        super().__init__()
        
        # Create positional encoding matrix
        position = torch.arange(max_seq_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        
        # Initialize the positional encoding buffer
        pe = torch.zeros(1, max_seq_length, d_model)
        
        # Fill with sinusoidal pattern
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        
        # Register as buffer (not a parameter, but part of the model state)
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Add positional encoding to input embeddings
        
        Args:
            x: Tensor of shape [batch_size, seq_length, d_model]
            
        Returns:
            Tensor with positional encoding added
        """
        # Add positional encoding to input (only up to the sequence length)
        x = x + self.pe[:, :x.size(1), :]
        return x

## Self-Attention Mechanism

Self-attention is the core component of transformer models that allows them to model dependencies between tokens regardless of their distance in the sequence.

This implementation features:
- Multi-head attention that splits the representation into multiple parts
- Query, key, and value projections
- Scaled dot-product attention with softmax normalization
- Residual connection to preserve information flow

The self-attention mechanism allows each position to attend to all positions in the sequence, capturing complex dependencies.

In [None]:
class PyTorchSelfAttention(nn.Module):
    def __init__(self, num_head: int = 8, d_model: int = 64):
        super().__init__()
        assert (d_model % num_head == 0), 'd_model must be divisible by num_head'
        
        self.num_head = num_head
        self.d_model = d_model
        self.d_head = int(d_model / num_head)
        self.W_q = nn.Parameter(torch.randn(d_model, d_model))
        self.W_v = nn.Parameter(torch.randn(d_model, d_model))
        self.W_k = nn.Parameter(torch.randn(d_model, d_model))
        self.W_o = nn.Parameter(torch.randn(d_model, d_model))

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]
        seq_len = x.shape[1]
        return x.reshape(batch_size, seq_len, self.num_head, self.d_head).transpose(1, 2)
    
    def join_heads(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]
        seq_len = x.shape[2]
        return x.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass for self-attention with optional masking.
        
        Args:
            x: Input tensor (batch_size, seq_len, d_model)
            mask: Mask tensor (batch_size, 1, seq_len, seq_len) or (1, 1, seq_len, seq_len)
                  1s indicate positions to mask
            
        Returns:
            Output tensor (batch_size, seq_len, d_model)
        """
        # Input x: (batch_size, seq_len, d_model)
        
        # Linear projections
        q = torch.matmul(x, self.W_q)  # (batch_size, seq_len, d_model)
        v = torch.matmul(x, self.W_v)  # (batch_size, seq_len, d_model)
        k = torch.matmul(x, self.W_k)  # (batch_size, seq_len, d_model)
        
        # Split heads
        q = self.split_heads(q)  # Shape: (batch_size, num_heads, seq_len, d_head)
        k = self.split_heads(k)  # Shape: (batch_size, num_heads, seq_len, d_head)
        v = self.split_heads(v)  # Shape: (batch_size, num_heads, seq_len, d_head)
        
        # Compute attention scores
        # q.shape:   (batch_size, num_heads, seq_len, d_head)
        # k.T.shape: (batch_size, num_heads, d_head, seq_len)
        # Result:    (batch_size, num_heads, seq_len, seq_len)
        attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
        
        # Apply mask if provided
        if mask is not None:
            # Add a large negative value to masked positions to make their softmax score ~0
            attention_scores = attention_scores.masked_fill(mask.bool(), -1e9)
        
        # Softmax to get attention weights
        # Maintains same shape: (batch_size, num_heads, seq_len, seq_len)
        attention_weights = torch.softmax(attention_scores, dim=-1)
        
        # Compute weighted sum of values
        # attention_weights: (batch_size, num_heads, seq_len, seq_len)
        # v:                 (batch_size, num_heads, seq_len, d_head)
        # Result:            (batch_size, num_heads, seq_len, d_head)
        head_output = torch.matmul(attention_weights, v)
        
        # Combine heads back to original dimension
        output = self.join_heads(head_output)
        
        # Linear projection
        output = torch.matmul(output, self.W_o)
        
        # Residual connection
        output += x
        return output

In [5]:
# Testing the self-attention implementation with different configurations
test_configs = [
        (2, 10, 64),   # batch_size=2, seq_len=10, d_model=64
        (4, 20, 128),  # batch_size=4, seq_len=20, d_model=128
    ]

for batch_size, seq_len, d_model in test_configs:
    # Create multi-head attention with default 8 heads
    sa = PyTorchSelfAttention(num_head=8, d_model=d_model)
    
    # Create zero tensor input
    x = torch.zeros(batch_size, seq_len, d_model)
    
    # Run forward pass
    output = sa(x)
    
    print(f"\nConfig: batch_size={batch_size}, seq_len={seq_len}, d_model={d_model}")
    print("Input shape:", x.shape)
    
    # Verify intermediate shapes
    q = torch.matmul(x, sa.W_q)
    q_split = sa.split_heads(q)
    print("Q after split shape:", q_split.shape)
    
    # Verifying output shape matches input shape
    print("Output shape:", output.shape)
    assert output.shape == x.shape, "Output shape must match input shape"


Config: batch_size=2, seq_len=10, d_model=64
Input shape: torch.Size([2, 10, 64])
Q after split shape: torch.Size([2, 8, 10, 8])
Output shape: torch.Size([2, 10, 64])

Config: batch_size=4, seq_len=20, d_model=128
Input shape: torch.Size([4, 20, 128])
Q after split shape: torch.Size([4, 8, 20, 16])
Output shape: torch.Size([4, 20, 128])


# Encoder Block Implementation

The encoder block combines several components of the transformer architecture:
- Positional encoding to add position information
- Multi-head self-attention to capture contextual relationships
- Feed-forward neural network for further processing
- Layer normalization and residual connections for stable training
- Dropout for regularization

This implementation includes two self-attention layers to increase the model's capacity to capture complex dependencies.

In [6]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model: int, num_head: int = 8, 
    d_ff: int = 2048, dropout: float = 0.1):
        super().__init__()
        self.self_attn = PyTorchSelfAttention(num_head=num_head, d_model=d_model)
        self.self_attn2 = PyTorchSelfAttention(num_head=num_head, d_model=d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        
        # Feed-forward network
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Add positional encoding
        x = self.pos_encoder(x)
        
        # Self-attention block with residual connection and layer norm
        attn_output = self.self_attn(x)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward block with residual connection and layer norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        attn_output = self.self_attn2(x)
        x = self.norm3(x + self.dropout(attn_output))
        return x

In [7]:
# Testing the encoder layer with different configurations
test_configs = [
        (2, 10, 64),   # batch_size=2, seq_len=10, d_model=64
        (4, 20, 128),  # batch_size=4, seq_len=20, d_model=128
    ]

for batch_size, seq_len, d_model in test_configs:
    # Create multi-head attention with default 8 heads
    encoder = TransformerEncoderLayer(d_model=d_model,num_head=8)
    
    # Create zero tensor input
    x = torch.zeros(batch_size, seq_len, d_model)
    
    # Run forward pass
    output = encoder(x)
    
    print(f"\nConfig: batch_size={batch_size}, seq_len={seq_len}, d_model={d_model}")
    print("Input shape:", x.shape)
    
    # Verifying output shape matches input shape
    print("Output shape:", output.shape)
    assert output.shape == x.shape, "Output shape must match input shape"


Config: batch_size=2, seq_len=10, d_model=64
Input shape: torch.Size([2, 10, 64])
Output shape: torch.Size([2, 10, 64])

Config: batch_size=4, seq_len=20, d_model=128
Input shape: torch.Size([4, 20, 128])
Output shape: torch.Size([4, 20, 128])


# Transformer Decoder Layer

The decoder layer is a crucial component in sequence-to-sequence transformer models. It processes the target sequence while attending to the encoder's output.

Key components of the decoder layer:
1. Masked self-attention - Allows each position to attend only to previous positions (causal masking)
2. Cross-attention - Enables the decoder to attend to all positions in the encoder output
3. Feed-forward network - Further processes the attended information
4. Layer normalization and residual connections - Stabilize training

The decoder produces outputs that can be used to predict the next token in an autoregressive fashion.

In [8]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model: int, num_head: int = 8, d_ff: int = 2048, dropout: float = 0.1):
        super().__init__()
        
        # Self-attention layer (with masking)
        self.self_attn = PyTorchSelfAttention(num_head=num_head, d_model=d_model)
        
        # Cross-attention layer (to attend to encoder outputs)
        self.cross_attn = PyTorchSelfAttention(num_head=num_head, d_model=d_model)
        
        # Feed-forward network
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor, 
                src_mask: Optional[torch.Tensor] = None, 
                tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            x: Target sequence (batch_size, tgt_seq_len, d_model)
            encoder_output: Output from encoder (batch_size, src_seq_len, d_model)
            src_mask: Mask for encoder outputs (optional)
            tgt_mask: Mask for decoder inputs (optional)
        """
        # First sub-layer: Self-attention with causal mask
        # Need to modify self_attn to handle the mask
        attn_output = self.apply_attention(self.self_attn, x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Second sub-layer: Cross-attention to encoder outputs
        cross_attn_output = self.apply_attention(self.cross_attn, x, encoder_output, encoder_output, src_mask)
        x = self.norm2(x + self.dropout(cross_attn_output))
        
        # Third sub-layer: Feed-forward network
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x
    
    def apply_attention(self, attention_layer: PyTorchSelfAttention, 
                        query: torch.Tensor, 
                        key: torch.Tensor, 
                        value: torch.Tensor, 
                        mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Helper method to apply attention with masking"""
        # This needs modifications to the original self-attention to handle masks
        # For now, this is a placeholder assuming we'll extend PyTorchSelfAttention
        return attention_layer(query)  # Would need to be modified to handle key, value, mask

# Complete Encoder Stack

This class implements a full transformer encoder, which consists of:
1. Positional encoding at the input
2. A stack of N encoder layers
3. A final layer normalization

The encoder processes the source sequence and produces contextual representations that capture the relationships between tokens. These representations are then used by the decoder for generating the target sequence.

Note: In a full transformer, the query and value come from the decoder while the key comes from the encoder in the cross-attention mechanism.

In [9]:
class TransformerEncoder(nn.Module):
    def __init__(self, d_model: int, num_head: int, d_ff: int, 
                 num_layers: int, max_seq_length: int, dropout: float = 0.1):
        super().__init__()
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model, max_seq_length)
        
        # Stack of encoder layers
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_head, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Final normalization layer
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            src: Source sequence (batch_size, src_seq_len, d_model)
            src_mask: Mask to avoid attending to padding tokens
        """
        # Add positional encoding
        x = self.pos_encoder(src)
        
        # Apply each encoder layer in sequence
        for layer in self.layers:
            x = layer(x)  # We'd need to modify this to pass the mask
        
        # Apply final normalization
        return self.norm(x)

# Complete Decoder Stack

The complete decoder stack consists of:
1. Positional encoding for the target sequence
2. Multiple decoder layers that process the target and attend to the encoder output
3. A final layer normalization

The decoder operates in an autoregressive manner during inference, predicting one token at a time. During training, teacher forcing is typically used where the entire target sequence is provided, but with appropriate masking to prevent attending to future tokens.

The decoder combines information from the target sequence processed so far with the context provided by the encoder to generate predictions for the next tokens.

In [10]:
class TransformerDecoder(nn.Module):
    def __init__(self, d_model: int, num_head: int, d_ff: int, 
                 num_layers: int, max_seq_length: int, dropout: float = 0.1):
        super().__init__()
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model, max_seq_length)
        
        # Stack of decoder layers
        self.layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, num_head, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Final normalization layer
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, tgt: torch.Tensor, encoder_output: torch.Tensor,
                src_mask: Optional[torch.Tensor] = None,
                tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            tgt: Target sequence (batch_size, tgt_seq_len, d_model)
            encoder_output: Output from encoder (batch_size, src_seq_len, d_model)
            src_mask: Mask for encoder outputs
            tgt_mask: Mask for decoder inputs
        """
        # Add positional encoding
        x = self.pos_encoder(tgt)
        
        # Apply each decoder layer in sequence
        for layer in self.layers:
            # forward
            x = layer(x, encoder_output, src_mask, tgt_mask)
        
        # Apply final normalization
        return self.norm(x)

# Complete Transformer Model

This class implements the full transformer architecture, combining:
1. Input embeddings for source and target tokens
2. The encoder stack to process the source sequence
3. The decoder stack to generate the target sequence
4. A final output projection to vocabulary size

The transformer model is designed for sequence-to-sequence tasks like machine translation, text summarization, and question answering. It captures long-range dependencies effectively thanks to the self-attention mechanism.

This implementation follows the architecture described in the "Attention is All You Need" paper, with customizable parameters for model size and configuration.

In [11]:
class Transformer(nn.Module):
    def __init__(self, 
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 d_model: int = 512, 
                 num_head: int = 8,
                 d_ff: int = 2048,
                 num_layers: int = 6,
                 max_seq_length: int = 5000,
                 dropout: float = 0.1):
        super().__init__()
        
        # Embedding layers
        self.src_embed = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model)
        
        # Encoder and decoder
        self.encoder = TransformerEncoder(d_model, num_head, d_ff, num_layers, max_seq_length, dropout)
        self.decoder = TransformerDecoder(d_model, num_head, d_ff, num_layers, max_seq_length, dropout)
        
        # Final output layer
        self.output_layer = nn.Linear(d_model, tgt_vocab_size)
        
        # Initialize parameters
        self._init_parameters()
        
    def _init_parameters(self):
        """Initialize parameters with Xavier uniform"""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, src: torch.Tensor, tgt: torch.Tensor,
                src_mask: Optional[torch.Tensor] = None,
                tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass through the entire transformer
        
        Args:
            src: Source sequence (batch_size, src_seq_len)
            tgt: Target sequence (batch_size, tgt_seq_len)
            src_mask: Mask for source sequence
            tgt_mask: Mask for target sequence
            
        Returns:
            Output logits (batch_size, tgt_seq_len, tgt_vocab_size)
        """
        # Embed source and target sequences
        src_embedded = self.src_embed(src)
        tgt_embedded = self.tgt_embed(tgt)
        
        # Encode the source sequence
        encoder_output = self.encoder(src_embedded, src_mask)
        
        # Decode with encoder output and target sequence
        decoder_output = self.decoder(tgt_embedded, encoder_output, src_mask, tgt_mask)
        
        # Project to vocabulary size
        output = self.output_layer(decoder_output)
        
        return output

# Masking Mechanism

Masking is essential in transformer models for two main purposes:

1. **Padding Mask**: To prevent the model from attending to padding tokens in variable-length sequences.
2. **Causal/Look-ahead Mask**: To ensure that predictions for a position can only depend on known outputs at previous positions (used in decoder).

This implementation provides utility functions to create both types of masks and combine them when needed. Proper masking is crucial for both model performance and ensuring the autoregressive property during generation.

In [12]:
def create_padding_mask(seq: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
    """
    Create mask for padding tokens
    
    Args:
        seq: Input sequence (batch_size, seq_len)
        pad_idx: Index representing padding token
        
    Returns:
        Mask tensor (batch_size, 1, 1, seq_len)
        1s indicate positions to mask (padding)
    """
    # Create mask for padding: 1 for padding, 0 for non-padding
    # Shape: (batch_size, seq_len)
    mask = (seq == pad_idx).float()
    
    # Add dimensions for attention heads and query position
    # Shape: (batch_size, 1, 1, seq_len)
    return mask.unsqueeze(1).unsqueeze(2)

def create_causal_mask(seq_len: int) -> torch.Tensor:
    """
    Create causal mask to prevent attending to future positions
    
    Args:
        seq_len: Length of sequence
        
    Returns:
        Mask tensor (1, 1, seq_len, seq_len)
        1s indicate positions to mask (future positions)
    """
    # Create a triangular matrix where 1s represent future positions to mask
    # Shape: (seq_len, seq_len)
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).float()
    
    # Add batch and head dimensions
    # Shape: (1, 1, seq_len, seq_len)
    return mask.unsqueeze(0).unsqueeze(0)

# Example usage:
def get_masks(src: torch.Tensor, tgt: torch.Tensor, pad_idx: int = 0) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Get all masks needed for transformer training
    
    Args:
        src: Source sequence (batch_size, src_seq_len)
        tgt: Target sequence (batch_size, tgt_seq_len)
        pad_idx: Index representing padding token
        
    Returns:
        src_mask: Mask for source padding
        tgt_mask: Combined mask for target (padding + causal)
    """
    # Source padding mask
    src_mask = create_padding_mask(src, pad_idx)
    
    # Target padding mask
    tgt_padding_mask = create_padding_mask(tgt, pad_idx)
    
    # Target causal mask
    tgt_seq_len = tgt.size(1)
    tgt_causal_mask = create_causal_mask(tgt_seq_len)
    
    # Combine padding and causal masks for target
    # We use broadcasting to combine the two masks
    tgt_mask = torch.max(tgt_padding_mask, tgt_causal_mask)
    
    return src_mask, tgt_mask

# Sample Usage: Training the Transformer Model

Here's sample code to train the transformer with 3 epochs. We'll use a small dataset for simplicity and quick training.

In [13]:
# Sample code to train transformer with 3 epochs
import torch
import torch.nn as nn
import torch.optim as optim

# Sample parallel sentences (English -> French, very simplified example)
src_sentences = [
    "hello world",
    "how are you",
    "I am learning transformers",
    "this is interesting",
    "machine learning is fun",
    "deep learning models",
    "natural language processing",
    "attention is all you need"
]

tgt_sentences = [
    "bonjour le monde",
    "comment allez vous",
    "j'apprends les transformers",
    "c'est intéressant",
    "l'apprentissage automatique est amusant",
    "modèles d'apprentissage profond",
    "traitement du langage naturel",
    "l'attention est tout ce dont vous avez besoin"
]

# Hyperparameters
d_model = 64  # Small for demonstration
num_heads = 4
d_ff = 128
num_layers = 2
max_seq_length = 20
batch_size = 4
learning_rate = 0.001
num_epochs = 3

# Tokenize sentences
tokenizer_src = PyTorchTokenizer(max_vocab_size=100)
tokenizer_tgt = PyTorchTokenizer(max_vocab_size=100)

# Build vocabulary
tokenizer_src.fit_on_texts(src_sentences)
tokenizer_tgt.fit_on_texts(tgt_sentences)

# Convert to sequences
src_seq = tokenizer_src.texts_to_sequences(src_sentences, max_length=max_seq_length)
tgt_seq = tokenizer_tgt.texts_to_sequences(tgt_sentences, max_length=max_seq_length)

# For training, target input is the target sequence shifted right (removing the last token)
# Target output is the target sequence shifted left (removing the first token, which is SOS)
tgt_input = tgt_seq[:, :-1]
tgt_output = tgt_seq[:, 1:]

# Create masks
src_mask, tgt_mask = get_masks(src_seq, tgt_input, pad_idx=tokenizer_src.PAD_IDX)

# Create the transformer model
model = Transformer(
    src_vocab_size=tokenizer_src.vocab_size,
    tgt_vocab_size=tokenizer_tgt.vocab_size,
    d_model=d_model,
    num_head=num_heads,
    d_ff=d_ff,
    num_layers=num_layers,
    max_seq_length=max_seq_length
)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer_tgt.PAD_IDX)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
print("Starting training...")
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    # Create mini-batches
    num_samples = len(src_sentences)
    indices = torch.randperm(num_samples)
    
    for i in range(0, num_samples, batch_size):
        batch_indices = indices[i:i+batch_size]
        
        # Get batch data
        src_batch = src_seq[batch_indices]
        tgt_input_batch = tgt_input[batch_indices]
        tgt_output_batch = tgt_output[batch_indices]
        
        # Get batch masks
        src_mask_batch = src_mask[batch_indices] if src_mask is not None else None
        tgt_mask_batch = tgt_mask[batch_indices] if tgt_mask is not None else None
        
        # Forward pass
        outputs = model(src_batch, tgt_input_batch, src_mask_batch, tgt_mask_batch)
        
        # Reshape outputs and target for loss calculation
        outputs = outputs.reshape(-1, outputs.shape[-1])
        tgt_output_batch = tgt_output_batch.reshape(-1)
        
        # Calculate loss
        loss = criterion(outputs, tgt_output_batch)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    # Print epoch statistics
    avg_loss = total_loss / (len(src_sentences) // batch_size + 1)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

print("Training completed!")

# Test translation
def translate(model, src_text, tokenizer_src, tokenizer_tgt, max_length=50):
    model.eval()
    
    # Tokenize source text
    src = tokenizer_src.texts_to_sequences([src_text], max_length=max_seq_length)
    src_mask = create_padding_mask(src, tokenizer_src.PAD_IDX)
    
    # Start with just the SOS token
    tgt = torch.tensor([[tokenizer_tgt.SOS_IDX]])
    
    # Generate output token by token
    for i in range(max_length):
        # Create target mask
        tgt_mask = create_causal_mask(tgt.size(1))
        
        # Make prediction
        with torch.no_grad():
            output = model(src, tgt, src_mask, tgt_mask)
        
        # Get next token prediction
        next_token = output[:, -1, :].argmax(dim=-1).unsqueeze(1)
        
        # Add predicted token to target sequence
        tgt = torch.cat([tgt, next_token], dim=1)
        
        # Stop if EOS token is generated
        if next_token.item() == tokenizer_tgt.EOS_IDX:
            break
    
    # Convert ids to text
    return tokenizer_tgt.sequences_to_texts(tgt)[0]

# Test with a sample sentence
test_sentence = "deep learning is fascinating"
translation = translate(model, test_sentence, tokenizer_src, tokenizer_tgt)
print(f"Input: {test_sentence}")
print(f"Translation: {translation}")

Starting training...
Epoch 1/3, Loss: 2.5911
Epoch 2/3, Loss: 2.3690
Epoch 3/3, Loss: 2.1439
Training completed!
Input: deep learning is fascinating
Translation: j'apprends


# Running the Decoder Independently

Here's code to run just the decoder part of the transformer, which can be useful for text generation tasks or when you want to implement a decoder-only architecture like GPT.

In [14]:
# Code to run the decoder only
import torch
import torch.nn as nn
import torch.nn.functional as F

class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, num_heads=8, 
                 d_ff=1024, num_layers=4, 
                 max_seq_length=512, dropout=0.1):
        super().__init__()
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model, max_seq_length)
        
        # Stack of decoder layers (but without cross-attention)
        self.decoder_layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Output projection
        self.output_layer = nn.Linear(d_model, vocab_size)
        
        # For masked attention
        self.max_seq_length = max_seq_length
        
    def forward(self, x, mask=None):
        # Generate mask if not provided
        if mask is None:
            mask = create_causal_mask(x.size(1))
            
        # Embed tokens
        x = self.embedding(x)
        
        # Add positional encoding
        x = self.pos_encoder(x)
        
        # Apply each decoder layer (using self-attention only)
        for layer in self.decoder_layers:
            x = layer(x)  # Using encoder layer as our decoder-only layer
        
        # Project to vocabulary
        output = self.output_layer(x)
        
        return output
    
    def generate(self, start_tokens, max_length=100, temperature=1.0):
        self.eval()
        
        # Initialize with start tokens
        current_tokens = start_tokens.clone()
        
        # Generate tokens auto-regressively
        for _ in range(max_length):
            # Forward pass with current sequence
            with torch.no_grad():
                logits = self(current_tokens)
            
            # Get next token prediction (last position)
            next_token_logits = logits[:, -1, :] / temperature
            next_token_probs = F.softmax(next_token_logits, dim=-1)
            
            # Sample from the distribution
            next_token = torch.multinomial(next_token_probs, num_samples=1)
            
            # Append to sequence
            current_tokens = torch.cat([current_tokens, next_token], dim=1)
            
            # Break if we reach a certain length
            if current_tokens.size(1) >= self.max_seq_length:
                break
        
        return current_tokens

# Example usage of the decoder-only model
def run_decoder_only_example():
    # Simple text corpus
    corpus = [
        "the quick brown fox jumps over the lazy dog",
        "all that glitters is not gold",
        "to be or not to be that is the question",
        "a picture is worth a thousand words",
        "actions speak louder than words",
        "Rome wasn't built in a day"
    ]
    
    # Create and fit tokenizer
    tokenizer = PyTorchTokenizer(max_vocab_size=100)
    tokenizer.fit_on_texts(corpus)
    
    # Convert texts to sequences
    sequences = tokenizer.texts_to_sequences(corpus, max_length=20)
    
    # Create decoder-only model
    model = DecoderOnlyTransformer(
        vocab_size=tokenizer.vocab_size,
        d_model=64,
        num_heads=4,
        d_ff=128,
        num_layers=2,
        max_seq_length=20
    )
    
    # Train for a few steps (very basic for demonstration)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.PAD_IDX)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Quick training loop (just for demonstration)
    model.train()
    for epoch in range(5):  # Very few epochs for demonstration
        total_loss = 0
        
        for seq in sequences:
            # Create input and target (shifted by 1)
            x = seq[:-1].unsqueeze(0)  # Input: all tokens except last, add batch dimension
            y = seq[1:].unsqueeze(0)   # Target: all tokens except first, add batch dimension
            
            # Forward pass
            outputs = model(x)
            
            # Calculate loss
            loss = criterion(outputs.reshape(-1, outputs.shape[-1]), y.reshape(-1))
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(sequences):.4f}")
    
    # Text generation with the trained model
    prompt = "the quick"
    prompt_tokens = tokenizer.texts_to_sequences([prompt], add_sos_eos=False)
    
    # Generate continuation
    generated = model.generate(prompt_tokens, max_length=15, temperature=1.2)
    generated_text = tokenizer.sequences_to_texts(generated)[0]
    
    print(f"\nPrompt: {prompt}")
    print(f"Generated: {generated_text}")

# Run the example
run_decoder_only_example()

Epoch 1, Loss: 3.7190
Epoch 2, Loss: 3.5732
Epoch 3, Loss: 3.5058
Epoch 4, Loss: 3.5054
Epoch 5, Loss: 3.4779

Prompt: the quick
Generated: the quick quick picture actions quick question jumps a that day quick built to speak
