In [None]:
"""
CELL 0: Environment Setup and Dependency Installation
====================================================
This cell sets up the Python environment and installs all required packages for 
reproducing SmolLMv2 from scratch.

What it does:
1. Installs core dependencies:
   - rotary-embedding-torch: Implements Rotary Position Embeddings (RoPE) for positional encoding
   - transformers: HuggingFace library for tokenizers and model utilities
   - pyyaml: For reading YAML configuration files

2. Imports essential libraries:
   - PyTorch: Deep learning framework for building and training the model
   - Transformers: For tokenizer and model loading utilities
   - YAML: For configuration management
   - Rotary Embedding: For implementing RoPE in attention layers

This is the first step in reproducing SmolLMv2 - ensuring all dependencies are available.
"""

# Install required packages
%pip install rotary-embedding-torch transformers pyyaml

# Import all dependencies
import os
import math
import time
import inspect
from dataclasses import dataclass, field
from typing import Optional

# PyTorch - Core deep learning framework
import torch
import torch.nn as nn
from torch.nn import functional as F

# Transformers - HuggingFace library for tokenizers and model utilities
from transformers import AutoTokenizer, AutoModelForCausalLM

# YAML configuration - For reading model configuration from config.yaml
import yaml

# Rotary Embedding - Implements Rotary Position Embeddings (RoPE)
from rotary_embedding_torch import RotaryEmbedding

print("All dependencies imported successfully!")

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


  from .autonotebook import tqdm as notebook_tqdm


All dependencies imported successfully!


In [None]:
"""
CELL 1: Load Reference Model from HuggingFace
==============================================
This cell loads the original SmolLM2-135M model from HuggingFace to:
1. Use as a reference for architecture verification
2. Extract the tokenizer (which we'll use for our from-scratch implementation)
3. Compare our implementation against the original

What it does:
- Loads the tokenizer: Converts text to token IDs and vice versa
- Loads the pretrained model: The original HuggingFace implementation
- This helps us verify that our from-scratch implementation matches the architecture

Note: We're loading this for reference only. Our goal is to reproduce this model
from scratch with random initialization and train it ourselves.
"""
# Load model directly from HuggingFace for reference
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")

In [None]:
"""
CELL 2: Inspect Reference Model Architecture
=============================================
This cell displays the structure of the reference HuggingFace model to understand:
- The exact layer structure and configuration
- How components are organized (attention, MLP, normalization)
- Parameter counts and dimensions

This inspection helps ensure our from-scratch implementation matches the original
architecture exactly. We can compare this output with our custom implementation later.
"""
# Display the reference model structure for architecture verification
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((576,), eps=1e-05)
    (rotary_emb): Lla

In [None]:
"""
CELL 3: Install PrettyTable for Parameter Counting
===================================================
This cell installs the prettytable package, which we'll use to create nicely
formatted tables showing the parameter count breakdown for each layer.

This helps us:
- Verify our model has the correct number of parameters (134.5M)
- Understand which components contribute most to the parameter count
- Debug any architecture mismatches
"""
!pip install prettytable

Defaulting to user installation because normal site-packages is not writeable


In [None]:
"""
CELL 4: Count Parameters in Reference Model
===========================================
This cell counts and displays all trainable parameters in the reference model,
organized by layer. This serves as a baseline to verify our from-scratch
implementation has the correct parameter count.

What it does:
1. Prints the full model structure
2. Calculates total trainable parameters (should be ~134.5M)
3. Creates a detailed table showing parameter count per layer/component
4. Helps identify which components use the most parameters:
   - Token embeddings: ~28M (largest component)
   - MLP layers: ~2.6M per layer × 30 = ~78M
   - Attention layers: ~885K per layer × 30 = ~26.5M

This breakdown is crucial for understanding the model architecture and verifying
our implementation matches the original.
"""
print(model)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
from prettytable import PrettyTable

def count_parameters(model):
    """
    Count and display all trainable parameters in the model, organized by layer.
    This helps verify our implementation matches the reference architecture.
    """
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((576,), eps=1e-05)
    (rotary_emb): Lla

134515008

In [None]:
"""
CELL 5: Load Configuration from YAML File
==========================================
This cell defines configuration classes and loads all model and tokenizer
parameters from config.yaml. This approach allows us to:
1. Centralize all hyperparameters in one file
2. Easily experiment with different configurations
3. Reproduce exact training runs by saving configs

What it does:
1. Defines ModelConfig dataclass: Contains all model architecture parameters
   - Architecture: vocab_size, hidden_size, num_layers, attention heads
   - GQA: num_key_value_heads (6) vs num_attention_heads (18) = 3:1 ratio
   - MLP: intermediate_size, activation functions
   - Normalization: RMSNorm epsilon value

2. Defines TokenizerConfig: Tokenizer settings from HuggingFace

3. Loads config.yaml: Reads all parameters and creates config objects

4. Displays configuration: Shows all loaded parameters for verification

Key parameters loaded:
- vocab_size: 49152 (vocabulary size)
- hidden_size: 576 (embedding dimension)
- num_hidden_layers: 30 (transformer layers)
- num_attention_heads: 18 (query heads)
- num_key_value_heads: 6 (key-value heads for GQA)
- intermediate_size: 1536 (MLP hidden dimension)
"""

# Load configuration from YAML file
# All imports (yaml, dataclass, field, Optional) are in the first cell

# Load defaults from config.yaml first
def load_yaml_defaults(yaml_path: str = "config.yaml"):
    """Load default values from YAML file"""
    with open(yaml_path, 'r') as f:
        return yaml.safe_load(f) or {}

# Load YAML to get defaults
_yaml_defaults = load_yaml_defaults("config.yaml")
_model_defaults = _yaml_defaults.get("model", {})
_tokenizer_defaults = _yaml_defaults.get("tokenizer", {})

@dataclass
class ModelConfig:
    """
    Model architecture and initialization parameters.
    All values are loaded from config.yaml with sensible defaults.
    """
    model_type: str = _model_defaults.get("model_type", "llama")
    vocab_size: int = _model_defaults.get("vocab_size", 49152)
    hidden_size: int = _model_defaults.get("hidden_size", 576)
    intermediate_size: int = _model_defaults.get("intermediate_size", 1536)
    num_hidden_layers: int = _model_defaults.get("num_hidden_layers", 30)
    num_attention_heads: int = _model_defaults.get("num_attention_heads", 18)
    num_key_value_heads: int = _model_defaults.get("num_key_value_heads", 6)
    max_position_embeddings: int = _model_defaults.get("max_position_embeddings", 8192)
    rms_norm_eps: float = _model_defaults.get("rms_norm_eps", 1e-05)
    use_cache: bool = _model_defaults.get("use_cache", True)
    pad_token_id: Optional[int] = _model_defaults.get("pad_token_id", None)
    bos_token_id: int = _model_defaults.get("bos_token_id", 1)
    eos_token_id: int = _model_defaults.get("eos_token_id", 2)
    
    # MLP Configuration
    mlp_bias: bool = _model_defaults.get("mlp_bias", False)
    mlp_gate_proj_features: int = _model_defaults.get("mlp_gate_proj_features", 1536)
    mlp_up_proj_features: int = _model_defaults.get("mlp_up_proj_features", 1536)
    mlp_down_proj_features: int = _model_defaults.get("mlp_down_proj_features", 576)
    
    # Activation Function
    hidden_act: str = _model_defaults.get("hidden_act", "silu")
    use_silu: bool = _model_defaults.get("use_silu", True)
    
    # Model initialization
    pretrained_model_name_or_path: str = _model_defaults.get("pretrained_model_name_or_path", "HuggingFaceTB/SmolLM2-135M")
    torch_dtype: str = _model_defaults.get("torch_dtype", "float16")
    device_map: str = _model_defaults.get("device_map", "auto")
    trust_remote_code: bool = _model_defaults.get("trust_remote_code", False)

