In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [2]:
# ============================================================================
# STEP 1: CONFIGURATION - Define all hyperparameters
# ============================================================================

class BertConfig:
    """
    Configuration classs to store BERT hyperparameters.
    These values are based on BERT-Base architecture.
    """
    def __init__(self):
        self.vocab_size = 30522                     # Size of vocabulary (WordPiece tokens)
        self.max_position_embeddings = 512          # Maximum sequence length
        self.hidden_size = 768                      # Dimension of embeddings (d_model)
        self.num_hidden_layers = 12                 # Number of transformer encoder layers
        self.num_attention_heads = 12               # Number of attention heads
        self.intermediate_size = 3072               # Dimension of feed-forward layer (4 * hidden_size)
        self.hidden_dropout_prob = 0.1              # Dropout probability
        self.attention_probs_dropout_prob = 0.1     # Attention dropout
        self.type_vocab_size = 2                    # Number of segment types (Sentence A/B)
        self.layer_norm_eps = 1e-12                 # Layer normalization epsilon
    
config = BertConfig()

In [3]:
# ============================================================================
# STEP 2: EMBEDDING LAYER - Convert tokens to embeddings
# ============================================================================

class BertEmbedding(nn.Module):
    """
    BERT has 3 types of embeddings that are summed together:
    1. Token Embeddings - represents the actual word/token.
    2. Position Embeddings - represents position in sequence.
    3. Segment Embeddings - differentiates between sentence A and B
    """

    def __init__(self, config):
        super(BertEmbedding, self).__init__()

        # Token Embedding: Maps token IDs to dense vectors
        # Input: (batch_size, seq_len) with token IDs
        # Output: (batch_size, seq_len, hidden_size)
        self.word_embeddings = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.hidden_size,
            padding_idx=0
        )

        # Position Embedding: Learned position encodings
        # Input: (batch_size, seq_len) with position IDs [0, 1, 2,......, seq_len-1]
        # Output: (batch_size, seq_len, hidden_size)
        self.position_embeddings = nn.Embedding(
            num_embeddings=config.max_position_embeddings,
            embedding_dim=config.hidden_size
        )

        # Segment/Token Type Embedding: Differentiates sentence A from B
        # Input: (batch_size, seq_len) with segment IDs (0 or 1)
        # Output: (batch_size, seq_len, hidden_size)
        self.token_type_embeddings = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.hidden_size
        )

        # Layer Normalization: Normalizes the summed embeddings
        self.layer_norm = nn.LayerNorm(
            normalized_shape=config.hidden_size,
            eps=config.layer_norm_eps
        )

        # Dropout for regularization
        self.dropout = nn.Dropout(p=config.hidden_dropout_prob)

        # Register buffer for positions_ids to avoid recreating it every forward pass
        # Shape: (1, max_positions_embeddings)
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1))
        )

    def forward(self, input_ids, token_type_ids=None):
        """
        Args:
            input_ids: Token IDs, shape: (batch_size, seq_len)
            token_type_ids: Segment IDs, shape: (batch_size, seq_len), Optional
        
        Returns:
            embeddings: Combined embeddings, shape: (batch_size, seq_len, hidden_size)
        """
        # input_ids shape: (batch_size, seq_len)
        batch_size, seq_len = input_ids.size()

        # Step 2.1: Get position IDs for the sequence
        # Shape: (1, seq_len) -> will broadcast to (batch_size, seq_len)
        position_ids = self.position_ids[:, :seq_len]

        # Step 2.2: If token_type_ids not provided, assume all tokens are from sentence A.
        if token_type_ids is None:
            token_type_ids == torch.zeros_like(input_ids) # Shape: (batch_size, seq_len)
        
        # Step 2.3: Get token embeddings
        # Input shape: (batch_size, seq_len)
        # Output shape: (batch_size, seq_len, hidden_size=768)
        word_embeddings = self.word_embeddings(input_ids)

        # Step 2.4: Get position embeddings
        # Input shape: (1, seq_len) 
        # Output shape: (1, seq_len, hidden_size=768) -> broadcasts to (batch_size, seq_len, 768)
        position_embeddings = self.position_embeddings(position_ids)

        # Step 2.5: Get Segment/Token type embeddings
        # Input shape: (batch_size, seq_len)
        # Output shape: (batch_size, seq_len, hidden_size=768)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # Step 2.7: Sum all three embeddings
        # All inputs: (batch_size, seq_len, hidden_size=768)
        # All outputs: (batch_size, seq_len, hidden_size=768)
        embeddings = word_embeddings + position_embeddings + token_type_embeddings

        # Step 2.7: Apply Layer Normalization
        # Input shape: (batch_size, seq_len, hidden_size=768)
        # Output shape: (batch_size, seq_len, hidden_size=768)
        embeddings = self.layer_norm(embeddings)

        # Step 2.8: Apply dropout
        # Input shape: (batch_size, seq_len, hidden_size=768)
        # Output shape: (batch_size, seq_len, hidden_size=768)
        return self.dropout(embeddings)

