In [1]:
import math
import torch
from torch import nn
from torch.nn import functional as F

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
# ============================================================================
# BLOCK 1: INPUT EMBEDDINGS
# ============================================================================

class InputEmbedding(nn.Module):
    """
    Convert input token indices to dense vector representations.
    This is the first step in processing input sequence.
    """
    def __init__(self, d_model: int, vocab_size: int):
        """
        Args:
            d_model: Embedding dimension (e.g. 512)
            vocab_size: size of vocabulary (e.g., 10000)
        """
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        # Embedding layer: maps token indices to vectors
        self.embedding = nn.Embedding(self.vocab_size, self.d_model)
    
    def forward(self, x):
        """
        Input shape: (batch_size,, seq_len) - token indices
        Output_shape: (batch_size, seq_len, d_model) - embedding vectors

        we multiply by sqrt(d_model) to scale embedding as per the paper
        """
        # x shape: (batch_size, seq_len)
        # embedding(x) shape: (batch_size, seq_len, d_model)
        return self.embedding(x) * math.sqrt(self.d_model)

In [3]:
# ============================================================================
# BLOCK 2: POSITIONAL ENCODING
# ============================================================================

class PositionalEncoding(nn.Module):
    """
    Adds positional information to embeddings since transformers have no 
    inherent notion of token position (unlike RNNs).
    Uses sine and cosine functions of different frequencies.
    """
    def __init__(
        self,
        d_model: int,
        max_seq_len: int,
        dropout: float = 0.1
    ):
        """
        Args:
            d_model: Embedding dimensions (512)
            max_seq_len: Maximum sequence length (e.g., 5000)
            dropout: Dropout probability
        """
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        self.dropout = dropout

        # Create positional encoding matrix
        # Shape: (max_seq_len, d_model)
        pe = torch.zeros(max_seq_len, d_model)

        # Create postion indices [0, 1, 2, 3, ...., max_seq_len-1]
        # Shape: (max_seq_len, )
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)

        # Create divison term for scaling
        # Shape: (d_model/2, )
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(1000.0) / d_model))

        # Apply sine to even indices in the array; 2i
        # pe[:, 0::2] shape: (max_seq_len, d_model / 2)
        pe[:, 0::2] = torch.sin(position * div_term)

        # Apply cosine to even indices in the array; 2i+1
        # pe[:, 1::2] shape: (max_seq_len, d_model / 2)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Add batch dimension: (1, max_seq_len, d_model)
        pe = pe.unsqueeze(0)

        # Register as buffer (not a parameter, but part of module state)
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        Input shape: (batch_size, seq_len, d_model)
        Output shape: (batch_size, seq_len, d_model)
        
        Add postional encoding to the input embeddings
        """
        # x shape: (batch_size, seq_len, d_model)
        # self.pe[:, :x.shape[1], :] shape: (1, seq_len, d_mdoel)
        # Broadcasting positional encoding to each batch
        x = x + self.pe[:, :x.shape[1], :].requires_grad_(False)
        return self.dropout(x)

In [4]:
# ============================================================================
# BLOCK 3: LAYER NORMALIZATION
# ============================================================================

class LayerNormalization(nn.Module):
    """
    Normalizes inputs across feature dimensions.
    Helps with training stability and convergence.
    """
    def __init__(
        self,
        features: int,
        eps: float = 1e-6
    ):
        """
        Args:
            features: number of features (d_model)
            eps: small constant for numerical stability
        """
        super().__init__()
        self.eps = eps
        # Learn parameters for affine transformation
        self.alpha = nn.Parameter(torch.ones(features)) # Multiplicative
        self.bias = nn.Parameter(torch.ones(features))  # Additive

    def forward(self, x):
        """
        Input shape: (batch_size, seq_len, d_model)
        Output shape: (batch_size, seq_len, d_model)

        Normalizes across last dimension (features)
        """
        # x shape: (batch_size, seq_len, d_model)
        # Calculate mean across last dimension
        # mean shape: (batch_size, seq_len, 1)
        mean = x.mean(dim=-1, keepdim=True)

        # Calculate std across last dimension
        # std shape: (batch_size, seq_len, 1)
        std = x.std(dim=-1, keepdim=True)

        # Normalize: (x - mean) / (std + eps)
        # Output shape: (batch_size, seq_len, d_model)
        return self.alpha * (x - mean) / (std + self.eps) + self.bias 

In [5]:
# ============================================================================
# BLOCK 4: FEED FORWARD NETWORK
# ============================================================================

class FeedForwardBlock(nn.Module):
    """
    Position-wise feed-forward network
    Applies two linear transformations with ReLU activation  in between.
    FFN(x) = max(0, xW1 + b1)W2 + b2
    """
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        dropout: float = 0.1
    ):
        """
        Args:
            d_model: Input/Output dimensions (512)
            d_ff: Hidden layer dimension (2048, typically 4x of d_model)
        """
        super().__init__()
        # First linear transformation: d_model -> d_ff
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        
        # Second linear transformation: d_ff -> d_model
        self.linear2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x):
        """
        Input shape: (batch_size, seq_len, d_model)
        Output shape: (batch_size, seq_len, d_model)
        
        Applies: Linear -> ReLU -> Dropout -> Linear
        """
        # x shape: (batch_size, seq_len, d_model)
        # After linear1: (batch_size, seq_len, d_ff)
        x = self.linear1(x)
        
        # After ReLU: (batch_size, seq_len, d_ff)
        x = F.relu(x)
        
        # After dropout: (batch_size, seq_len, d_ff)
        x = self.dropout(x)
        
        # After linear2: (batch_size, seq_len, d_model)
        return self.linear2(x)

In [6]:
# ============================================================================
# BLOCK 5: MULTI-HEAD ATTENTION (Core Component)
# ============================================================================

class MultiHeadAttentionBlock(nn.Module):
    """
    Multihead attention mechanism.
    Allows model to jointly attend to information from different
    respresentation subspaces at different positions.
    """
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        dropout: float = 0.1
    ):
        """
        Args:
            d_model: Model dimensions (512)
            num_heads: Number of attention heads (8)
            dropout: Dropout probability
        """
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        
        # Ensure that d_model is divisible by num_heads 
        assert self.d_model % self.num_heads == 0, "d_model must be divisible by num_heads"
        
        # Dimension of each head
        self.d_k = d_model//num_heads
        
        # Linear layers for Q, K, V projections
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        
        # Output projection
        self.w_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    @staticmethod
    def attention(query, key, value, mask=None, dropout=None):
        """
        Scaled dot-product attention
        Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
        
        Input shapes:
            query: (batch_size, num_heads, seq_len, d_k)
            key: (batch_size, num_heads, seq_len, d_k)
            value: (batch_size, num_heads, seq_len, d_k)
            mask: (batch_size, 1, seq_len, seq_len) # or similar
        
        Output shape: (batch_size, num_heads, seq_len, d_k)
        """
        d_k = query.shape[-1]
        
        # Calculate attention scores: QK^T
        # query shape: (batch_size, num_heads, seq_len, d_k)
        # key.transpose(-2, -1) shape: (batch_size, num_head, d_k, seq_len)
        # scores shape: (batch_size, num_heads, seq_len, seq_len)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        
        # Apply mask (if provided) - used for padding and look-ahead masking
        if mask is not None:
            # mask shape: (batch_size, 1, 1, seq_len) or (batch_size, 1, seq_len, seq_len)
            # Set masked positions to very negative value so softmax -> 0
            scores = scores.masked_fill(mask==0, -1e9)
        
        # Apply softmax to get attention weights
        # attention_weights shape: (batch_size, num_heads, seq_len, seq_len)
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply dropout to attention weights
        if dropout is not None:
            attention_weights = dropout(attention_weights)
            
        # Apply attention weights to values
        # attention_weights shape: (batch_size, num_heads, seq_len, seq_len)
        # value shape: (batch_size, num_heads, seq_len, d_k)
        # output shape: (batch_size, num_heads, seq_len, d_k)
        output = torch.matmul(attention_weights, value)
        
        return output, attention_weights
    
    def forward(self, q, k, v, mask=None):
        """
        Input shapes:
            q (query): (batch_size, seq_len, d_model)
            k (key):   (batch_size, seq_len, d_model)
            v (value): (batch_size, seq_len, d_model)
            mask:      (batch_size, 1, seq_len, seq_len)
            
        Output shape: (batch_size, seq_len, d_model)
        """
        batch_size = q.shape[0]
        
        # Apply linear projection
        # Input shape: (batch_size, seq_len, d_model)
        # Output shape: (batch_szie, seq_len, d_model)
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)
        
        # Split into multiple heads
        # Reshape: (batch_size, seq_len, d_model) -> (batch_size, seq_len, num_heads, d_k)
        # Then transpose: (batch_size, num_heads, seq_len, d_k)
        query = query.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # query shape: (batch_size, num_heads, seq_len, d_k)
        # key shape:   (batch_size, num_heads, seq_len, d_k)
        # value shape: (batch_size, num_heads, seq_len, d_k)
        
        # Apply attention
        # x shape: (batch_size, num_heads, seq_len, d_k)
        # attention_weights shape: (batch_size, num_heads, seq_len, seq_len)
        x, self.attention_weights = self.attention(query, key, value, mask, self.dropout)
        
        # Concatenate heads
        # Transpose: (batch_size, num_heads, seq_len, d_k) -> (batch_size, seq_len, num_heads, d_k)
        # Reshape: (batch_size, seq_len, num_heads, d_k) -> (batch_size, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # Apply final linear projections
        # Input: (batch_size, seq_len, d_model)
        # Output: (batch_size, seq_len, d_model)
        return self.w_o(x)

In [7]:
# ============================================================================
# BLOCK 6: RESIDUAL CONNECTION
# ============================================================================

class ResidualConnection(nn.Module):
    """
    Residual connection followed by layer normalization.
    Implements: LayerNorm(x + Sublayer(x))
    Helps with gradient flow and training deep networks
    """
    def __init__(
        self,
        features: int,
        dropout: float = 0.1
    ):
        """
        Args:
            features: Number of features (d_model)
            dropout: Dropout probability
        """
        super().__init__()
        self.norm = LayerNormalization(features=features)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, sublayer):
        """
        Input shape: (batch_size, seq_len, d_model)
        output shape: (batch_size, seq_len, d_model)
        
        Args:
            x: Input tensor
            sublayer: Function/module to apply (e.g., attention of FFN)
        """
        # x shape: (batch_size, seq_len, d_model)
        # sublayer (self.norm(x)) shape: (batch_size, seq_len, d_model)
        # Add residual connection and apply dropout
        # Output shape: (batch_size, seq_len, d_model)
        return x + self.dropout(sublayer(self.norm(x)))

In [8]:
# ============================================================================
# BLOCK 7: ENCODER BLOCK
# ============================================================================

class EncoderBlock(nn.Module):
    """
    Single encoder block consisting of:
    1. Multi-head self-attention
    2. Feed-forward network
    Both with residual connections and layer normalization
    """
    def __init__(
        self,
        self_attention_block: MultiHeadAttentionBlock,
        feed_forward_block: FeedForwardBlock,
        dropout: float = 0.1
    ):
        """
        Args:
            self_attention_block: Multi-head attention layer
            feed_forward_block: Feed-forward network
            dropout: Dropout probability
        """
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        # Two residual connections: one for attention, one for FFN
        self.residual_connections = nn.ModuleList(
            [
                ResidualConnection(self_attention_block.d_model, dropout),
                ResidualConnection(feed_forward_block.linear2.out_features, dropout)
            ]
        )
    
    def forward(self, x, src_mask):
        """
        Input shape: (batch_size, seq_len, d_model)
        Output shape: (batch_size, seq_len, d_model)
        
        Args:
            x: Input tensor
            src_mask: Mask for padding tokens
        """
        # x shape: (batch_size, seq_len, d_model)
        # Apply self-attention: query = key = value = x
        # Output shape: (batch_size, seq_len, d_model)
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        
        # Apply feed-forward network with residual connection
        # Output shape: (batch_size, seq_len, d_model)
        return self.residual_connections[1](x, self.feed_forward_block)

In [9]:
# ============================================================================
# BLOCK 8: ENCODER (Stack of Encoder Blocks)
# ============================================================================

class Encoder(nn.Module):
    """
    Complete encoder: stack of N encoder blocks 
    """
    def __init__(self, layers: nn.ModuleList):
        """
        Args:
            layers: List of EncoderBlock modules
        """
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(layers[0].self_attention_block.d_model)
        
    def forward(self, x, mask):
        """
        Input shape: (batch_size, seq_len, d_model)
        output shape: (batch_size, seq_len, d_model)
        
        Passes input through all encoder layers sequentially
        """
        # x shape: (batch_size, seq_len, d_model)
        
        # Pass through each encoder block
        for layer in self.layers:
            # x shape remains: (batch_size, seq_len, d_model)
            x = layer(x, mask)
            
        # Final lauyer normalization
        # Output shape: (batch_size, seq_len, d_model)
        return self.norm(x)

In [10]:
# ============================================================================
# BLOCK 9: DECODER BLOCK
# ============================================================================

class DecoderBlock(nn.Module):
    """
    Single decoder block consisting of:
    1. Masked multi-head self-attention (looks at previous positions only)
    2. Multi-head cross-attention (attends to encoder output)
    3. Feed-forward network
    4. All with residual connections and layern normalization.
    """
    def __init__(
        self,
        self_attention_block: MultiHeadAttentionBlock,
        cross_attention_block: MultiHeadAttentionBlock,
        feed_forward_block: FeedForwardBlock,
        dropout: float = 0.1
    ):
        """
        Args:
            self_attention_block: Masked self-attention layer
            cross_attention_block: Cross-attention layer for encoder-decoder attention
            feed_forward_block: Feed-forward network
            dropout: Dropout probability
        """
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        
        # Three residual connections
        self.residual_connections = nn.ModuleList(
            [
                ResidualConnection(self.self_attention_block.d_model, dropout),
                ResidualConnection(self.cross_attention_block.d_model, dropout),
                ResidualConnection(self.feed_forward_block.linear2.out_features, dropout)     
            ]
        )
        
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        """
        Input shapes:
            x: (batch_size, tgt_seq_len, d_model) - decoder input
            encoder_output: (batch_size, src_seq_len, d_model) - encoder output
            src_mask: (batch_size, 1, 1, src_seq_len) - source padding mask
            tgt_mask: (batch_size, 1, tgt_seq_len, tgt_seq_len) - target mask
        
        Output shape: (batch_size, tgt_seq_len, d_model)
        """
        # x shape: (batch_size, tgt_seq_len, d_model)
        
        # 1. Masked self-attention (decoder attends to previous positions)
        # query = key = value = x (from decoder)
        # Output shape: (batch_size, tgt_seq_len, d_model)
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        
        # 2. Cross-attention (decoder attends to encoder output)
        # query = x (from decoder), key = value = encoder_output
        # Output shape: (batch_size, tgt_seq_len, d_model)
        x = self.residual_connections[1](
            x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask)
        )
        
        # 3. Feed-forward network
        # Output shape: (batch_size, tgt_seq_len, d_model)
        return self.residual_connections[2](x, self.feed_forward_block)

In [11]:
# ============================================================================
# BLOCK 10: DECODER (Stack of Decoder Blocks)
# ============================================================================

class Decoder(nn.Module):
    """
    Complete decoder: stack of N decoder blocks
    """
    def __int__(self, layers: nn.ModuleList):
        """
        Args:
            layers: List of DecoderBlock modules
        """
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(layers[0].self_attention_block.d_model)
        
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        """
        Input shapes:
            x: (batch_size, tgt_seq_len, d_model)
            encoder_output: (batch_size, src_seq_len, d_model)
            src_mask: (batch_size, 1, 1, src_seq_len)
            tgt_mask: (batch_size, 1, tgt_seq_len, tgt_seq_len)
            
        Output shape: (batch_size, tgt_seq_len, d_model)
        """
        # x shape: (batch_size, tgt_seq_len, d_model)
        
        # Pass through rach decoder block
        for layer in self.layers:
            # x shape remains: (batch_size, tgt_seq_len, d_model)
            x = layer(x, encoder_output, src_mask, tgt_mask)
        
        # Final layer normalization
        # Output shape: (batch_size, tgt_seq_len, d_model)
        return self.norm(x)

In [12]:
# ============================================================================
# BLOCK 11: PROJECTION LAYER (Output Layer)
# ============================================================================

class ProjectionLayer(nn.Module):
    """
    Projects decoder output to vocabulary size.
    Maps from d_model to vocabulary probabilities
    """
    def __init__(self, d_model: int, vocab_size: int):
        """
        Args:
            d_model: Model dimensions (512)
            vocab_size: Size of output vocabulary
        """
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        """
        Input shape: (batch_size, seq_len, d_model)
        Output shape: (batch_size, seq_len, vocab_size)
        
        Projects to vocabulary and applies log_softmax for numerical stability
        """
        # x shape: (batch_size, seq_len, d_model)
        # After projection: (batch_size, seq_len, vocab_size)
        # Log softmax over vocabulary dimension (dim=-1)
        return torch.log_softmax(self.proj(x), dim=-1)

In [13]:
# ============================================================================
# BLOCK 12: COMPLETE TRANSFORMER MODEL
# ============================================================================

class Transformer(nn.Module):
    """
    Complete Transformer model combining encoder and decoder.
    """
    def __init__(
        self,
        encoder: Encoder,
        decoder: Decoder,
        src_embed: InputEmbedding,
        tgt_embed: InputEmbedding,
        src_pos: PositionalEncoding,
        tgt_pos: PositionalEncoding,
        projection_layer: ProjectionLayer
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
        
    def encode(self, src, src_mask):
        """
        Encode source sequence.
        
        Input shapes:
            src: (batch_size, src_seq_len) - source token indices
            src_mask: (batch_size, 1, 1, src_seq_len)
        
        Output shape: (batch_size, src_seq_len, d_model)
        """
        # src shape: (batch_size, src_seq_len)
        # After embedding: (batch_size, src_seq_len, d_model)
        src = self.src_embed(src)
        
        # After positional encoding: (batch_size, src_seq_len, d_model)
        src = self.src_pos(src)
        
        # After encoder: (batch_size, src_seq_len, d_model)
        return self.encoder(src, src_mask)
    
    def decode(self, encoder_output, src_mask, tgt, tgt_mask):
        """
        Decodes target sequence given encoder output.
        
        Input shapes:
            encoder_output: (batch_size, src_seq_len, d_model)
            src_mask: (batch_size, 1, 1, src_seq_len)
            tgt: (batch_size, tgt_seq_len) - target token indices
            tgt_mask: (batch_size, 1, tgt_seq_len, tgt_seq_len)
            
        Outpu shape: (batch_size, tgt_seq_len, d_model)            
        """
        # tgt shape: (batch_size, tgt_seq_len)
        # After embedding: (batch_size, tgt_seq_len, d_model)
        tgt = self.tgt_embed(tgt)
        
        # After positional encoding: (batch_size, tgt_seq_len, d_model)
        tgt = self.tgt_pos(tgt)
        
        # After decoder: (batch_size, tgt_seq_len, d_model)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
    
    def project(self, x):
        """
        Projects decoder output to vocabulary probabilities
        
        Input shape: (batch_size, tgt_seq_len, d_model)
        Output_shape: (batch_size, tgt_seq_len, vocab_size)
        """
        return self.projection_layer(x)

In [20]:
# ============================================================================
# BLOCK 13: MODEL BUILDER FUNCTION
# ============================================================================

def build_transformer(
    src_vocab_size: int,
    tgt_vocab_size: int,
    src_seq_len: int,
    tgt_seq_len: int,
    d_model: int = 512,
    num_encoder_blocks: int = 6,
    num_decoder_blocks: int = 6,
    num_heads: int = 8,
    dropout: float = 0.1,
    d_ff: int = 2048
) -> Transformer:
    """
    Builds a complete transformer model with specified parameters.
    
    Args:
        src_vocab_size: Source vocabulary size
        tgt_vocab_size: Target vocabulary size
        src_seq_len: Maximum source sequence length
        tgt_seq_len: Maximum target sequence length
        d_model: Model dimension (default: 512)
        num_encoder_blocks: Number of encoder layers (default: 6)
        num_decoder_blocks: Number of decoder layers (default: 6)
        num_heads: Number of attention heads (default: 8)
        dropout: Dropout probability (default: 0.1)
        d_ff: Feed-forward hidden dimension (default: 2048)
    
    Returns:
        Complete Transformer model
    """
    # Create embedding layers
    src_embed = InputEmbedding(d_model=d_model, vocab_size=src_vocab_size)
    tgt_embed = InputEmbedding(d_model=d_model, vocab_size=tgt_vocab_size)
    
    # Create positional encoding layer
    src_pos = PositionalEncoding(d_model=d_model, max_seq_len=src_seq_len, dropout=dropout)
    tgt_pos = PositionalEncoding(d_model=d_model, max_seq_len=tgt_seq_len, dropout=dropout)

    # Create encoder blocks
    encoder_blocks = []
    for _ in range(num_encoder_blocks):
        encoder_self_attention_block = MultiHeadAttentionBlock(
            d_model=d_model,
            num_heads=num_heads,
            dropout=dropout
        )
        feed_forward_block = FeedForwardBlock(
            d_model=d_model,
            d_ff=d_ff,
            dropout=dropout
        )
        encoder_block = EncoderBlock(
            self_attention_block=encoder_self_attention_block,
            feed_forward_block=feed_forward_block,
            dropout=dropout
        )
        encoder_blocks.append(encoder_block)
    
    # Create decoder blocks
    decoder_blocks = []
    for _ in range(num_decoder_blocks):
        decoder_self_attention_block = MultiHeadAttentionBlock(
            d_model=d_model,
            num_heads=num_heads,
            dropout=dropout
        )
        decoder_cross_attention_block = MultiHeadAttentionBlock(
            d_model=d_model,
            num_heads=num_heads,
            dropout=dropout
        )
        feed_forward_block = FeedForwardBlock(
            d_model=d_model,
            d_ff=d_ff,
            dropout=dropout
        )
        decoder_block = DecoderBlock(
            self_attention_block=decoder_self_attention_block,
            cross_attention_block=decoder_cross_attention_block,
            feed_forward_block=feed_forward_block,
            dropout=dropout
        )
        decoder_blocks.append(decoder_block)
        
    # Create encoder and decoder
    encoder = Encoder(nn.ModuleList(encoder_blocks))
    decoder = Decoder(nn.ModuleList(decoder_blocks))
    
    # Create Projection layer
    projection_layer = ProjectionLayer(d_model=d_model, vocab_size=tgt_vocab_size)
    
    # Create the transformer
    transformer = Transformer(
        encoder=encoder,
        decoder=decoder,
        src_embed=src_embed,
        tgt_embed=tgt_embed,
        src_pos=src_pos,
        tgt_pos=tgt_pos,
        projection_layer=projection_layer
    )
    # Initialize parameters with Xavier uniform
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return transformer

In [7]:
# ============================================================================
# BLOCK 14: MASK GENERATION FUNCTIONS
# ============================================================================

def create_padding_mask(seq, pad_idx=0):
    """
    Create padding mask to ignore padding tokens
    
    Input shape: (batch_size, seq_len)
    Output shape: (batch_size, 1, 1, seq_len)
    
    Args:
        seq: Input sequence with token indices
        pad_idx: Padding token index (default: 0)
    """
    # seq shape: (batch_size, seq_len)
    # Create mask where padding token are 0, others are 1
    # mask shape: (batch_size, 1, 1, seq_len)
    return (seq != pad_idx).unsqueeze(1).unsqueeze(2)

def create_look_ahead_mask(size):
    """
    Create look ahead mask for decoder to prevent attending to future tokens.
    
    Input size (int) - Sequence length
    Output shape: (1, size, size)
    
    Creates lower triangular matrix so position i can only attend to <= i
    """
    # Create lower triangualr matrix
    # Output shape: (1, size, size)
    mask = torch.tril(torch.ones(1, size, size))
    return mask # 1 for valid positions, 0 for masked

def create_target_mask(tgt, pad_idx=0):
    """
    Creates combined mask for decoder: padding mask AND look-ahead mask.
    
    Input shape: (batch_size, tgt_seq_len)
    Output_shape: (batch_size, 1, tgt_seq_len, tgt_seq_len)
    """
    # tgt shape: (batch_size, tgt_seq_len)
    tgt_seq_len = tgt.shape[1]
    
    # Create padding mask: (batch_size, 1, 1, tgt_seq_len)
    tgt_padding_mask = create_padding_mask(seq=tgt, pad_idx=pad_idx)
    
    # Create look-ahead mask: (1, tgt_seq_len, tgt_seq_len)
    tgt_look_ahead_mask = create_look_ahead_mask(size=tgt_seq_len).to(tgt.device)
    
    # Combine masks (both must be satisfied)
    # Output shape: (batch_size, 1, tgt_seq_len, tgt_seq_len)
    tgt_mask = tgt_padding_mask & tgt_look_ahead_mask
    return tgt_mask

In [11]:
class Qwer:
    
    @staticmethod
    def sampel(query):
        return "static method called"
    
    def test(self, query):
        return self.sampel(query=query) 

In [13]:
obj = Qwer()
obj.test("hello")

'static method called'

In [16]:
Qwer.sampel("")

'static method called'