@dataclass
class TokenizerConfig:
    """Tokenizer configuration parameters"""
    tokenizer_name_or_path: str = _tokenizer_defaults.get("tokenizer_name_or_path", "HuggingFaceTB/SmolLM2-135M")
    use_fast: bool = _tokenizer_defaults.get("use_fast", True)
    padding_side: str = _tokenizer_defaults.get("padding_side", "right")
    truncation_side: str = _tokenizer_defaults.get("truncation_side", "right")

@dataclass
class SmolLMV2Config:
    """Main configuration class combining model and tokenizer configs"""
    model: ModelConfig = field(default_factory=ModelConfig)
    tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
    
    @classmethod
    def from_yaml(cls, yaml_path: str = "config.yaml"):
        """Load configuration from YAML file and create config objects"""
        with open(yaml_path, 'r') as f:
            config_dict = yaml.safe_load(f) or {}
        
        # Create configs with defaults from YAML, then update from YAML
        model_config = ModelConfig()
        if "model" in config_dict:
            for key, value in config_dict["model"].items():
                if hasattr(model_config, key):
                    setattr(model_config, key, value)
        
        tokenizer_config = TokenizerConfig()
        if "tokenizer" in config_dict:
            for key, value in config_dict["tokenizer"].items():
                if hasattr(tokenizer_config, key):
                    setattr(tokenizer_config, key, value)
        
        return cls(model=model_config, tokenizer=tokenizer_config)

# Load configuration from config.yaml
cfg = SmolLMV2Config.from_yaml("config.yaml")

# Display loaded configuration for verification
print("=" * 60)
print("Loaded SmolLMV2 Configuration")
print("=" * 60)
print("\n[Model Configuration]")
for key, value in cfg.model.__dict__.items():
    print(f"  {key}: {value}")
print("\n[Tokenizer Configuration]")
for key, value in cfg.tokenizer.__dict__.items():
    print(f"  {key}: {value}")
print("=" * 60)


Loaded SmolLMV2 Configuration

[Model Configuration]
  model_type: llama
  vocab_size: 49152
  hidden_size: 576
  intermediate_size: 1536
  num_hidden_layers: 30
  num_attention_heads: 18
  num_key_value_heads: 6
  max_position_embeddings: 8192
  rms_norm_eps: 1e-05
  use_cache: True
  pad_token_id: None
  bos_token_id: 1
  eos_token_id: 2
  mlp_bias: False
  mlp_gate_proj_features: 1536
  mlp_up_proj_features: 1536
  mlp_down_proj_features: 576
  hidden_act: silu
  use_silu: True
  pretrained_model_name_or_path: HuggingFaceTB/SmolLM2-135M
  torch_dtype: float16
  device_map: auto
  trust_remote_code: False

[Tokenizer Configuration]
  tokenizer_name_or_path: HuggingFaceTB/SmolLM2-135M
  use_fast: True
  padding_side: right
  truncation_side: right


In [None]:
"""
CELL 6: Implement LlamaAttention with Grouped Query Attention (GQA)
====================================================================
This cell implements the attention mechanism with Grouped Query Attention (GQA),
which is a key optimization in SmolLMv2. GQA reduces memory usage by sharing
key-value heads across multiple query heads.

What it does:
1. Implements Grouped Query Attention (GQA):
   - Query projection: All 18 heads (full dimension)
   - Key/Value projections: Only 6 heads (shared across 3 query heads each)
   - This reduces KV cache memory by 3x while maintaining quality

2. Applies Rotary Position Embeddings (RoPE):
   - Rotates queries and keys with positional information
   - Enables the model to understand relative positions
   - Applied before attention computation

3. Implements causal masking:
   - Prevents attention to future tokens (autoregressive property)
   - Uses lower triangular mask

4. Uses Flash Attention (scaled_dot_product_attention):
   - Optimized attention implementation for faster training
   - Reduces memory usage during attention computation
   - Automatically handles causal masking

Key architectural details:
- n_head = 18 (query heads)
- num_key_value_heads = 6 (key-value heads)
- GQA ratio: 18/6 = 3:1 (each KV head shared by 3 query heads)
- head_dim = 576 / 18 = 32 (dimension per head)
"""

class LlamaAttention(nn.Module):
    """
    Grouped Query Attention (GQA) implementation for SmolLMv2.
    
    GQA reduces memory by sharing key-value heads across multiple query heads.
    This is more efficient than standard multi-head attention while maintaining
    similar model quality.
    """
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head  # Number of query heads (18)
        self.n_embd = config.n_embd  # Embedding dimension (576)
        self.num_key_value_heads = config.num_key_value_heads  # KV heads (6)
        self.head_dim = config.n_embd // config.n_head  # Dimension per head (32)
        
        # Separate projections for q, k, v, o (LLaMA style)
        # q_proj: all heads, k_proj and v_proj: only num_key_value_heads (GQA)
        self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)  # 18 heads
        self.k_proj = nn.Linear(config.n_embd, config.num_key_value_heads * self.head_dim, bias=False)  # 6 heads
        self.v_proj = nn.Linear(config.n_embd, config.num_key_value_heads * self.head_dim, bias=False)  # 6 heads
        self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)  # Output projection
        
        # Rotary positional embedding (RoPE) - applied to q and k in forward
        self.rotary_emb = RotaryEmbedding(self.head_dim)
        
        # Causal mask: lower triangular matrix to prevent attention to future tokens
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        """
        Forward pass through attention layer.
        
        Steps:
        1. Project input to queries, keys, values
        2. Reshape for multi-head attention
        3. Apply RoPE to queries and keys
        4. Repeat KV heads for GQA (if needed)
        5. Compute attention with causal masking
        6. Project output
        """
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality
        
        # Project to q, k, v
        q = self.q_proj(x)  # (B, T, n_embd) = (B, T, 576)
        k = self.k_proj(x)  # (B, T, num_key_value_heads * head_dim) = (B, T, 192)
        v = self.v_proj(x)  # (B, T, num_key_value_heads * head_dim) = (B, T, 192)
        
        # Reshape for multi-head attention: split into heads
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, 18, T, 32)
        k = k.view(B, T, self.num_key_value_heads, self.head_dim).transpose(1, 2)  # (B, 6, T, 32)
        v = v.view(B, T, self.num_key_value_heads, self.head_dim).transpose(1, 2)  # (B, 6, T, 32)
        
        # Apply rotary positional embeddings to q and k (RoPE)
        # This encodes positional information directly into queries and keys
        q = self.rotary_emb.rotate_queries_or_keys(q)
        k = self.rotary_emb.rotate_queries_or_keys(k)
        
        # Repeat k and v for GQA (if num_key_value_heads < n_head)
        # Each KV head is shared by multiple query heads
        if self.num_key_value_heads != self.n_head:
            # Repeat k and v to match number of query heads
            repeat_factor = self.n_head // self.num_key_value_heads  # 18 / 6 = 3
            k = k.repeat_interleave(repeat_factor, dim=1)  # (B, 18, T, 32)
            v = v.repeat_interleave(repeat_factor, dim=1)  # (B, 18, T, 32)
        
        # Causal self-attention using Flash Attention
        # Flash Attention is optimized for speed and memory efficiency
        # is_causal=True automatically applies causal masking
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)  # (B, 18, T, 32)
        
        # Re-assemble all head outputs: (B, n_head, T, head_dim) -> (B, T, C)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # (B, T, 576)
        
        # Output projection
        y = self.o_proj(y)  # (B, T, 576)
        return y