In [16]:
# ============================================================================
# STEP 3: MULTI-HEAD SELF-ATTENTION - Core attention mechanism
# ============================================================================

class MultiHeadSelfAttention(nn.Module):
    """
    Multi-Head self-attention mechanism. 
    
    The attention mechanism allows the model to focus on different parts of the input.
    Multi-head attention runs multiple attention mechanisms in parellel

    Formula: Attention(Q, K, V) = softmax(Q*K^T / sqrt(d_k)) * V
    """
    def __init__(self, config: BertConfig):
        super(MultiHeadSelfAttention, self).__init__()

        assert config.hidden_size % config.num_attention_heads == 0, "hidden_size must be divisible by num_attention_heads"

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = config.hidden_size // config.num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # Linear transformation for Query, Key and Value
        # Each transforms (batch_size, seq_len, hidden_size) -> (batch_size, seq_len, hidden_size)
        self.query = nn.Linear(in_features=config.hidden_size, out_features=self.all_head_size)
        self.key = nn.Linear(in_features=config.hidden_size, out_features=self.all_head_size)
        self.value = nn.Linear(in_features=config.hidden_size, out_features=self.all_head_size)

        self.dropout = nn.Dropout(p=config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        """
        Reshape tensor for multi-head attention computation.

        Args:
            x: shape (batch_size, seq_len, hidden_size=768)

        Returns:
            reshaped tensor: (batch_size, num_heads=12, seq_len, head_size=64)
        """
        # Input shape: (batch_size, seq_len, hidden_size=768)
        batch_size, seq_len, hidden_size = x.size()

        # Step 3.1: Reshape to separate heads
        # (batch_size, seq_len, 768) -> (batch_size, seq_len, 12, 64)
        x = x.view(batch_size, seq_len, self.num_attention_heads, self.attention_head_size)

        # Step 3.2: Transpose to bring heads dimension forward
        # (batch_size, seq_len, 12, 64) -> (batch_size, 12, seq_len, 64)
        return x.permute(0, 2, 1, 3)
    
    def forward(self, hidden_states, attention_mask=None):
        """
        Args:
            hidden_state: shape (batch_size, seq_len, hidden_size=768)
            attention_mask: shape (batch_size, 1, 1, seq_len) - mask padded tokens
        
        Returns:
            context_layer: shape (batch_size, seq_len, hidden_size=768)
        """
        # Input shape: (batch_size, seq_len, hidden_size)
        batch_size, seq_len, hidden_size = hidden_states.size()

        # Step 3.3: Apply linear transformation to get Q, K and V
        # Each output shape: (batch_size, seq_len, hidden_size)
        query_layer = self.query(hidden_states)     # (batch_size, seq_len, hidden_size)
        key_layer = self.key(hidden_states)         # (batch_size, seq_len, hidden_size)
        value_layer = self.value(hidden_states)     # (batch_size, seq_len, hidden_size)

        # Step 3.4: Reshape Q, K and V for multi-head attention
        # Transform to (batch_size, num_heads, seq_len, attention_head_size)
        query_layer = self.transpose_for_scores(query_layer)    # (batch_size, num_heads, seq_len, attention_hidden_size)
        key_layer = self.transpose_for_scores(key_layer)        # (batch_size, num_heads, seq_len, attention_hidden_size)
        value_layer = self.transpose_for_scores(value_layer)    # (batch_size, num_heads, seq_len, attention_hidden_size)

        # Step 3.5: Calculate attention scores - Q*K^T
        # query: (batch_size, num_heads, seq_len, attention_hidden_size)
        # key transposed: (batch_size, num_heads, attention_hidden_size, seq_len)
        # Output: (batch_size, num_heads, seq_len, seq_len)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # Step 3.6: Scale by sqrt of head dimension
        # This prevents dot product from growing too large
        # Input/Output shape: (batch_size, num_heads, seq_len, seq_len)
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # Step 3.7: Apply attention mask (if provided)
        # Mask padded tokens by setting there scores to very -ve value
        if attention_mask is not None:
            # attention_mask shape: (batch_size, 1, 1, seq_len)
            # attention_scores shape: (batch_size, num_heads, seq_len, seq_len)
            attention_scores = attention_scores + attention_mask

        # Step 3.8: Apply softmax to get attention probabilities
        # Input: (batch_size, num_heads, seq_len, seq_len)
        # Output: (batch_size, num_heads, seq_len, seq_len) - probability sum to 1 over last dimension
        attention_probs = F.softmax(attention_scores, dim=-1)


        # Step 3.9: Apply dropout to attention probabilities
        # Shape: (batch_size, num_heads, seq_len, seq_len)
        attention_probs = self.dropout(attention_probs)

        # Step 3.10: Multiply attention probabilities with V to get context
        # attention_probs: (batch_size, num_heads, seq_len, seq_len)
        # value_layer: (batch_size, num_heads, seq_len, attention_head_size)
        # Output: (batch_size, num_heads, seq_len, attention_head_size)
        context_layer = torch.matmul(attention_probs, value_layer)

        # Step 3.11: Transpose back to original format
        # (batch_size, num_heads, seq_len, attention_head_size) -> (batch_size, seq_len, num_heads, attention_head_size)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

        # Step 3.12: Reshape to combine all heads
        # (batch_size, seq_len, num_heads, attention_head_size) -> (batch_size, seq_len, hidden_size)
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, )
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer

In [17]:
# ============================================================================
# STEP 4: SELF-ATTENTION OUTPUT - Project attention output back
# ============================================================================

class SelfAttentionOutput(nn.Module):
    """
    Projects the attention output and applies residual connection + layer norm.
    """
    def __init__(self, config: BertConfig):
        super(SelfAttentionOutput, self).__init__()

        # Linear layer to project attention output
        # Input/output: (batch_size, seq_len, hidden_size=768)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)

        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        """
        Args:
            hidden_states: Attention output, shape (batch_size, seq_len, hidden_size=768)
            input_tensor: Original input to atttention, shape (batch_size, seq_len, hidden_size=768)
        
        Returns:
            output: shape (batch_size, seq_len, hidden_size=768)
        """
        # Step 4.1: Apply linear transformation
        # Input: (batch_size, seq_len, hidden_size)
        # Output: (batch_size, seq_len, hidden_size)
        hidden_states = self.dense(hidden_states)

        # Step 4.2: Apply dropout
        # shape: (batch_size, seq_len, hidden_size)
        hidden_states = self.dropout(hidden_states)

        # Step 4.3: Add residual connection and apply layer norm
        # Both inputs: (batch_size, seq_len, hidden_size)
        # Output: (batch_size, seq_len, hidden_size)
        return self.LayerNorm(hidden_states + input_tensor)

In [18]:
# ============================================================================
# STEP 5: COMPLETE ATTENTION BLOCK - Combine attention + output projection
# ============================================================================

class BertAttention(nn.Module):
    """
    Complete attention block combining multi-head attention and output projection
    """
    def __init__(self, config: BertConfig):
        super(BertAttention, self).__init__()
        self.self_attention = MultiHeadSelfAttention(config=config)
        self.output = SelfAttentionOutput(config=config)
        
    def forward(self, hidden_states, attention_mask=None):
        """
        Args:
            hidden_states: shape (batch_size, seq_len, hidden_size)
            attention_mask: shape (batch_size, 1, 1, seq_len)
        
        Returns:
            attention_output: shape (batch_size, seq_len, hidden_size)
        """
        # Step 5.1: Apply multi-head self attention
        # Inputs: (batch_size, seq_len, hidden_size), (batch_size, 1, 1, seq_len)
        # Output: (batch_size, seq_len, hidden_size)
        self_attention_output = self.self_attention(hidden_states, attention_mask)

        # Step 5.2: Apply output projection with residual connection
        # Inputs: (batch_size, seq_len, hidden_size), (batch_size, seq_len, hidden_size)
        # Output: (batch_size, seq_len, hidden_size)
        attention_output = self.output(self_attention_output, hidden_states)
        return attention_output # Shape: (batch_size, seq_len, hidden_size)

