In [1]:
"""
GPT-style Causal Language Model Implementation

This module implements a decoder-only transformer architecture similar to GPT
for causal language modeling tasks. The model is designed for pretraining on
large text corpora using next-token prediction.

Architecture:
    - Token + Position Embeddings
    - N Transformer Decoder Blocks
    - Layer Normalization
    - Language Modeling Head
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

In [None]:


class MultiHeadAttention(nn.Module):
    """
    Multi-head self-attention mechanism with causal masking.

    This implementation splits the hidden dimension across multiple attention heads,
    allowing the model to attend to different representation subspaces simultaneously.
    Causal masking ensures that positions can only attend to earlier positions.

    Args:
        hidden_size (int): Dimension of input embeddings
        num_attention_heads (int): Number of parallel attention heads
        dropout (float): Dropout probability for attention weights
    """

    def __init__(self, hidden_size: int, num_attention_heads: int, dropout: float = 0.1):
        super().__init__()
        assert hidden_size % num_attention_heads == 0, \
            "hidden_size must be divisible by num_attention_heads"

        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.head_dim = hidden_size // num_attention_heads
        self.scale = math.sqrt(self.head_dim)

        # Query, Key, Value projections for all heads (batched)
        self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
        self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)

        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Forward pass of multi-head attention.

        Args:
            hidden_states: Input tensor of shape (batch_size, seq_len, hidden_size)
            attention_mask: Optional mask of shape (batch_size, 1, 1, seq_len)

        Returns:
            Output tensor of shape (batch_size, seq_len, hidden_size)
        """
        batch_size, seq_len, _ = hidden_states.size()

        # Project to Q, K, V and split into multiple heads
        qkv = self.qkv_proj(hidden_states)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_attention_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch, heads, seq_len, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Compute attention scores: (batch, heads, seq_len, seq_len)
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale

        # Apply causal mask (prevent attending to future tokens)
        if attention_mask is not None:
            attn_weights = attn_weights.masked_fill(attention_mask == 0, float('-inf'))

        # Softmax and dropout
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, v)  # (batch, heads, seq_len, head_dim)

        # Concatenate heads and project back
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)
        attn_output = self.out_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        return attn_output


In [None]:

class FeedForward(nn.Module):
    """
    Position-wise feed-forward network.

    A two-layer MLP with GELU activation, applied independently to each position.
    This adds non-linearity and increases model capacity.

    Args:
        hidden_size (int): Input/output dimension
        intermediate_size (int): Hidden dimension (typically 4x hidden_size)
        dropout (float): Dropout probability
    """

    def __init__(self, hidden_size: int, intermediate_size: int, dropout: float = 0.1):
        super().__init__()
        self.fc1 = nn.Linear(hidden_size, intermediate_size)
        self.fc2 = nn.Linear(intermediate_size, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply two-layer FFN with GELU activation."""
        x = self.fc1(x)
        x = F.gelu(x)  # Gaussian Error Linear Unit
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [None]:


class TransformerBlock(nn.Module):
    """
    Single transformer decoder block.

    Consists of:
        1. Multi-head self-attention with causal masking
        2. Feed-forward network
        3. Layer normalization and residual connections

    Args:
        hidden_size (int): Dimension of embeddings
        num_attention_heads (int): Number of attention heads
        intermediate_size (int): FFN intermediate dimension
        dropout (float): Dropout probability
        layer_norm_eps (float): Layer norm epsilon for numerical stability
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        intermediate_size: int,
        dropout: float = 0.1,
        layer_norm_eps: float = 1e-5
    ):
        super().__init__()
        self.ln_1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.attn = MultiHeadAttention(hidden_size, num_attention_heads, dropout)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.ffn = FeedForward(hidden_size, intermediate_size, dropout)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Forward pass with pre-norm architecture.

        Args:
            hidden_states: Input tensor (batch_size, seq_len, hidden_size)
            attention_mask: Optional attention mask

        Returns:
            Output tensor (batch_size, seq_len, hidden_size)
        """
        # Self-attention with residual connection
        residual = hidden_states # why residual connection? because we want to add the input to the output of the attention block
        hidden_states = self.ln_1(hidden_states)
        hidden_states = self.attn(hidden_states, attention_mask)
        hidden_states = residual + hidden_states

        # Feed-forward with residual connection
        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        hidden_states = self.ffn(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


In [None]:

class GPTLMModel(nn.Module):
    """
    GPT-style Causal Language Model.

    A decoder-only transformer architecture for autoregressive language modeling.
    The model predicts the next token given all previous tokens in the sequence.

    Architecture components:
        - Token embeddings
        - Learned positional embeddings
        - Stack of transformer decoder blocks
        - Final layer normalization
        - Language modeling head (projects to vocabulary)

    Args:
        vocab_size (int): Size of vocabulary
        hidden_size (int): Dimension of embeddings
        num_hidden_layers (int): Number of transformer blocks
        num_attention_heads (int): Number of attention heads per block
        intermediate_size (int): FFN intermediate dimension
        max_position_embeddings (int): Maximum sequence length
        dropout (float): Dropout probability
        layer_norm_eps (float): Layer norm epsilon
    """

    def __init__(
        self,
        vocab_size: int,
        hidden_size: int,
        num_hidden_layers: int,
        num_attention_heads: int,
        intermediate_size: int,
        max_position_embeddings: int,
        dropout: float = 0.1,
        layer_norm_eps: float = 1e-5
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings

        # Embeddings
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.dropout = nn.Dropout(dropout)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(
                hidden_size,
                num_attention_heads,
                intermediate_size,
                dropout,
                layer_norm_eps
            ) for _ in range(num_hidden_layers)
        ])

        # Final layer norm
        self.ln_f = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

        # Language modeling head
        self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)

        # Weight tying: share weights between token embeddings and lm_head
        self.lm_head.weight = self.token_embeddings.weight

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize model weights using Xavier/He initialization."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
        """
        Create causal attention mask to prevent attending to future tokens.

        Args:
            seq_len: Length of sequence
            device: Device to create mask on

        Returns:
            Causal mask of shape (1, 1, seq_len, seq_len)
        """
        mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
        mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions
        return mask

    def forward(
        self,
        input_ids: torch.Tensor,
        labels: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Forward pass of the model.

        Args:
            input_ids: Token indices of shape (batch_size, seq_len)
            labels: Optional labels for computing loss (batch_size, seq_len)

        Returns:
            logits: Token prediction logits (batch_size, seq_len, vocab_size)
            loss: Cross-entropy loss if labels provided, else None
        """
        batch_size, seq_len = input_ids.size()
        device = input_ids.device

        # Create position ids
        position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
        position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_len)

        # Get embeddings
        token_embeds = self.token_embeddings(input_ids)
        position_embeds = self.position_embeddings(position_ids)
        hidden_states = token_embeds + position_embeds
        hidden_states = self.dropout(hidden_states)

        # Create causal mask
        attention_mask = self.get_causal_mask(seq_len, device)

        # Apply transformer blocks
        for block in self.blocks:
            hidden_states = block(hidden_states, attention_mask)

        # Final layer norm
        hidden_states = self.ln_f(hidden_states)

        # Get logits
        logits = self.lm_head(hidden_states)

        # Compute loss if labels provided
        loss = None
        if labels is not None:
            # Shift logits and labels for next-token prediction
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # Flatten and compute cross-entropy loss
            loss = F.cross_entropy(
                shift_logits.view(-1, self.vocab_size),
                shift_labels.view(-1),
                ignore_index=-100  # Ignore padding tokens
            )

        return logits, loss