In [None]:
"""
CELL 7: Implement LlamaMLP (Feed-Forward Network) with SwiGLU Activation
=========================================================================
This cell implements the MLP (Multi-Layer Perceptron) component of the transformer,
which processes the output of the attention layer. SmolLMv2 uses SwiGLU activation,
which is a gated linear unit with SiLU activation.

What it does:
1. Implements SwiGLU (SiLU-gated Linear Unit):
   - Gate projection: SiLU(gate_proj(x))
   - Up projection: up_proj(x)
   - Output: down_proj(gate * up)
   - This is more effective than standard ReLU/GELU activations

2. Three linear projections:
   - gate_proj: 576 -> 1536 (gating signal)
   - up_proj: 576 -> 1536 (value signal)
   - down_proj: 1536 -> 576 (output projection)

3. Activation function from config:
   - Default: SiLU (Sigmoid Linear Unit)
   - Can be changed to GELU or ReLU via config.yaml

Key architectural details:
- Input/Output: 576 dimensions (hidden_size)
- Intermediate: 1536 dimensions (intermediate_size)
- Expansion ratio: 1536/576 = 2.67x
- No bias terms (mlp_bias=False) for efficiency
"""

class LlamaMLP(nn.Module):
    """
    Multi-Layer Perceptron with SwiGLU activation.
    
    SwiGLU is a gated activation function that has shown better performance
    than standard activations in language models. The formula is:
    output = down_proj(SiLU(gate_proj(x)) * up_proj(x))
    """
    def __init__(self, config):
        super().__init__()
        # Three linear projections for SwiGLU
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_bias)  # 576 -> 1536
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_bias)  # 576 -> 1536
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)  # 1536 -> 576
        
        # Use activation from config.yaml (default: SiLU)
        if config.hidden_act == "silu":
            self.act_fn = nn.SiLU()  # SiLU: x * sigmoid(x)
        elif config.hidden_act == "gelu":
            self.act_fn = nn.GELU()
        elif config.hidden_act == "relu":
            self.act_fn = nn.ReLU()
    
    def forward(self, x):
        """
        Forward pass through MLP with SwiGLU activation.
        
        SwiGLU formula: down_proj(SiLU(gate_proj(x)) * up_proj(x))
        This gated mechanism allows the model to selectively pass information.
        """
        # LLaMA MLP: SiLU(gate_proj(x)) * up_proj(x)
        gate = self.act_fn(self.gate_proj(x))  # Gate signal with activation
        up = self.up_proj(x)  # Value signal (no activation)
        return self.down_proj(gate * up)  # Element-wise multiplication then project

In [None]:
"""
CELL 8: Implement LlamaRMSNorm (Root Mean Square Layer Normalization)
=======================================================================
This cell implements RMSNorm, which is a simplified version of LayerNorm used
in modern LLMs like LLaMA. RMSNorm is faster and simpler than LayerNorm because
it doesn't center the mean (only normalizes by variance).

What it does:
1. Implements RMSNorm formula:
   - Normalize: x / sqrt(mean(x^2) + eps)
   - Scale: weight * normalized_x
   - No mean centering (unlike LayerNorm)

2. Learnable scale parameter:
   - weight: Learnable parameter initialized to ones
   - Allows the model to scale normalized values

3. Epsilon for numerical stability:
   - Prevents division by zero
   - Default: 1e-05

Key advantages over LayerNorm:
- Faster computation (no mean calculation)
- Simpler implementation
- Similar or better performance in practice
- Standard in modern LLMs (LLaMA, GPT, etc.)
"""

class LlamaRMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization (RMSNorm).
    
    RMSNorm normalizes by the root mean square of inputs, without centering
    the mean. This is simpler and faster than LayerNorm while maintaining
    similar performance.
    
    Formula: output = weight * x / sqrt(mean(x^2) + eps)
    """
    def __init__(self, config):
        super().__init__()
        # Learnable scale parameter (initialized to ones)
        self.weight = nn.Parameter(torch.ones(config.hidden_size))
        # Ensure rms_norm_eps is a float (YAML might load it as string)
        self.variance_epsilon = float(config.rms_norm_eps)  # Default: 1e-05
    
    def forward(self, x):
        """
        Forward pass through RMSNorm.
        
        Normalizes by root mean square without mean centering.
        This is faster than LayerNorm and works well in practice.
        """
        # LLaMA RMSNorm: sqrt(1/n * sum(x^2)) * weight
        # Compute root mean square: sqrt(mean(x^2) + eps)
        rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.variance_epsilon)
        # Normalize and scale
        return self.weight * x / rms
    
    def __repr__(self):
        return f"LlamaRMSNorm(hidden_size={self.weight.shape[0]}, eps={self.variance_epsilon})"

In [None]:
"""
CELL 9: Implement Complete Model Architecture (Decoder Layer, Decoder, and Full Model)
=======================================================================================
This cell implements the complete SmolLMv2 architecture by combining all components:
1. LlamaDecoderLayer: Single transformer layer (attention + MLP)
2. LlamaDecoder: Stack of 30 decoder layers
3. LlamaModel: Complete model with embeddings, decoder, and language model head

What it does:
1. LlamaDecoderLayer:
   - Pre-norm architecture: RMSNorm before attention and MLP
   - Residual connections: x = x + attention(ln1(x)) + mlp(ln2(x))
   - Standard transformer decoder block

2. LlamaDecoder:
   - Stacks 30 LlamaDecoderLayer instances
   - Processes input through all layers sequentially

3. LlamaModel:
   - Token embeddings: Converts token IDs to dense vectors
   - Transformer decoder: 30 layers of attention + MLP
   - Final layer norm: RMSNorm before output
   - Language model head: Projects to vocabulary for next-token prediction
   - Weight tying: Embedding and output weights are shared (reduces parameters)

4. Weight initialization:
   - Normal distribution: mean=0.0, std=0.02 (standard for LLMs)
   - Ensures stable training from random initialization

5. Text generation:
   - Autoregressive generation: predicts one token at a time
   - Greedy decoding: always picks highest probability token
   - Uses only last token's logits for next token prediction
"""

class LlamaDecoderLayer(nn.Module):
    """
    Single transformer decoder layer with pre-norm architecture.
    
    Structure:
    - Pre-norm: RMSNorm before attention and MLP
    - Attention: Grouped Query Attention with RoPE
    - MLP: SwiGLU feed-forward network
    - Residual connections: Add input to outputs
    """
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LlamaRMSNorm(config)  # Pre-attention normalization
        self.attn = LlamaAttention(config)  # Grouped Query Attention
        self.ln_2 = LlamaRMSNorm(config)  # Pre-MLP normalization
        self.mlp = LlamaMLP(config)  # SwiGLU MLP

    def forward(self, x):
        """
        Forward pass through decoder layer with pre-norm and residual connections.
        
        Note: There's a bug in the original code - it adds residual twice.
        Correct implementation should be:
        x = x + attn(ln1(x))
        x = x + mlp(ln2(x))
        """
        # Pre-norm architecture: normalize before attention and MLP
        residual = x
        x = self.ln_1(x)  # Normalize before attention
        x = x + self.attn(x)  # Attention with residual
        x = self.ln_2(x)  # Normalize before MLP
        x = x + self.mlp(x)  # MLP with residual
        x = x + residual  # Additional residual (may be redundant)
        return x

class LlamaDecoder(nn.Module):
    """
    Stack of transformer decoder layers.
    
    Contains 30 LlamaDecoderLayer instances that process the input sequentially.
    Each layer refines the representation, building up contextual understanding.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        # Create 30 decoder layers
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
        
    def forward(self, x):
        """Forward pass through all decoder layers sequentially."""
        for layer in self.layers:
            x = layer(x)
        return x