In [19]:
# ============================================================================
# STEP 6: FEED-FORWARD NETWORK - Position-wise FFN after attention
# ============================================================================

class BertIntermediate(nn.Module):
    """
    First part of feed-forward neural network - expands hidden_size to intermediate_size
    """
    def __init__(self, config):
        super(BertIntermediate, self).__init__()

        # Expands from hidden_size (768) to intermediate_size (3072)
        # Input: (batch_size, seq_len, hidden_size)
        # Output: (batch_size, seq_len, intermediate_size)
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)

        # GELU activation function (Gaussian Error Linear Unit)
        # Smoother than ReLU, used in original BERT
        self.intermediate_act_fn = nn.GELU()

    def forward(self, hidden_states):
        """
        Args:
            hidden_states: shape (batch_size, seq_len, hidden_size)
        
        Returns:
            output: shape (batch_size, seq_len, intermediate_size)
        """
        # Step 6.1: Apply linear transformation (expand dimension)
        # Input: (batch_size, seq_len, hidden_size)
        # Output: (batch_size, seq_len, intermediate_size)
        hidden_states = self.dense(hidden_states)

        # Step 6.2: Apply GELU activation
        # Input/output: (batch_size, seq_len, intermediate_size)
        return self.intermediate_act_fn(hidden_states)
    

class BertOutput(nn.Module):
    """
    Second part of feed-forward network - projects back to hidden_size
    """
    def __init__(self, config):
        super(BertOutput, self).__init__()
        # Project from intermediate_size (3072) to hidden_size (768)
        # Input: (batch_size, seq_len, intermediate_size)
        # Output: (batch_size, seq_len, hidden_size)
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)

        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        """
        Args:
            hidden_states: FFN intermediate output, shape (batch_size, seq_len, intermediate_size)
            input_tensor: Input to FFN block, shape (batch_size, seq_len, hidden_size)

        Returns:
            output: shape (batch_size, seq_len, hidden_size)
        """
        # Step 6.3: Project back to hidden_size
        # Input: (batch_size, seq_len, intermediate_size)
        # Output: (batch_size, seq_len, hidden_size)
        hidden_states = self.dense(hidden_states)

        # Step 6.4: Apply dropout
        # Shape: (batch_size, seq_len, hidden_size)
        hidden_states = self.dropout(hidden_states)

        # Step 6.5: Add residual connection adn layer norm
        # Both inputs: (batch_size, seq_len, hidden_size)
        # Output: (batch_size, seq_len, hidden_size)
        return self.LayerNorm(hidden_states + input_tensor)

In [20]:
# ============================================================================
# STEP 7: ENCODER LAYER - Single transformer encoder block
# ============================================================================

class BertLayer(nn.Module):
    """
    Single BERT encoder layer consisting of:
    1. Multi-head self-attention
    2. Feed-forward network
    
    Each sub-layer has residual connection + layer normalization
    """
    def __init__(self, config: BertConfig):
        super(BertLayer, self).__init__()
        self.attention = BertAttention(config=config)
        self.intermediate = BertIntermediate(config=config)
        self.output = BertOutput(config=config)

    def forward(self, hidden_states, attention_mask=None):
        """
        Args:
            hidden_states: shape (batch_size, seq_len, hidden_size)
            attention_mask: shape (batch_size, 1, 1, seq_len)

        Returns:
            layer_output: shape (batch_size, seq_len, hidden_size)
        """
        # step 7.1: Apply attention block
        # Input: (batch_size, seq_len, hidden_size)
        # Output: (batch_size, seq_len, hidden_size)
        attention_output = self.attention(hidden_states, attention_mask)

        # Step 7.2: Apply feed-forward network - intermediate expansion
        # Input: (batch_size, seq_len, hidden_size)
        # Output: (batch_size, seq_len, intermediate_size)
        intermediate_output = self.intermediate(attention_output)

        # Step 7.3: Apply feed-forward network - project back with residual
        # Inputs: (batch_size, seq_len, intermediate_size), (batch_size, seq_len, hidden_size)
        # Output: (batch_size, seq_len, hidden_size)
        return self.output(intermediate_output, attention_output)