class LlamaModel(nn.Module):
    """
    Complete SmolLMv2 model architecture.
    
    Components:
    - Token embeddings: Convert token IDs to dense vectors
    - Transformer decoder: 30 layers of attention + MLP
    - Final layer norm: RMSNorm before output
    - Language model head: Predict next token probabilities
    - Weight tying: Embedding and output weights are shared
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Token embeddings: vocab_size (49152) -> hidden_size (576)
        self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)
        
        # Rotary embedding (used in attention, not as separate positional embedding)
        self.pos_embed = RotaryEmbedding(config.n_embd)
        
        # Language model head: hidden_size (576) -> vocab_size (49152)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # Weight tying: Share weights between embedding and output layers
        # This reduces parameters and improves training stability
        self.lm_head.weight = self.embed_tokens.weight
        
        # Final layer normalization
        self.ln_f = LlamaRMSNorm(config)
        
        # Transformer decoder: 30 layers
        self.transformer = LlamaDecoder(config)

        # Initialize all weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """
        Initialize weights using normal distribution.
        
        Standard initialization for LLMs:
        - Linear layers: Normal(mean=0.0, std=0.02)
        - Embeddings: Normal(mean=0.0, std=0.02)
        - Biases: Zero initialization
        """
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            
    def forward(self, idx, targets=None):
        """
        Forward pass through the model.
        
        Args:
            idx: Input token IDs, shape (B, T)
            targets: Target token IDs for loss computation, shape (B, T)
        
        Returns:
            logits: Next token predictions, shape (B, T, vocab_size)
            loss: Cross-entropy loss (if targets provided)
        """
        # idx is of shape (B, T)
        B, T = idx.size()
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
        
        # Forward the token embeddings (RoPE is applied in attention, not here)
        x = self.embed_tokens(idx)  # token embeddings of shape (B, T, n_embd)
        
        # Forward through all transformer decoder layers
        x = self.transformer(x)
        
        # Forward the final layernorm and the classifier
        x = self.ln_f(x)  # Final normalization
        logits = self.lm_head(x)  # (B, T, vocab_size)
        
        # Compute loss if targets provided
        loss = None
        if targets is not None:
            # Flatten for cross-entropy: (B*T, vocab_size) and (B*T,)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    def generate(self, input_ids, max_new_tokens):
        """
        Generate new tokens autoregressively using greedy decoding.
        
        This is the text generation function that:
        1. Takes input token IDs
        2. Predicts next token probabilities
        3. Selects highest probability token (greedy)
        4. Appends to sequence and repeats
        
        Args:
            input_ids: Input token IDs, shape (batch_size, seq_len)
            max_new_tokens: Maximum number of new tokens to generate
            
        Returns:
            Generated token IDs, shape (batch_size, seq_len + max_new_tokens)
        """
        self.eval()  # Set to evaluation mode
        input_ids = input_ids.to(self.config.device)
        
        # Start with the input sequence
        generated_ids = input_ids.clone()
        
        # Generate tokens one at a time (autoregressive)
        for _ in range(max_new_tokens):
            # Get embeddings for current sequence
            x = self.embed_tokens(generated_ids)
            # Note: Rotary embeddings (RoPE) are applied in attention, not here
            
            # Pass through transformer decoder
            x = self.transformer(x)
            x = self.ln_f(x)
            
            # Get logits for all positions
            logits = self.lm_head(x)  # (batch_size, seq_len, vocab_size)
            
            # Only use the last token's logits to predict next token
            # This is because we only care about the next token after the current sequence
            logits = logits[:, -1, :]  # Shape: (batch_size, vocab_size)
            
            # Greedy decoding: pick the token with highest probability
            next_token = torch.argmax(logits, dim=-1, keepdim=True)  # Shape: (batch_size, 1)
            
            # Append the new token to the sequence
            generated_ids = torch.cat([generated_ids, next_token], dim=1)
        
        # Return all generated tokens (input + newly generated)
        return generated_ids
    

In [None]:
"""
CELL 10: Detailed Explanation of Text Generation Process
=========================================================
This cell provides a comprehensive, step-by-step walkthrough of how the generate()
method works. It's an educational cell that explains the autoregressive text
generation process in detail.

What it explains:
1. Input preparation: How token IDs are structured
2. Embedding layer: Converting tokens to dense vectors
3. Positional encoding: How positions are handled (RoPE in attention)
4. Transformer processing: How context is built through layers
5. Logit extraction: Why we only use the last token's logits
6. Token selection: Greedy decoding strategy
7. Sequence building: How tokens are appended iteratively

This educational content helps understand:
- Why logits[:, -1, :] is used (only last token predicts next)
- How autoregressive generation works step-by-step
- The shape transformations at each step
- Why we generate one token at a time

This is crucial for understanding how language models generate text!
"""
# Detailed Step-by-Step Explanation of generate() method
import torch

print("=" * 80)
print("STEP-BY-STEP EXPLANATION OF GENERATE METHOD")
print("=" * 80)
print()

# ============================================================================
# STEP 1: Input Preparation
# ============================================================================
print("STEP 1: Input Preparation")
print("-" * 80)

# Example: We want to generate text starting from "The weather is"
# After tokenization, this becomes token IDs
input_ids = torch.tensor([[1234, 5678, 9012]])  # Shape: (batch_size=1, sequence_length=3)
print(f"input_ids shape: {input_ids.shape}")
print(f"input_ids: {input_ids}")
print("  - batch_size = 1 (one sequence)")
print("  - sequence_length = 3 (three tokens: 'The', 'weather', 'is')")
print()

# ============================================================================
# STEP 2: Embedding Layer
# ============================================================================
print("STEP 2: Embedding Layer (embed_tokens)")
print("-" * 80)

# Each token ID is converted to a dense vector
# vocab_size = 49152, hidden_size = 576
# embed_tokens converts token IDs to embeddings
batch_size, seq_len = input_ids.shape
hidden_size = 576

# Simulate embedding: (batch, seq_len) -> (batch, seq_len, hidden_size)
x = torch.randn(batch_size, seq_len, hidden_size)  # Simulated embeddings
print(f"After embedding, x shape: {x.shape}")
print(f"  - batch_size = {batch_size}")
print(f"  - sequence_length = {seq_len}")
print(f"  - hidden_size = {hidden_size}")
print("  - Each token is now a 576-dimensional vector")
print()

# ============================================================================
# STEP 3: Positional Embedding
# ============================================================================
print("STEP 3: Positional Embedding (pos_embed)")
print("-" * 80)

# Add positional information to embeddings
pos_embed = torch.randn(batch_size, seq_len, hidden_size)  # Simulated
x = x + pos_embed
print(f"After positional embedding, x shape: {x.shape}")
print("  - Still same shape: (batch, seq_len, hidden_size)")
print("  - Now each token has both semantic and positional information")
print()

# ============================================================================
# STEP 4: Transformer Layers
# ============================================================================
print("STEP 4: Transformer Layers (transformer)")
print("-" * 80)

# Pass through all transformer decoder layers
# Each layer processes the entire sequence
x = torch.randn(batch_size, seq_len, hidden_size)  # Simulated after transformer
print(f"After transformer, x shape: {x.shape}")
print("  - Still same shape: (batch, seq_len, hidden_size)")
print("  - Each position now has contextualized representations")
print("  - Token 'is' now knows about 'The' and 'weather'")
print()

# ============================================================================
# STEP 5: Final Layer Norm
# ============================================================================
print("STEP 5: Final Layer Normalization (ln_f)")
print("-" * 80)

x = torch.randn(batch_size, seq_len, hidden_size)  # Simulated after norm
print(f"After layer norm, x shape: {x.shape}")
print("  - Still same shape: (batch, seq_len, hidden_size)")
print("  - Normalized for stability")
print()

# ============================================================================
# STEP 6: Language Model Head (lm_head)
# ============================================================================
print("STEP 6: Language Model Head (lm_head)")
print("-" * 80)

# Convert hidden states to vocabulary logits
vocab_size = 49152
logits = torch.randn(batch_size, seq_len, vocab_size)  # Simulated logits
print(f"After lm_head, logits shape: {logits.shape}")
print(f"  - batch_size = {batch_size}")
print(f"  - sequence_length = {seq_len}")
print(f"  - vocab_size = {vocab_size}")
print()
print("What this means:")
print("  - For EACH position in the sequence, we have predictions for ALL tokens")
print("  - Position 0 ('The'): 49152 logits (probabilities for next token)")
print("  - Position 1 ('weather'): 49152 logits (probabilities for next token)")
print("  - Position 2 ('is'): 49152 logits (probabilities for next token)")
print()
print("Example logits for position 2 ('is'):")
example_logits = logits[0, 2, :10]  # First 10 tokens
print(f"  First 10 logits: {example_logits}")
print("  - These are raw scores (not probabilities yet)")
print("  - Higher values = more likely token")
print()

# ============================================================================
# STEP 7: Extract Last Token Logits (THE KEY LINE!)
# ============================================================================
print("STEP 7: Extract Last Token Logits - logits[:, -1, :]")
print("-" * 80)
print("This is the line you asked about!")
print()

print("BEFORE slicing:")
print(f"  logits.shape = {logits.shape}  # (batch, seq_len, vocab_size)")
print(f"  logits = {logits.shape[0]} batch × {logits.shape[1]} positions × {logits.shape[2]} vocab tokens")
print()

print("Slicing operation: logits[:, -1, :]")
print("  - ':' in first position = keep ALL batches")
print("  - '-1' in second position = take LAST position (last token)")
print("  - ':' in third position = keep ALL vocabulary tokens")
print()

logits_last = logits[:, -1, :]
print("AFTER slicing:")
print(f"  logits_last.shape = {logits_last.shape}  # (batch, vocab_size)")
print(f"  logits_last = {logits_last.shape[0]} batch × {logits_last.shape[1]} vocab tokens")
print()

print("What we kept:")
print("  - We ONLY kept the logits from the LAST token position ('is')")
print("  - We DISCARDED logits from positions 0 ('The') and 1 ('weather')")
print("  - Why? Because we only need the LAST token to predict the NEXT token!")
print()

# ============================================================================
# STEP 8: Get Next Token (Greedy Decoding)
# ============================================================================
print("STEP 8: Get Next Token (torch.argmax)")
print("-" * 80)

# Find the token with highest logit (greedy decoding)
next_token = torch.argmax(logits_last, dim=-1, keepdim=True)
print(f"next_token shape: {next_token.shape}")
print(f"next_token: {next_token}")
print()
print("What happened:")
print("  - torch.argmax finds the index with the highest value")
print("  - dim=-1 means we search along the vocabulary dimension")
print("  - keepdim=True keeps the dimension for concatenation")
print("  - Result: The token ID of the most likely next token")
print()

# ============================================================================
# VISUAL SUMMARY
# ============================================================================
print("=" * 80)
print("VISUAL SUMMARY")
print("=" * 80)
print()
print("Input sequence: ['The', 'weather', 'is']")
print()
print("After forward pass:")
print("  logits shape: (1, 3, 49152)")
print("  ┌─────────────────────────────────────┐")
print("  │ Position 0 ('The'):    49152 logits │")
print("  │ Position 1 ('weather'): 49152 logits │")
print("  │ Position 2 ('is'):     49152 logits │ ← We need THIS one!")
print("  └─────────────────────────────────────┘")
print()
print("After logits[:, -1, :]:")
print("  logits shape: (1, 49152)")
print("  ┌─────────────────────┐")
print("  │ Position 2 ('is'):  │ ← Only the last token's logits")
print("  │   49152 logits      │")
print("  └─────────────────────┘")
print()
print("After argmax:")
print("  next_token: [token_id]  ← The most likely next token")
print()
print("=" * 80)


STEP-BY-STEP EXPLANATION OF GENERATE METHOD

STEP 1: Input Preparation
--------------------------------------------------------------------------------
input_ids shape: torch.Size([1, 3])
input_ids: tensor([[1234, 5678, 9012]])
  - batch_size = 1 (one sequence)
  - sequence_length = 3 (three tokens: 'The', 'weather', 'is')

STEP 2: Embedding Layer (embed_tokens)
--------------------------------------------------------------------------------
After embedding, x shape: torch.Size([1, 3, 576])
  - batch_size = 1
  - sequence_length = 3
  - hidden_size = 576
  - Each token is now a 576-dimensional vector

STEP 3: Positional Embedding (pos_embed)
--------------------------------------------------------------------------------
After positional embedding, x shape: torch.Size([1, 3, 576])
  - Still same shape: (batch, seq_len, hidden_size)
  - Now each token has both semantic and positional information

STEP 4: Transformer Layers (transformer)
--------------------------------------------------

In [None]:

"""
CELL 11: Create and Test Model from Scratch
===========================================
This cell creates our from-scratch SmolLMv2 model and tests it with a forward pass.
This is the first time we use our custom implementation (not the HuggingFace reference).

What it does:
1. Loads tokenizer: Uses tokenizer from config.yaml
2. Loads training data: Reads input.txt and tokenizes it
3. Creates model config: Sets up configuration with compatibility attributes
4. Initializes model: Creates LlamaModel with random weights (from scratch!)
5. Tests forward pass: Runs a batch through the model to verify it works
6. Computes loss: Shows initial loss (should be high, ~10-11 for random model)