In [21]:
# ============================================================================
# STEP 8: ENCODER - Stack of encoder layers
# ============================================================================

class BertEncoder(nn.Module):
    """
    BERT Encoder consisting of multiple stacked encoder layers.
    For BERT-base: 12 layers
    """
    def __init__(self, config):
        super(BertEncoder, self).__init__()

        # Create a list of 12 encoder layers
        self.layer = nn.ModuleList([
            BertLayer(config=config) for _ in range(config.num_hidden_layers)
        ])
    
    def forward(self, hidden_states, attention_mask=None):
        """
        Args:
            hidden_states: Embeddings, shape (batch_size, seq_len, hidden_size)
            attention_mask: shape (batch_size, 1, 1, seq_len)
        
        Returns:
            hidden_states: Final encoder output, shape (batch_size, seq_len, hidden_size)
        """
        # Step 8.1: Pass through each encoder layer sequrentially
        # Each layer : Input (batch_size, seq_len, hidden_size) -> Output (batch_size, seq_len, hidden_size)
        for i, layer_module in enumerate(self.layer):
            # Pass through layer i (i=0 to 11 for BERT-base)
            hidden_states = layer_module(hidden_states, attention_mask)
            # After each layer shape remains (batch_size, seq_len, hidden_size)
        return hidden_states

In [22]:
# ============================================================================
# STEP 9: POOLER - Extract [CLS] token representation
# ============================================================================