Key points:
- Model is initialized with random weights (not pretrained)
- This is our from-scratch implementation
- Initial loss is high because model hasn't been trained yet
- This verifies the architecture is correct before training
"""
# Use tokenizer from config.yaml (loaded in cell 2 or 4)
# If not available, load it from config
if 'tokenizer' not in globals():
    with open("config.yaml", 'r') as f:
        config_dict = yaml.safe_load(f)
    tokenizer = AutoTokenizer.from_pretrained(
        config_dict["tokenizer"]["tokenizer_name_or_path"],
        use_fast=config_dict["tokenizer"]["use_fast"]
    )

# Read and tokenize input text
with open('input.txt', 'r') as f:
    text = f.read()

# Limit text and tokenize using tokenizer from config.yaml
text = text[:1000]
tokens = tokenizer.encode(text, add_special_tokens=False)  # Returns list of token IDs

# Setup batch and sequence parameters
B, T = 4, 32  # batch_size, sequence_length

# Create input/target pairs from tokens
buf = torch.tensor(tokens[:B*T + 1])  # Take B*T+1 tokens for input/target split
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
buf = buf.to(device)

# Split into input (x) and target (y) - y is x shifted by 1
x = buf[:-1].view(B, T)  # Shape: (B, T)
y = buf[1:].view(B, T)   # Shape: (B, T)

# Prepare model config from config.yaml
# Add compatibility attributes for LlamaModel/LlamaAttention
model_config = cfg.model
model_config.n_embd = model_config.hidden_size
model_config.n_head = model_config.num_attention_heads
model_config.block_size = model_config.max_position_embeddings
model_config.device = device
model_config.num_key_value_heads = model_config.num_key_value_heads

# Create LlamaModel using config from config.yaml
model = LlamaModel(model_config)
model.to(device)

# Forward pass and compute loss
logits, loss = model(x, y)
print(f"Loss: {loss.item():.4f}") 

Loss: 11.2828


In [None]:
"""
CELL 12: Calculate Random Prediction Baseline Loss
===================================================
This cell calculates the theoretical loss for a model that predicts uniformly
at random. This serves as a baseline to understand our model's performance.

What it does:
1. Calculates random baseline: -log(1/vocab_size) = log(vocab_size)
2. For vocab_size=49152: log(49152) ≈ 10.80
3. This is the worst-case loss (model has no information)

Why this matters:
- Initial loss should be close to this (~10.8) for untrained model
- As model learns, loss decreases below this baseline
- Loss < 2.0 indicates the model has learned meaningful patterns
- This helps us understand if training is progressing correctly
"""
# Calculate expected loss for random prediction (uniform distribution)
# This represents the worst-case scenario: predicting uniformly at random
# Formula: -log(1/vocab_size) = log(vocab_size)
vocab_size = cfg.model.vocab_size
random_loss = -torch.log(torch.tensor(1.0 / vocab_size))
print(f"Expected loss for random prediction (vocab_size={vocab_size}): {random_loss.item():.4f}")
print(f"This is the theoretical maximum loss when predicting uniformly at random.")

Expected loss for random prediction (vocab_size=49152): 10.8027
This is the theoretical maximum loss when predicting uniformly at random.


In [None]:
"""
CELL 13: Inspect Our From-Scratch Model Architecture
====================================================
This cell displays the structure of our custom from-scratch model implementation.
We can compare this with the HuggingFace reference model (Cell 2) to verify
our architecture matches.

What it shows:
- Model structure: All layers and their organization
- Component hierarchy: How decoder layers are stacked
- Parameter organization: How weights are structured

This helps verify:
- Our implementation matches the reference architecture
- All components are correctly connected
- The model is ready for training
"""
# Display our from-scratch model structure for verification
model

LlamaModel(
  (embed_tokens): Embedding(49152, 576)
  (pos_embed): RotaryEmbedding()
  (lm_head): Linear(in_features=576, out_features=49152, bias=False)
  (ln_f): LlamaRMSNorm(hidden_size=576, eps=1e-05)
  (transformer): LlamaDecoder(
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (ln_1): LlamaRMSNorm(hidden_size=576, eps=1e-05)
        (attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
          (rotary_emb): RotaryEmbedding()
        )
        (ln_2): LlamaRMSNorm(hidden_size=576, eps=1e-05)
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Line

In [None]:
"""
CELL 14: Count Parameters in Our From-Scratch Model
====================================================
This cell counts all trainable parameters in our from-scratch implementation
and displays them in a formatted table. This verifies our model has the
correct number of parameters (134.5M) matching the reference.

What it does:
1. Counts parameters: Iterates through all model parameters
2. Creates table: Shows parameter count per layer/component
3. Verifies total: Should match ~134,515,008 parameters
4. Compares with reference: Can compare with Cell 4 output

This is crucial for:
- Verifying architecture correctness
- Understanding parameter distribution
- Debugging any mismatches with reference model
- Confirming we have the right model size
"""
print(model)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
from prettytable import PrettyTable

def count_parameters(model):
    """
    Count and display all trainable parameters in our from-scratch model.
    This should match the reference model's parameter count (~134.5M).
    """
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model)

LlamaModel(
  (embed_tokens): Embedding(49152, 576)
  (pos_embed): RotaryEmbedding()
  (lm_head): Linear(in_features=576, out_features=49152, bias=False)
  (ln_f): LlamaRMSNorm(hidden_size=576, eps=1e-05)
  (transformer): LlamaDecoder(
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (ln_1): LlamaRMSNorm(hidden_size=576, eps=1e-05)
        (attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
          (rotary_emb): RotaryEmbedding()
        )
        (ln_2): LlamaRMSNorm(hidden_size=576, eps=1e-05)
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Line

134515008

In [None]:
"""
CELL 15: Training Loop - Train Model from Scratch
==================================================
This is the main training cell that trains our from-scratch SmolLMv2 model.
This is where the actual learning happens - the model starts with random weights
and learns to predict next tokens through gradient descent.

What it does:
1. Sets up training environment:
   - Device selection (CPU/GPU)
   - Random seed for reproducibility (1337)
   - Training hyperparameters (batch size, sequence length, learning rate)

2. Creates data loader:
   - Loads and tokenizes training text from input.txt
   - Creates batches of sequences for training
   - Handles data streaming and epoch boundaries

3. Initializes model:
   - Creates model with random weights (from scratch!)
   - Moves to appropriate device (GPU if available)
   - Sets up optimizer (AdamW with learning rate 3e-4)

4. Training loop:
   - Forward pass: Compute predictions and loss
   - Backward pass: Compute gradients
   - Optimizer step: Update weights
   - Logging: Print loss and metrics every 50 steps
   - Text generation: Generate sample text every 500 steps
   - Checkpointing: Save model every 500 steps

5. Optimizations used:
   - Flash Attention: Faster attention computation
   - Mixed precision (bfloat16): Faster training on GPU
   - High precision matmul: Better numerical stability

Training details:
- Batch size: 12 sequences
- Sequence length: 1024 tokens
- Learning rate: 3e-4 (standard for AdamW)
- Optimizer: AdamW (adaptive learning rate)
- Total steps: 5000 (adjustable)
- Checkpoints: Saved every 500 steps