class BertPooler(nn.Module):
    """
    Pools the output by taking the hidden state of [CLS] token (first_token).
    Used for classification tasks
    """
    def __init__(self, config):
        super(BertPooler, self).__init__()

        # Linear layer + tanh activation
        # Input/Output: (batch_size, hidden_size)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        """
        Args:
            hidden_states: Encoder output, shape (batch_size, seq_len, hidden_size)

        Returns:
            pooled_output: [CLS] representation, shape (batch_size, hidden_size)
        """
        # Step 9.1: Extract first token ([CLS]) hidden state.
        # Input: (batch_size, seq_len, hidden_size)
        # Output: (batch_size, hidden_size)
        first_token_tensor = hidden_states[:, 0]
        
        # Step 9.2: Apply linear transformation
        # Input: (batch_size, hidden_size)
        # Output: (batch_size, hidden_size)
        pooled_output = self.dense(first_token_tensor)

        # Step 9.3: Apply tanh activation
        # Input/Output: (batch_size, hidden_size)
        return self.activation(pooled_output)

In [23]:
# ============================================================================
# STEP 10: COMPLETE BERT MODEL - Putting it all together
# ============================================================================

class BertModel(nn.Module):
    """
    Complete BERT model combining all components
    1. Embeddings (token + position + segment)
    2. Encoder (12 transformer layers)
    3. Pooler (for [CLS] token)
    """
    def __init__(self, config):
        super(BertModel, self).__init__()
        self.config = config
        
        # Initialize all components
        self.embeddings = BertEmbedding(config=config)
        self.encoder = BertEncoder(config=config)
        self.pooler = BertPooler(config=config)

    def get_extended_attention_mask(self, attention_mask):
        """
        Creates extended attention mask for multi-head attention.
        Converts 1s (attend) and 0s (don't attend) to 0s and -10000s.

        Args:
            attention_mask: shape (batch_size, seq_len) with 1s and 0s.

        Returns:
            extended_attention_mask: shape (batch_size, 1, 1, seq_len)
        """
        # Step 10.1: Add dimensions for broadcasting
        # Input: (batch_size, seq_len)
        # Output: (batch_size, 1, 1, seq_len)
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Steop 10.2: Convert to float and create mask values
        # 1 -> 0.0 (attend), 0 -> -10000.0 (don't attend)
        # Shape: (batch_size, 1, 1, seq_len)
        extended_attention_mask = extended_attention_mask.to(dtype=torch.float32)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        return extended_attention_mask
    
    def forward(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None
    ):
        """
        Forward pass through BERT model.

        Args:
            input_ids: Token IDs shape (batch_size, seq_len)
            attention_mask: Mask for padding, shape (batch_size, seq_len), Optional
            token_type_ids: Segment IDs shape (batch_size, seq_len) Optional

        Returns:
            sequence_ouptput: All token respresentations, shape (batch_size, seq_len, hidden_size)
            pooled_output: [CLS] token representation, shape (batch_size, hidden_size)
        """
        # Step 10.3: Create attention mask if not provided
        if attention_mask is None:
            # Default: attend to all tokens
            # Shape: (batch_size, seq_len)
            attention_mask = torch.ones_like(input=input_ids)
        
        # Step 10.4: Extend attention mask for mult-head attention
        # Input: (batch_size, seq_len)
        # Output: (batch_size, 1, 1, seq_len)
        extended_attention_mask = self.get_extended_attention_mask(attention_mask=attention_mask)

        # Step 10.5: Get embeddings
        # Input: input_ids (batch_size, seq_len), token_type_ids (batch_size, seq_len)
        # Output: (batch_size, seq_len, hidden_size)
        embedding_output = self.embeddings(input_ids, token_type_ids)

        # Step 10.6: Pass through encoder layers
        # Input: (batch_size, seq_len, hidden_size), attention_mask (batch_size, 1, 1, seq_len)
        # Output: (batch_size, seq_len, hidden_size)
        encoder_output = self.encoder(embedding_output, extended_attention_mask)

        # Step 10.7: Pool [CLS] token representation
        # Input: (batch_size, seq_len, hidden_size)
        # Output: (batch_size, hidden_size)
        pooled_output = self.pooler(encoder_output)

        # encoder_output: (batch_size, seq_len, hidden_size) - all token representations
        # pooled_output: (batch_size, hidden_size) - [CLS] token for classification
        return encoder_output, pooled_output

In [29]:
# ============================================================================
# STEP 11: BERT FOR MASKED LANGUAGE MODELING (Pre-training Task)
# ============================================================================

class BertForMaskedLM(nn.Module):
    """
    Bert with masked language modelling head.
    predicts masked tokens in the input sequence.
    """
    def __init__(self, config: BertConfig):
        super(BertForMaskedLM, self).__init__()
        
        self.bert = BertModel(config=config)

        # MLM prediction head
        self.mlm_dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.mlm_activation = nn.GELU()
        self.mlm_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        # Final layer to predict vocabulary
        # Input: (batch_size, seq_len, hidden_size)
        # Output: (batch_size, seq_len, vocab_size)
        self.mlm_classifier = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        """
        Args:
            input_ids: shape (batch_size, seq_len)
            attention_mask: shape (batch_size, seq_len)
            token_type_ids: shape (batch_size, seq_len)
        
        Returns:
            prediction_scores: shape (batch_size, seq_len, vocab_size)
        """
        # Step 11.1: Get BERT outputs
        # sequence_output: (batch_size, seq_len, hidden_size)
        # pooled_output: (batch_size, hidden_size)
        sequence_output, pooled_output = self.bert(input_ids, attention_mask, token_type_ids)

        # Step 11.2: Apply MLM head transformation
        # Input: (batch_size, seq_len, hidden_size)
        # Output: (batch_size, seq_len, hidden_size)
        hidden_states = self.mlm_dense(sequence_output)

        # Step 11.3: Apply GELU activation function
        # Input: (batch_size, seq_len, hidden_size)
        hidden_states = self.mlm_activation(hidden_states)

        # Step 11.4: Apply layer normalization
        # Input: (batch_size, seq_len, hidden_size)
        hidden_states = self.mlm_layer_norm(hidden_states)

        # Step 11.5: Project to vocabulary size
        # Input: (batch_size, seq_len, hidden_size)
        # Output: (batch_size, seq_len, vocab_size=30522)
        prediction_scores = self.mlm_classifier(hidden_states)
        return prediction_scores

In [30]:
# ============================================================================
# STEP 12: BERT FOR SEQUENCE CLASSIFICATION (Fine-tuning Task)
# ============================================================================

class BertForSequenceClassification(nn.Module):
    """
    BERT for classification tasks (e.g., sentiment analysis).
    Uses [CLS] token representation for classification.
    """
    def __init__(self, config: BertConfig, num_labels):
        super(BertForSequenceClassification, self).__init__()
        self.num_labels = num_labels
        self.bert = BertModel(config=config)

        # Dropout for regularizaton
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # Classification head
        # Input: (batch_size, hidden_size)
        # Output: (batch_size, num_labels)
        self.classifier = nn.Linear(config.hidden_size, num_labels)

    def forward(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None
    ):
        """
        Args:
            input_ids: shape (batch_size, seq_len)
            attention_mask: shape (batch_size, seq_len)
            token_type_ids: shape (batch_size, seq_len)
        
        Returns:
            logit: shape (batch-size, num_labels)
        """
        # Step 12.1: Get BERT outputs
        # sequence_output: (batch_size, seq_len, hidden_size)
        # pooled_output: (batch_size, hidden_size) - [CLS] representation.
        sequence_output, pooled_output = self.bert(input_ids, attention_mask, token_type_ids)

        # Step 12.2: Apply dropout to [CLS] representation
        # input/Output:  (batch_size, hidden_size)
        pooled_output = self.dropout(pooled_output)

        # Step 12.3: Apply classification layer
        # Input: (batch_size, hidden_size)
        # Output: (batch_size, num_labels)
        logits = self.classifier(pooled_output)
        return logits

In [35]:
# ============================================================================
# STEP 13: EXAMPLE USAGE
# ============================================================================

if __name__ == "__main__":
    print("="*80)
    print("BERT MODEL FROM SCRATCH - EXAMPLE USAGE")
    print("="*80)
    
    # Create configuration
    config = BertConfig()

    # Example 1: Basic BERT Model
    print("\n1. BASIC BERT MODEL")
    print("-" * 80)
    model = BertModel(config)

    # Create dummy input
    batch_size = 2
    seq_len = 128

    # Input IDs: (batch_size, seq_len) - random token IDs
    input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
    print(f"Input IDs shape: {input_ids.shape}")  # (2, 128)

    # Attention mask: (batch_size, seq_len) - 1 for real tokens, 0 for padding
    attention_mask = torch.ones(batch_size, seq_len)
    print(f"Attention mask shape: {attention_mask.shape}")  # (2, 128)

    # Token type IDs: (batch_size, seq_len) - 0 for sentence A, 1 for sentence B
    token_type_ids = torch.zeros(batch_size, seq_len, dtype=torch.long)
    token_type_ids[:, seq_len//2:] = 1  # Second half is sentence B
    print(f"Token type IDs shape: {token_type_ids.shape}")  # (2, 128)

    # Forward pass
    sequence_output, pooled_output = model(input_ids, attention_mask, token_type_ids)

    print(f"\nSequence output shape: {sequence_output.shape}")  # (2, 128, 768)
    print(f"Pooled output shape: {pooled_output.shape}")  # (2, 768)

    # Example 2: BERT for Masked Language Modeling
    print("\n2. BERT FOR MASKED LANGUAGE MODELING")
    print("-" * 80)
    mlm_model = BertForMaskedLM(config)

    prediction_scores = mlm_model(input_ids, attention_mask, token_type_ids)
    print(f"MLM prediction scores shape: {prediction_scores.shape}")  # (2, 128, 30522)

    # Get predicted token for first masked position
    predicted_token_id = torch.argmax(prediction_scores[0, 0, :])
    print(f"Predicted token ID at position 0: {predicted_token_id.item()}")

    # Example 3: BERT for Sequence Classification
    print("\n3. BERT FOR SEQUENCE CLASSIFICATION (e.g., Sentiment Analysis)")
    print("-" * 80)
    num_labels = 2  # Binary classification (positive/negative)
    classification_model = BertForSequenceClassification(config, num_labels)

    logits = classification_model(input_ids, attention_mask, token_type_ids)
    print(f"Classification logits shape: {logits.shape}")  # (2, 2)

    # Get predictions
    predictions = torch.argmax(logits, dim=-1)
    print(f"Predictions: {predictions}")  # Tensor of class indices

    # Calculate probabilities
    probabilities = F.softmax(logits, dim=-1)
    print(f"Probabilities shape: {probabilities.shape}")  # (2, 2)
    print(f"Sample probabilities: {probabilities[0]}")

    # Model statistics
    print("\n4. MODEL STATISTICS")
    print("-" * 80)
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    print("\n" + "="*80)
    print("COMPLETE! All components working correctly.")
    print("="*80)

BERT MODEL FROM SCRATCH - EXAMPLE USAGE

1. BASIC BERT MODEL
--------------------------------------------------------------------------------
Input IDs shape: torch.Size([2, 128])
Attention mask shape: torch.Size([2, 128])
Token type IDs shape: torch.Size([2, 128])

Sequence output shape: torch.Size([2, 128, 768])
Pooled output shape: torch.Size([2, 768])

2. BERT FOR MASKED LANGUAGE MODELING
--------------------------------------------------------------------------------
MLM prediction scores shape: torch.Size([2, 128, 30522])
Predicted token ID at position 0: 14106

3. BERT FOR SEQUENCE CLASSIFICATION (e.g., Sentiment Analysis)
--------------------------------------------------------------------------------
Classification logits shape: torch.Size([2, 2])
Predictions: tensor([0, 0])
Probabilities shape: torch.Size([2, 2])
Sample probabilities: tensor([0.6052, 0.3948], grad_fn=<SelectBackward0>)

4. MODEL STATISTICS
----------------------------------------------------------------------

In [27]:
x = torch.arange(5)

print(f"Original X:\n {x}\n\n")
print(f"After (1, -1):\n {x.expand(1, -1)}\n\n")
print(f"After (2, -1):\n {x.expand(2, -1)}\n\n")
print(f"After (3, -1):\n {x.expand(3, -1)}\n\n")

Original X:
 tensor([0, 1, 2, 3, 4])


After (1, -1):
 tensor([[0, 1, 2, 3, 4]])


After (2, -1):
 tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]])


After (3, -1):
 tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]])




In [36]:
x.size(), x.shape

(torch.Size([5]), torch.Size([5]))

In [37]:
x = torch.randn(1, 3)
x

tensor([[-0.9517, -0.9720,  0.1327]])

In [38]:
x.size(), x.shape

(torch.Size([1, 3]), torch.Size([1, 3]))

In [12]:
x.expand(3, 5, -1)

tensor([[[-2.2346,  1.5581, -0.1645],
         [-2.2346,  1.5581, -0.1645],
         [-2.2346,  1.5581, -0.1645],
         [-2.2346,  1.5581, -0.1645],
         [-2.2346,  1.5581, -0.1645]],

        [[-2.2346,  1.5581, -0.1645],
         [-2.2346,  1.5581, -0.1645],
         [-2.2346,  1.5581, -0.1645],
         [-2.2346,  1.5581, -0.1645],
         [-2.2346,  1.5581, -0.1645]],

        [[-2.2346,  1.5581, -0.1645],
         [-2.2346,  1.5581, -0.1645],
         [-2.2346,  1.5581, -0.1645],
         [-2.2346,  1.5581, -0.1645],
         [-2.2346,  1.5581, -0.1645]]])

In [27]:
x = torch.tensor([[1, 2, 3]])     # shape: (1, 3)
x

tensor([[1, 2, 3]])

In [31]:
x.expand(10, -1, -1)

tensor([[[1, 2, 3]],

        [[1, 2, 3]],

        [[1, 2, 3]],

        [[1, 2, 3]],

        [[1, 2, 3]],

        [[1, 2, 3]],

        [[1, 2, 3]],

        [[1, 2, 3]],

        [[1, 2, 3]],

        [[1, 2, 3]]])

In [16]:
x = torch.randn(2, 3)
x

tensor([[ 0.9807, -0.8164, -0.3877],
        [-1.5797,  1.1740,  0.4903]])

In [19]:
x.expand(3, 2, 3)

tensor([[[ 0.9807, -0.8164, -0.3877],
         [-1.5797,  1.1740,  0.4903]],

        [[ 0.9807, -0.8164, -0.3877],
         [-1.5797,  1.1740,  0.4903]],

        [[ 0.9807, -0.8164, -0.3877],
         [-1.5797,  1.1740,  0.4903]]])

In [22]:
torch.arange(config.max_position_embeddings).expand(2, -1)

tensor([[  0,   1,   2,  ..., 509, 510, 511],
        [  0,   1,   2,  ..., 509, 510, 511]])