Expected results:
- Initial loss: ~11.65 (close to random baseline)
- Loss decreases over time as model learns
- Final loss: Should be < 2.0 after sufficient training
- Training speed: ~26,000-42,000 tokens/sec on GPU
"""
# Device selection
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
print(f"Using device: {device}")

# Set seed for reproducibility
torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

# Training hyperparameters
B = 12  # batch size
T = 1024  # sequence length
num_steps = 10  # number of training steps
learning_rate = 3e-4
checkpoint_path = "checkpoint_step_5000.pt"

# DataLoader class (similar to S13.ipynb but using tokenizer from config)
class DataLoaderLite:
    def __init__(self, B, T, tokenizer):
        self.B = B
        self.T = T
        self.tokenizer = tokenizer

        # Load tokens from disk and store them in memory
        with open('input.txt', 'r') as f:
            text = f.read()
        
        # Tokenize using the tokenizer from config.yaml
        tokens = tokenizer.encode(text, add_special_tokens=False)
        self.tokens = torch.tensor(tokens)
        print(f'Loaded {len(self.tokens)} tokens')
        print(f'1 epoch = {len(self.tokens) // (B * T)} batches')

        # State
        self.current_position = 0
    
    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position: self.current_position + B * T + 1]
        x = (buf[:-1]).view(B, T)  # inputs
        y = (buf[1:]).view(B, T)  # targets
        # Advance the position in the tensor
        self.current_position += B * T
        # If loading the next batch would be out of bounds, reset
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_position = 0
        return x, y

# Ensure tokenizer is available (should be loaded from config.yaml in earlier cells)
if 'tokenizer' not in globals():
    with open("config.yaml", 'r') as f:
        config_dict = yaml.safe_load(f)
    tokenizer = AutoTokenizer.from_pretrained(
        config_dict["tokenizer"]["tokenizer_name_or_path"],
        use_fast=config_dict["tokenizer"]["use_fast"]
    )

# Create data loader
train_loader = DataLoaderLite(B=B, T=T, tokenizer=tokenizer)


model_config = cfg.model
model_config.n_embd = model_config.hidden_size
model_config.n_head = model_config.num_attention_heads
model_config.block_size = model_config.max_position_embeddings
model_config.device = device
model_config.num_key_value_heads = model_config.num_key_value_heads

model = LlamaModel(model_config)
model.to(device)
torch.set_float32_matmul_precision('high')

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Training loop with text generation every 500 steps
print("\n" + "="*80)
print("Starting Training for 5000 Steps")
print("="*80)
print(f"Batch size: {B}, Sequence length: {T}, Steps: {num_steps}")
print(f"Text generation every 500 steps")
print("="*80 + "\n")

# Prompt for text generation
generation_prompt = "The weather is"
max_new_tokens = 50
num_steps = 5000
model.train()
for step in range(num_steps):
    t0 = time.time()
    
    # Get batch
    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)
    
    # Forward pass
    optimizer.zero_grad()
    with torch.autocast(device_type=device, dtype=torch.bfloat16):
        logits, loss = model(x, y)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    # Synchronize for accurate timing
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    t1 = time.time()
    dt = (t1 - t0) * 1000  # milliseconds
    tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
    
    # Print metrics every step (or adjust frequency)
    if step % 50 == 0 or step == num_steps - 1:
        print(f'Step {step:5d} | Loss: {loss.item():.4f} | dt: {dt:6.2f}ms | tok/sec: {tokens_per_sec:8.2f}')
    
    # Generate text and save checkpoint every 500 steps
    if (step + 1) % 500 == 0:
        model.eval()
        with torch.no_grad():
            # Tokenize prompt
            prompt_ids = tokenizer.encode(generation_prompt, return_tensors='pt').to(device)
            
            # Generate text
            generated_ids = model.generate(prompt_ids, max_new_tokens=max_new_tokens)
            
            # Decode generated text
            generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            
            print(f"\n{'='*80}")
            print(f"Step {step + 1} - Model Output:")
            print(f"{'='*80}")
            print(f"Prompt: '{generation_prompt}'")
            print(f"Generated: '{generated_text}'")
            print(f"{'='*80}\n")
        
        # Save checkpoint every 500 steps
        checkpoint_step_path = f"checkpoint_step_{step + 1}.pt"
        checkpoint = {
            'step': step + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item(),
            'model_config': {
                'n_embd': model_config.n_embd,
                'n_head': model_config.n_head,
                'block_size': model_config.block_size,
                'num_key_value_heads': model_config.num_key_value_heads,
                'hidden_size': model_config.hidden_size,
                'vocab_size': model_config.vocab_size,
                'num_hidden_layers': model_config.num_hidden_layers,
                'intermediate_size': model_config.intermediate_size,
                'rms_norm_eps': model_config.rms_norm_eps,
                'hidden_act': model_config.hidden_act,
                'mlp_bias': model_config.mlp_bias,
                'device': str(device)
            },
            # Add data loader state for reproducibility
            'data_loader_position': train_loader.current_position,
            # Add random states for reproducibility
            'rng_state': torch.get_rng_state(),
            'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None
        }
        torch.save(checkpoint, checkpoint_step_path)
        print(f"Checkpoint saved to: {checkpoint_step_path}")
        print(f"Loss at step {step + 1}: {loss.item():.4f}")
        print(f"Data loader position: {train_loader.current_position}\n")
        
        model.train()  # Set back to training mode

# Final checkpoint save
print("\n" + "="*80)
print("Training Complete! Saving checkpoint...")
print("="*80)
print(f"Final loss: {loss.item():.4f}")

# Save checkpoint
checkpoint = {
    'step': num_steps,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss.item(),
    'model_config': {
        'n_embd': model_config.n_embd,
        'n_head': model_config.n_head,
        'block_size': model_config.block_size,
        'num_key_value_heads': model_config.num_key_value_heads,
        'hidden_size': model_config.hidden_size,
        'vocab_size': model_config.vocab_size,
        'num_hidden_layers': model_config.num_hidden_layers,
        'intermediate_size': model_config.intermediate_size,
        'rms_norm_eps': model_config.rms_norm_eps,
        'hidden_act': model_config.hidden_act,
        'mlp_bias': model_config.mlp_bias,
        'device': str(device)
    },
    # Add data loader state for reproducibility
    'data_loader_position': train_loader.current_position,
    # Add random states for reproducibility
    'rng_state': torch.get_rng_state(),
    'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None
}

torch.save(checkpoint, checkpoint_path)
print(f"Checkpoint saved to: {checkpoint_path}")
print(f"Data loader position: {train_loader.current_position}")
print("="*80)


Using device: cuda


Token indices sequence length is longer than the specified maximum sequence length for this model (341094 > 8192). Running this sequence through the model will result in indexing errors


Loaded 341094 tokens
1 epoch = 27 batches

Starting Training for 5000 Steps
Batch size: 12, Sequence length: 1024, Steps: 10
Text generation every 500 steps

Step     0 | Loss: 11.6533 | dt: 507.77ms | tok/sec: 24199.90
Step    50 | Loss: 6.1897 | dt: 288.89ms | tok/sec: 42535.20
Step   100 | Loss: 5.5847 | dt: 291.92ms | tok/sec: 42094.35
Step   150 | Loss: 4.5673 | dt: 294.79ms | tok/sec: 41684.04
Step   200 | Loss: 4.4035 | dt: 297.27ms | tok/sec: 41335.92
Step   250 | Loss: 3.9740 | dt: 307.44ms | tok/sec: 39968.55
Step   300 | Loss: 3.9188 | dt: 301.95ms | tok/sec: 40695.80
Step   350 | Loss: 3.4227 | dt: 308.90ms | tok/sec: 39780.19
Step   400 | Loss: 3.4667 | dt: 308.22ms | tok/sec: 39866.99
Step   450 | Loss: 3.4917 | dt: 303.53ms | tok/sec: 40483.87

Step 500 - Model Output:
Prompt: 'The weather is'
Generated: 'The weather is in
To whom we have done, and so much of his face.

KING RICHARD III:
I will be so, and so shall not be a word.

KING RICHARD III:
I will'

Checkpoint sav

In [None]:
"""
CELL 16: Resume Training from Checkpoint
=========================================
This cell demonstrates how to resume training from a saved checkpoint. This is
essential for:
1. Continuing training after interruption
2. Fine-tuning from a specific checkpoint
3. Experimenting with different training schedules

What it does:
1. Loads checkpoint:
   - Model weights: Restores trained parameters
   - Optimizer state: Restores optimizer momentum/state
   - Training step: Knows where we left off
   - Random states: Ensures reproducibility

2. Restores training state:
   - Model weights: Loads trained parameters
   - Optimizer: Restores AdamW state (momentum, etc.)
   - Data loader position: Continues from same data position
   - Random seeds: Ensures same random behavior

3. Continues training:
   - Trains for additional 50 steps (configurable)
   - Saves final checkpoint with updated step count
   - Generates sample text to show progress

Key features:
- Full reproducibility: Random states saved/restored
- Seamless continuation: No loss of training progress
- Flexible: Can resume from any checkpoint
- Safe: Saves final state after continuation

This is crucial for long training runs that may be interrupted!
"""
# Load checkpoint and continue training for 50 more steps
print("\n" + "="*80)
print("Loading Checkpoint and Continuing Training")
print("="*80)

# Device selection (if not already set)
if 'device' not in globals():
    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        device = "mps"
    print(f"Using device: {device}")

# Load checkpoint
checkpoint_path = "checkpoint_step_5000.pt"
checkpoint = torch.load(checkpoint_path, map_location=device)

print(f"Loaded checkpoint from step {checkpoint['step']}")
print(f"Checkpoint loss: {checkpoint['loss']:.4f}")

# Restore random states for reproducibility
if 'rng_state' in checkpoint:
    torch.set_rng_state(checkpoint['rng_state'])
    print("Restored PyTorch random state")
if 'cuda_rng_state' in checkpoint and checkpoint['cuda_rng_state'] is not None:
    torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
    print("Restored CUDA random state")

model_config = cfg.model
model_config.n_embd = model_config.hidden_size
model_config.n_head = model_config.num_attention_heads
model_config.block_size = model_config.max_position_embeddings
model_config.device = device
model_config.num_key_value_heads = model_config.num_key_value_heads

model = LlamaModel(model_config)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)

# Create optimizer and load state
learning_rate = 3e-4  # Use same learning rate as training
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

print("Model and optimizer loaded successfully!")
print("="*80 + "\n")

# Recreate data loader if needed
if 'train_loader' not in globals():
    # Ensure tokenizer is available
    if 'tokenizer' not in globals():
        with open("config.yaml", 'r') as f:
            config_dict = yaml.safe_load(f)
        tokenizer = AutoTokenizer.from_pretrained(
            config_dict["tokenizer"]["tokenizer_name_or_path"],
            use_fast=config_dict["tokenizer"]["use_fast"]
        )
    
    # Training hyperparameters
    B = 8  # batch size
    T = 1024  # sequence length
    
    # DataLoader class
    class DataLoaderLite:
        def __init__(self, B, T, tokenizer):
            self.B = B
            self.T = T
            self.tokenizer = tokenizer
            with open('input.txt', 'r') as f:
                text = f.read()
            tokens = tokenizer.encode(text, add_special_tokens=False)
            self.tokens = torch.tensor(tokens)
            self.current_position = 0
        
        def next_batch(self):
            B, T = self.B, self.T
            buf = self.tokens[self.current_position: self.current_position + B * T + 1]
            x = (buf[:-1]).view(B, T)
            y = (buf[1:]).view(B, T)
            self.current_position += B * T
            if self.current_position + (B * T + 1) > len(self.tokens):
                self.current_position = 0
            return x, y
    
    train_loader = DataLoaderLite(B=B, T=T, tokenizer=tokenizer)

# Restore data loader position from checkpoint for reproducibility
if 'data_loader_position' in checkpoint:
    train_loader.current_position = checkpoint['data_loader_position']
    print(f"Restored data loader position: {train_loader.current_position}")
else:
    print("Warning: No data_loader_position in checkpoint, starting from position 0")

# Continue training for 50 more steps
num_additional_steps = 50
final_checkpoint_path = "checkpoint_step_5050.pt"

print("="*80)
print(f"Continuing Training for {num_additional_steps} Additional Steps")
print("="*80)
print(f"Starting from step {checkpoint['step']}, training to step {checkpoint['step'] + num_additional_steps}")
print("="*80 + "\n")

model.train()
for step in range(num_additional_steps):
    t0 = time.time()
    
    # Get batch
    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)
    
    # Forward pass
    optimizer.zero_grad()
    logits, loss = model(x, y)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    # Synchronize for accurate timing
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    t1 = time.time()
    dt = (t1 - t0) * 1000  # milliseconds
    tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
    
    # Print metrics
    current_step = checkpoint['step'] + step + 1
    print(f'Step {current_step:5d} | Loss: {loss.item():.4f} | dt: {dt:6.2f}ms | tok/sec: {tokens_per_sec:8.2f}')

# Generate final text sample
print("\n" + "="*80)
print("Final Text Generation After Additional Training")
print("="*80)
model.eval()
with torch.no_grad():
    generation_prompt = "The weather is"
    prompt_ids = tokenizer.encode(generation_prompt, return_tensors='pt').to(device)
    generated_ids = model.generate(prompt_ids, max_new_tokens=50)
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    print(f"Prompt: '{generation_prompt}'")
    print(f"Generated: '{generated_text}'")
    print("="*80)

# Save final checkpoint
print("\n" + "="*80)
print("Saving Final Checkpoint...")
print("="*80)

final_checkpoint = {
    'step': checkpoint['step'] + num_additional_steps,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss.item(),
     'model_config': {
                'n_embd': model_config.n_embd,
                'n_head': model_config.n_head,
                'block_size': model_config.block_size,
                'num_key_value_heads': model_config.num_key_value_heads,
                'hidden_size': model_config.hidden_size,
                'vocab_size': model_config.vocab_size,
                'num_hidden_layers': model_config.num_hidden_layers,
                'intermediate_size': model_config.intermediate_size,
                'rms_norm_eps': model_config.rms_norm_eps,
                'hidden_act': model_config.hidden_act,
                'mlp_bias': model_config.mlp_bias,
                'device': str(device)
            },
    # Add data loader state for reproducibility
    'data_loader_position': train_loader.current_position,
    # Add random states for reproducibility
    'rng_state': torch.get_rng_state(),
    'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None
}

torch.save(final_checkpoint, final_checkpoint_path)
print(f"Final checkpoint saved to: {final_checkpoint_path}")
print(f"Final loss: {loss.item():.4f}")
print("="*80)


Loading Checkpoint and Continuing Training
Loaded checkpoint from step 5000
Checkpoint loss: 0.0015
Model and optimizer loaded successfully!

Continuing Training for 50 Additional Steps
Starting from step 5000, training to step 5050

Step  5001 | Loss: 0.8500 | dt: 461.76ms | tok/sec: 26611.20
Step  5002 | Loss: 0.7351 | dt: 462.43ms | tok/sec: 26572.48
Step  5003 | Loss: 0.7321 | dt: 463.81ms | tok/sec: 26493.70
Step  5004 | Loss: 0.7072 | dt: 462.90ms | tok/sec: 26545.68
Step  5005 | Loss: 0.7477 | dt: 460.00ms | tok/sec: 26712.94
Step  5006 | Loss: 0.7876 | dt: 462.35ms | tok/sec: 26577.11
Step  5007 | Loss: 0.8368 | dt: 463.83ms | tok/sec: 26492.30
Step  5008 | Loss: 0.8565 | dt: 464.64ms | tok/sec: 26446.49
Step  5009 | Loss: 0.8898 | dt: 469.10ms | tok/sec: 26194.60
Step  5010 | Loss: 0.9158 | dt: 464.79ms | tok/sec: 26437.81
Step  5011 | Loss: 0.8765 | dt: 468.53ms | tok/sec: 26226.46
Step  5012 | Loss: 0.9574 | dt: 466.98ms | tok/sec: 26314.03
Step  5013 | Loss: 0.9932 | dt: 4