In [None]:
# Small LLM / Notebook created by Javier Ideami (ideami.com)
# Typical LLMs need many GPUs and millions of dollars to be trained
# This code trains a small LLM with a single GPU and little GPU memory 
# Of course results are not like a chatGPT, but they are good enough to see how the LLM trains to go
# from random combinations of letters to actual words and phrases that are sometimes decently coherent
# GPT3 has 175 Billion parameters. GPT4 has many, many more.
# This model has only 19 Million parameters with its default settings. That's why its perfect for learning 
# and experimenting

# Official notebook #vj30

In [None]:
#### For GOOGLE COLAB and similar platform Users:
#### Make sure to select a GPU in the online platform. Don't run this code with a CPU (it will be too slow)

# If you are running this code locally, your GPU should be selected automatically

In [None]:
# uncomment and run the following installation lines ONLY if you havent installed these libraries already outside of the notebook
#!pip install ipdb -q
#!pip install tqdm -q
#!pip install sentencepiece -q
#!pip install wandb -q

# And if you are not in Google Colab and you didn't yet install Pytorch, make sure to do it:
# find the ideal pytorch installation command at https://pytorch.org/get-started/locally/

In [None]:
# You can use this command to view information about your GPU and the amount of free memory it has
# Make sure that you have at last 4GB of free GPU memory to do this course
!nvidia-smi 
# If you are using Google Colab or a similar online platform, make sure to select a GPU in the menus
# In Google colab, at the moment the option is within the Runtime menus

In [None]:
### Import necessary libraries

# Standard Python libraries for file operations and system utilities
import os, sys  # Operating system interface and system-specific parameters
import ipdb  # Interactive Python debugger - allows us to pause execution and inspect variables
from tqdm import tqdm  # Progress bar library - shows training progress with visual bars
from datetime import datetime  # Date and time utilities for logging timestamps
import platform, shutil  # Platform detection and file operations

# HTTP and file handling libraries
import requests  # For downloading datasets and model files from URLs
import zipfile, io  # For extracting compressed files and handling binary data

# PyTorch - The main deep learning framework we'll use
import torch  # Core PyTorch library for tensor operations and GPU acceleration
import torch.nn as nn  # Neural network modules (layers, activations, etc.)
from torch.nn import functional as F  # Functional interface for neural network operations

# Tokenization library - converts text into numbers that neural networks can understand
import sentencepiece as spm  # Google's SentencePiece tokenizer for subword tokenization

# Performance optimizations for modern GPUs (Ampere architecture like A100, RTX 30/40 series)
# TF32 (Tensor Float 32) is a faster but slightly less precise format than FP32
torch.backends.cuda.matmul.allow_tf32 = True  # Enable TF32 for matrix multiplications (faster training)
torch.backends.cudnn.allow_tf32 = True  # Enable TF32 for cuDNN operations (faster convolutions)

# Clear GPU memory to start with a clean slate
torch.cuda.empty_cache()


In [None]:
# Download necessary files and create necessary folders
# wiki.txt - dataset: a tiny segment of the English Wikipedia
# wiki_tokenizer.model: trained tokenizer file (in another notebook I show you how to produce this file)
# wiki_tokenizer.vocab: trained tokenizer file (in another notebook I show you how to produce this file)
# encoded_data.pt (dataset tokenized with the tokenizer)
# I will explain how to produce encoded_data.pt - because it takes quite a bit to process, it's nice to have it in advance

# NOTE: Downloading will take a while, be patient. You can refresh your folder from time to time to see when the files
# have been created. If you have any problems downloading the files with this code, I have also added llm_train.zip
# to the downloadable resources of this lecture (however, best option is to use this code, because then you don't need
# to upload the files or do anything else)

files_url = "https://ideami.com/llm_train"

# Downloading proceeds if we detect that one of the key files to download is not present
if not os.path.exists(f"encoded_data.pt"):
    print("Downloading files using Python")
    response = requests.get(files_url)
    zipfile.ZipFile(io.BytesIO(response.content)).extractall(".")
else:
    print("you seem to have already downloaded the files. If you wish to re-download them, delete the encoded_data.pt file")



In [None]:
# =============================================================================
# CONFIGURATION PARAMETERS - The Heart of Our LLM Training Setup
# =============================================================================
# This section defines all the key parameters that control our model architecture,
# training process, and performance. Understanding these parameters is crucial for
# building and training effective language models.

# =============================================================================
# ARCHITECTURE PARAMETERS - Model Structure and Size
# =============================================================================
# These parameters define the physical structure of our transformer model

batch_size = 8  # Number of training samples processed simultaneously
                # Higher batch sizes = faster training but more GPU memory required
                # 8 is optimal for 4GB GPU, 128 for 24GB GPU
                # Batch size affects gradient stability and training speed

context = 512    # Maximum sequence length the model can process (context window)
                # Longer contexts = better understanding but more memory/computation
                # 512 is a good balance between performance and resource usage
                # GPT-3 uses 2048, GPT-4 uses 8192+ tokens

embed_size = 384  # Dimension of token embeddings (how "rich" each token representation is)
                  # Higher dimensions = more expressive but more parameters
                  # 384 is a good middle ground for our 19M parameter model
                  # GPT-3 uses 12288 dimensions

n_layers = 7      # Number of transformer blocks (depth of the network)
                  # More layers = more complex reasoning but more parameters
                  # 7 layers is sufficient for our small model
                  # GPT-3 has 96 layers, GPT-4 has 120+ layers

n_heads = 7       # Number of attention heads per layer
                  # More heads = more parallel attention patterns
                  # Must divide evenly into embed_size (384/7 ≈ 54.8, so we use 7)
                  # GPT-3 uses 96 heads, GPT-4 uses 128+ heads

BIAS = True       # Whether to include bias terms in linear layers
                  # Bias terms help the model learn offset values
                  # Generally recommended for better performance

# =============================================================================
# HYPERPARAMETERS - Training Behavior and Optimization
# =============================================================================
# These parameters control how the model learns and optimizes

lr = 3e-4        # Learning rate - how big steps the optimizer takes
                 # Too high = unstable training, too low = slow learning
                 # 3e-4 is a good starting point for transformer models
                 # GPT-3 uses 6e-4, but we use slightly lower for stability

dropout = 0.05   # Fraction of neurons randomly set to zero during training
                 # Prevents overfitting by adding randomness
                 # 0.05 = 5% dropout rate (relatively low for transformers)
                 # GPT-3 uses 0.1, but we use less for our smaller model

weight_decay = 0.01  # L2 regularization strength
                     # Penalizes large weights to prevent overfitting
                     # 0.01 is a moderate regularization strength
                     # Helps the model generalize better to unseen data

grad_clip = 1.0      # Maximum gradient norm before clipping
                     # Prevents gradient explosion during training
                     # 1.0 is a standard value that works well
                     # Essential for stable training of deep networks

# =============================================================================
# TRAINING PARAMETERS - Training Process Control
# =============================================================================
# These parameters control the training loop and evaluation

train_iters = 100000  # Maximum number of training iterations
                      # Each iteration processes one batch
                      # 100k iterations should be sufficient for our small model
                      # GPT-3 was trained for 300B tokens (much more)

eval_interval = 50    # How often to evaluate the model (every N iterations)
                      # More frequent = better monitoring but slower training
                      # 50 iterations is a good balance
                      # Allows us to track training progress

eval_iters = 3        # Number of batches to use for evaluation
                      # More batches = more accurate evaluation but slower
                      # 3 batches gives a good estimate of model performance
                      # Balances accuracy with speed

compile = False       # Whether to use PyTorch 2.0 compilation
                      # Compilation can speed up training by 20-30%
                      # Set to False if you encounter compatibility issues
                      # Requires PyTorch 2.0+ and compatible hardware

load_pretrained = False  # Whether to load a previously trained model
                        # Set to True to continue training from a checkpoint
                        # Useful for resuming interrupted training sessions
                        # Or for fine-tuning on new data

# =============================================================================
# CHECKPOINTING PARAMETERS - Model Saving and Loading
# =============================================================================
# These parameters control how we save and load model states

checkpoint_dir = 'models/'  # Directory where model checkpoints are saved
                            # Creates a 'models' folder in current directory
                            # Checkpoints save model weights, optimizer state, etc.

checkpoint_fn = "latest.pt"  # Filename for the latest checkpoint
                             # Updated every time we save a new checkpoint
                             # Contains the most recent model state

checkpoint_load_fn = "latest.pt"  # Filename for loading a checkpoint
                                  # Can be changed to load specific checkpoints
                                  # Example: "llm2.pt" to load a specific model
                                  # Useful for loading pre-trained models

# =============================================================================
# DATA TYPE AND DEVICE CONFIGURATION
# =============================================================================
# These parameters control computational precision and hardware usage

dtype = torch.bfloat16  # Data type for model computations
                        # bfloat16 = 16-bit floating point (faster, less memory)
                        # Alternative: torch.float32 (more precise but slower)
                        # bfloat16 is optimal for modern GPUs (A100, RTX 30/40 series)

# =============================================================================
# MODE CONFIGURATION
# =============================================================================
# Controls whether we're training or just using the model

inference = False  # Set to True to only run inference (no training)
                   # Useful for testing trained models
                   # When False, the model will train normally

# =============================================================================
# DEVICE SELECTION
# =============================================================================
# Automatically selects the best available device for computation

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device: You will be using: ", device)

# CUDA = GPU acceleration (much faster for deep learning)
# CPU = fallback for systems without GPU support
# Always prefer GPU for training large models


In [None]:
# =============================================================================
# LOGGING AND MONITORING SETUP - Track Your Training Progress
# =============================================================================
# This section sets up Weights & Biases (wandb) for monitoring training progress,
# visualizing metrics, and comparing different experiments. Wandb is a powerful
# tool for machine learning experiment tracking and visualization.

# =============================================================================
# WANDB CONFIGURATION - Experiment Tracking Setup
# =============================================================================
# Weights & Biases (wandb) is a popular platform for ML experiment tracking
# It provides real-time monitoring, visualization, and collaboration features

wandb_log = True  # Enable/disable wandb logging
                  # Set to False if you don't want to use wandb
                  # Recommended to keep True for better experiment tracking

wandb_project = "test"  # Project name in wandb dashboard
                        # Groups related experiments together
                        # Change to something descriptive like "llm-training"

wandb_run_name = "test-run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
# Create unique run names with timestamp
# Format: "test-run_2024_01_15_14_30_25"
# Helps distinguish between different training sessions

# =============================================================================
# WANDB INITIALIZATION - Start Experiment Tracking
# =============================================================================
if wandb_log:
    import wandb  # Import wandb library for experiment tracking
    wandb.init(project=wandb_project, name=wandb_run_name)
    # Initialize wandb session
    # Creates a new experiment run in your wandb dashboard
    # Will prompt for API key on first use

# =============================================================================
# WANDB API KEY SETUP - Authentication Required
# =============================================================================
# The first time you run this code with wandb_log=True, wandb will ask for an API key
# You can get your API key from: https://wandb.ai/settings#api
# 
# Alternative: Click the authorization link that appears when you run this cell
# The link will take you to: https://wandb.ai/authorize
# This is the quickest way to get authenticated
#
# What wandb tracks:
# - Training and validation loss curves
# - Learning rate schedules
# - Model parameters and hyperparameters
# - System metrics (GPU usage, memory)
# - Generated text samples
# - Training time and iteration counts


In [None]:
# =============================================================================
# DATASET LOADING - Load and Inspect Our Training Data
# =============================================================================
# This section loads the Wikipedia dataset that will be used to train our language model.
# The dataset contains a small portion of English Wikipedia text, which provides
# diverse, high-quality text for training our model to understand language patterns.

# =============================================================================
# LOAD WIKIPEDIA DATASET - Read the Training Text
# =============================================================================
with open('wiki.txt', 'r', encoding='utf-8') as f:
    text = f.read()  # Load the entire Wikipedia text into memory
                   # This creates a single string containing all the training data
                   # The text is already tokenized and ready for training

# =============================================================================
# DATASET INSPECTION - Examine the Training Data
# =============================================================================
print(text[10000:10500])  # Display a sample of the dataset
                          # Shows characters 10000-10500 from the text
                          # This gives us a preview of what the model will learn from
                          # You should see coherent English text from Wikipedia
                          # The text contains various topics, writing styles, and vocabulary

In [None]:
# =============================================================================
# TOKENIZER SETUP - Convert Text to Numbers
# =============================================================================
# This section sets up the SentencePiece tokenizer, which converts text into
# numerical tokens that our neural network can understand. Tokenization is a
# crucial step in natural language processing that bridges human language
# and machine learning.

# =============================================================================
# LOAD SENTENCEPIECE TOKENIZER - Initialize the Tokenizer
# =============================================================================
# SentencePiece is Google's subword tokenization algorithm
# It breaks text into meaningful subword units (not just words or characters)
# This allows the model to handle unknown words and reduces vocabulary size

sp = spm.SentencePieceProcessor(model_file='wiki_tokenizer.model')
# Load the pre-trained tokenizer model
# This model was trained on the same Wikipedia dataset we're using
# It knows how to split text into optimal subword units

# =============================================================================
# TOKENIZER VOCABULARY - Get the Size of Our Token Vocabulary
# =============================================================================
vocab_size = sp.get_piece_size()  # Get the total number of unique tokens
print(f"Tokenizer vocab_size: {vocab_size}")
# This tells us how many different tokens our model can work with
# Typical vocab sizes: 8K-50K tokens
# Our model uses a vocabulary of several thousand tokens

# =============================================================================
# ENCODING AND DECODING FUNCTIONS - Text ↔ Numbers Conversion
# =============================================================================
# These functions convert between human-readable text and machine-readable numbers

encode = lambda s: sp.Encode(s)  # Convert text string to list of token IDs
                                 # Example: "Hello world" → [154, 32, 2789]
                                 # Each number represents a subword unit

decode = lambda l: sp.Decode(l)  # Convert list of token IDs back to text
                                 # Example: [154, 32, 2789] → "Hello world"
                                 # Reconstructs the original text from tokens

# =============================================================================
# TOKENIZER TESTING - Verify It Works Correctly
# =============================================================================
print(decode(encode("Encoding Decoding functions ready")))
# Test the round-trip: text → tokens → text
# Should output: "Encoding Decoding functions ready"
# This confirms our tokenizer is working properly
# If this fails, there's an issue with the tokenizer model file

In [None]:
# =============================================================================
# DATASET TOKENIZATION - Convert Text to Training Data
# =============================================================================
# This section converts our raw text into numerical tokens that can be used
# for training. Tokenization is computationally expensive, so we save the
# result to avoid re-processing the same text multiple times.

# =============================================================================
# CHECK FOR PRE-TOKENIZED DATA - Avoid Re-processing
# =============================================================================
if os.path.exists(f"encoded_data.pt"):
    # Load pre-tokenized data if it already exists
    print("Loading saved encoded data")
    data = torch.load('encoded_data.pt')  # Load the pre-processed tokens
    # This is much faster than re-tokenizing the entire dataset
    # The file contains a PyTorch tensor of token IDs
else:
    # =============================================================================
    # TOKENIZE THE DATASET - Convert Text to Numbers
    # =============================================================================
    print("Encoding data")
    data = torch.tensor(encode(text), dtype=torch.long)  # Convert text to token IDs
    # encode(text) returns a list of token IDs for the entire text
    # torch.tensor() converts it to a PyTorch tensor
    # dtype=torch.long ensures the tokens are integers (required for embeddings)
    
    torch.save(data, 'encoded_data.pt')  # Save the tokenized data
    # This creates a file that can be loaded quickly in future runs
    # Saves significant time on subsequent training sessions


In [None]:
# =============================================================================
# DATASET SPLITTING - Create Training and Validation Sets
# =============================================================================
# This section splits our tokenized dataset into training and validation sets.
# This is essential for proper machine learning - we need separate data to
# evaluate our model's performance and prevent overfitting.

# =============================================================================
# CALCULATE DATASET SIZE - Measure Our Training Data
# =============================================================================
data_size = len(data)  # Get the total number of tokens in our dataset
# This tells us how much text we have available for training
# More data generally leads to better model performance

# =============================================================================
# SPLIT THE DATASET - Create Training and Validation Sets
# =============================================================================
spl = int(0.9 * data_size)  # Calculate 90% split point
# 90% for training, 10% for validation is a standard split
# This gives us enough data for training while reserving some for evaluation

train_data = data[:spl]  # First 90% of tokens for training
# The model learns from this data during training
# This is the data used to update the model's weights

val_data = data[spl:]  # Last 10% of tokens for validation
# This data is used to evaluate the model's performance
# The model never sees this data during training
# It helps us detect overfitting and measure generalization

# =============================================================================
# DATASET STATISTICS - Display Data Split Information
# =============================================================================
print(f'Total data: {data_size/1e6:.2f} Million | Training: {len(train_data)/1e6:.2f} Million | Validation: {len(val_data)/1e6:.2f} Million')
# Display the size of each dataset in millions of tokens
# This helps us understand how much data we're working with
# Typical datasets: millions to billions of tokens

# =============================================================================
# DATA PREVIEW - Examine the Tokenized Data
# =============================================================================
# data[:30] : shows the first 30 token IDs
# This gives us a preview of what the model will see during training
# Each number represents a token (word or subword unit)
# The model learns to predict the next token given the previous ones

In [None]:
# =============================================================================
# HELPER FUNCTIONS - Essential Utilities for Training
# =============================================================================
# This section contains utility functions that are crucial for the training process.
# These functions handle data batching, which is essential for efficient training
# of neural networks.

# =============================================================================
# GET_BATCH FUNCTION - Create Training Batches
# =============================================================================
# This function creates batches of training data for our model.
# It randomly samples sequences from the dataset and prepares them for training.
# Batching is essential for efficient GPU utilization and stable training.

def get_batch(split):
    """
    Create a batch of training or validation data.
    
    Args:
        split (str): Either "train" or "eval" to select dataset
        
    Returns:
        x (torch.Tensor): Input sequences of shape (batch_size, context_length)
        y (torch.Tensor): Target sequences of shape (batch_size, context_length)
    """
    
    # =============================================================================
    # SELECT DATASET - Choose Training or Validation Data
    # =============================================================================
    data = train_data if split == "train" else val_data
    # train_data: Used for training the model (updating weights)
    # val_data: Used for evaluation (measuring performance)
    
    # =============================================================================
    # RANDOM SAMPLING - Select Random Starting Positions
    # =============================================================================
    inds = torch.randint(len(data) - context, (batch_size,))  # (BS)
    # Randomly select starting positions for each sequence in the batch
    # len(data) - context ensures we don't go out of bounds
    # batch_size determines how many sequences we sample
    
    # =============================================================================
    # CREATE INPUT SEQUENCES - Extract Context Windows
    # =============================================================================
    x = torch.stack([data[i: i+context] for i in inds])  # (BS, SL)
    # Extract sequences of length 'context' starting at random positions
    # Each sequence is a context window for the model to process
    # Shape: (batch_size, context_length)
    
    # =============================================================================
    # CREATE TARGET SEQUENCES - Next Token Prediction
    # =============================================================================
    y = torch.stack([data[i+1: i+context+1] for i in inds])  # (BS, SL)
    # Target sequences are the input sequences shifted by 1 position
    # This creates the "next token" prediction task
    # The model learns to predict y[i] given x[i]
    
    # =============================================================================
    # EXAMPLE OF INPUT-OUTPUT PAIRS
    # =============================================================================
    # First 10 elements of first batch of inputs and labels:
    # x[0][:10] -> tensor([ 664,  278, 4031, 4056, 4065, 4062, 4062, 4051, 13, 13])
    # y[0][:10] -> tensor([ 278, 4031, 4056, 4065, 4062, 4062, 4051,   13, 13, 4066])
    # Notice how y is x shifted by 1 position - this is the next token prediction task
    
    # =============================================================================
    # MOVE TO DEVICE - Transfer Data to GPU/CPU
    # =============================================================================
    x, y = x.to(device), y.to(device)
    # Move tensors to the appropriate device (GPU or CPU)
    # This ensures the data is on the same device as the model
    # GPU is much faster for matrix operations
    
    return x, y



In [None]:
# =============================================================================
# BATCH FUNCTION TESTING - Verify Data Batching Works Correctly
# =============================================================================
# This section provides a way to test the get_batch function to ensure
# it's working properly. Uncomment the lines below to see the batch structure
# and verify that the data shapes and content are correct.

# =============================================================================
# TEST BATCH CREATION - Examine Batch Structure
# =============================================================================
# Uncomment the following lines to test your get_batch function:
# x, y = get_batch("train")
# print(f"x.shape: {x.shape}")  # Should be (batch_size, context_length)
# print(f"y.shape: {y.shape}")  # Should be (batch_size, context_length)
# print(x[0][:10])  # First 10 tokens of the first sequence in the batch
# print(y[0][:10])  # First 10 target tokens (should be x shifted by 1)

# =============================================================================
# WHAT TO EXPECT - Understanding the Output
# =============================================================================
# x.shape: (8, 512) - 8 sequences, each 512 tokens long
# y.shape: (8, 512) - Same shape as x
# x[0][:10]: [664, 278, 4031, 4056, 4065, 4062, 4062, 4051, 13, 13]
# y[0][:10]: [278, 4031, 4056, 4065, 4062, 4062, 4051, 13, 13, 4066]
# Notice how y is x shifted by 1 position - this is the next token prediction task

In [None]:
# =============================================================================
# GPT MODEL ARCHITECTURE - The Heart of Our Language Model
# =============================================================================
# This section defines our GPT (Generative Pre-trained Transformer) model.
# It's a decoder-only transformer architecture that learns to predict the next
# token in a sequence, which is the foundation of modern language models.

# =============================================================================
# MODEL SPECIFICATIONS - Understanding Our Architecture
# =============================================================================
# 19 million parameters with the default configuration
# Can be trained with 1 single GPU
# With 8 Batch Size, should require 4 GB of GPU Memory
# With 128 Batch Size, should require 24 GB of GPU Memory
# Adjust Batch Size as needed for less or more memory and training speed
# Because of small dataset and model, results will be limited but enough to
# demonstrate good improvement during the training and understand all the
# main technology involved in building LLMs

class GPT(nn.Module):
    """
    GPT (Generative Pre-trained Transformer) Model
    
    This is a decoder-only transformer architecture that learns to predict
    the next token in a sequence. It consists of:
    - Token embeddings
    - Position embeddings  
    - Multiple transformer blocks
    - Layer normalization
    - Final linear layer for vocabulary prediction
    """

    def __init__(self):
        super().__init__()
        
        # =============================================================================
        # EMBEDDING LAYERS - Convert Tokens to Rich Representations
        # =============================================================================
        self.embeddings = nn.Embedding(vocab_size, embed_size)
        # Token embeddings: Convert token IDs to dense vectors
        # Shape: (vocab_size, embed_size) - lookup table for each token
        # Each token gets a unique dense representation
        
        self.positions = nn.Embedding(context, embed_size)
        # Position embeddings: Encode position information
        # Shape: (context, embed_size) - one embedding per position
        # Allows the model to understand word order and position
        
        # =============================================================================
        # TRANSFORMER BLOCKS - The Core Processing Units
        # =============================================================================
        self.blocks = nn.Sequential(*[Block(n_heads) for _ in range(n_layers)])
        # Stack multiple transformer blocks
        # Each block contains multi-head attention and feed-forward layers
        # More blocks = deeper model = more complex reasoning
        
        # =============================================================================
        # FINAL LAYERS - Output Processing
        # =============================================================================
        self.ln = nn.LayerNorm(embed_size)  # Final layer normalization
        # Normalizes the output before the final linear layer
        # Helps with training stability and performance
        
        self.final_linear = nn.Linear(embed_size, vocab_size, bias=BIAS)
        # Final linear layer: Maps from embedding space to vocabulary
        # Outputs logits for each token in the vocabulary
        # Shape: (batch_size, sequence_length, vocab_size)
        
        # =============================================================================
        # WEIGHT INITIALIZATION - Start with Good Weights
        # =============================================================================
        self.apply(self._init_weights)  # Initialize all weights

    def _init_weights(self, module):
        """
        Initialize model weights with appropriate distributions.
        Good initialization is crucial for training stability.
        """
        if isinstance(module, nn.Linear):
            # Initialize weight matrices with normal distribution
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            # Small standard deviation prevents initial activations from being too large
            # Mean 0 centers the distribution around zero
            
            # Initialize bias parameters to 0
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
                # Zero bias is a good starting point for most layers
                
        elif isinstance(module, nn.Embedding):
            # Initialize embedding weights with normal distribution
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            # Same initialization as linear layers for consistency

    def forward(self, input, targets=None):
        """
        Forward pass through the GPT model.
        
        Args:
            input (torch.Tensor): Input token IDs of shape (batch_size, sequence_length)
            targets (torch.Tensor, optional): Target token IDs for training
            
        Returns:
            logits (torch.Tensor): Predicted logits for each token
            loss (torch.Tensor, optional): Cross-entropy loss if targets provided
        """
        
        # =============================================================================
        # INPUT PROCESSING - Convert Tokens to Embeddings
        # =============================================================================
        BS, SL = input.shape  # Batch Size, Sequence Length
        # Example: (8, 512) for batch_size=8, context=512
        
        emb = self.embeddings(input)  # (BS, SL, embed_size)
        # Convert token IDs to dense embeddings
        # Each token becomes a vector of size embed_size
        
        pos = self.positions(torch.arange(SL, device=device))  # (SL, embed_size)
        # Get position embeddings for each position in the sequence
        # Position 0 gets pos[0], position 1 gets pos[1], etc.
        
        x = emb + pos  # (BS, SL, embed_size)
        # Add token and position embeddings
        # This gives each token both semantic and positional information
        
        # =============================================================================
        # TRANSFORMER PROCESSING - Apply Multiple Blocks
        # =============================================================================
        x = self.blocks(x)  # (BS, SL, embed_size)
        # Pass through all transformer blocks
        # Each block applies multi-head attention and feed-forward layers
        # The model learns complex relationships between tokens
        
        x = self.ln(x)  # (BS, SL, embed_size)
        # Final layer normalization
        # Normalizes the output before the final linear layer
        
        # =============================================================================
        # OUTPUT GENERATION - Predict Next Tokens
        # =============================================================================
        logits = self.final_linear(x)  # (BS, SL, vocab_size)
        # Convert embeddings to vocabulary logits
        # Each position gets a probability distribution over all tokens
        
        # =============================================================================
        # LOSS CALCULATION - Cross-Entropy Loss for Training
        # =============================================================================
        loss = None
        if targets is not None:
            # Reshape for cross-entropy loss calculation
            BS, SL, VS = logits.shape  # (BS, SL, vocab_size)
            logits = logits.view(BS * SL, VS)  # (BS*SL, vocab_size)
            targets = targets.view(BS * SL)    # (BS*SL,)
            
            # Calculate cross-entropy loss
            loss = F.cross_entropy(logits, targets)
            # Cross-entropy measures how well our predictions match the targets
            # Lower loss = better predictions
            
            # =============================================================================
            # MANUAL LOSS CALCULATION (Optional - for educational purposes)
            # =============================================================================
            # Uncomment the following lines to see manual cross-entropy calculation:
            # counts = logits.exp()  # (BS*SL, vocab_size)
            # prob = counts / counts.sum(-1, keepdim=True)  # (BS*SL, vocab_size)
            # loss2 = -prob[torch.arange(BS*SL), targets].log().mean()
            # 
            # This shows the mathematical steps behind cross-entropy:
            # 1. Apply softmax to get probabilities
            # 2. Select probabilities for target tokens
            # 3. Take negative log (higher probability = lower loss)
            # 4. Take mean across all predictions

        return logits, loss

    def generate(self, input, max=500):
        """
        Generate new text by sampling from the model.
        
        Args:
            input (torch.Tensor): Starting sequence of token IDs
            max (int): Maximum number of tokens to generate
            
        Returns:
            torch.Tensor: Generated sequence of token IDs
        """
        
        # =============================================================================
        # TEXT GENERATION - Autoregressive Sampling
        # =============================================================================
        for _ in range(max):  # Generate up to 'max' tokens
            # Keep only the last 'context' tokens (sliding window)
            input = input[:, -context:]  # (1, context)
            # This ensures we don't exceed the model's context length
            
            # Get predictions from the model
            logits, _ = self(input)  # (1, context, vocab_size)
            # Get logits for the entire sequence
            
            # Focus on the last position (next token prediction)
            logits = logits[:, -1, :]  # (1, vocab_size)
            # We only care about predicting the next token
            
            # Convert logits to probabilities
            probs = F.softmax(logits, dim=-1)  # (1, vocab_size)
            # Softmax converts logits to probability distribution
            
            # Sample next token from the distribution
            next = torch.multinomial(probs, num_samples=1)  # (1, 1)
            # Multinomial sampling: higher probability tokens are more likely
            # This adds randomness to generation (not just greedy selection)
            
            # Append the new token to the sequence
            input = torch.cat((input, next), dim=1)  # (1, context+1)
            # The sequence grows by one token each iteration
            
        return input

In [None]:
# =============================================================================
# TRANSFORMER BLOCK CLASS - The Building Blocks of Our Model
# =============================================================================
# This section defines the Transformer Block, which is the core component
# of our GPT model. Each block combines attention mechanisms with feed-forward
# networks to process and understand relationships in the input sequence.

class Block(nn.Module):
    """
    Transformer Block - The Core Processing Unit
    
    A transformer block combines communication (attention) and computation (feed-forward)
    to process sequences. It helps the model understand relationships between tokens
    and perform complex reasoning through attention mechanisms.
    
    Architecture:
    - Multi-head attention (communication between tokens)
    - Feed-forward network (computation and transformation)
    - Layer normalization (stability and performance)
    - Residual connections (gradient flow and training stability)
    """
    
    def __init__(self, n_heads):
        super().__init__()
        
        # =============================================================================
        # ATTENTION MECHANISM - Multi-Head Attention Setup
        # =============================================================================
        head_size = embed_size // n_heads  # Calculate size of each attention head
        # We split the embedding dimensions among the number of heads
        # Example: embed_size=384, n_heads=7 → head_size=54 (approximately)
        # Each head processes a subset of the embedding dimensions
        
        self.ma = Multihead(n_heads, head_size)  # Multi-head attention layer
        # This is the core attention mechanism that allows tokens to "attend" to each other
        # Multiple heads allow the model to focus on different types of relationships
        
        # =============================================================================
        # FEED-FORWARD NETWORK - Computation and Transformation
        # =============================================================================
        self.feed_forward = ForwardLayer(embed_size)
        # Feed-forward network that processes each token independently
        # Applies non-linear transformations to the attention output
        # Increases the model's capacity for complex reasoning
        
        # =============================================================================
        # LAYER NORMALIZATION - Training Stability
        # =============================================================================
        self.ln1 = nn.LayerNorm(embed_size)  # Normalization before attention
        self.ln2 = nn.LayerNorm(embed_size)  # Normalization before feed-forward
        
        # LayerNorm normalizes the inputs across the features for each data point independently
        # It subtracts the mean and divides by the standard deviation, followed by scaling and shifting
        # This helps with training stability and allows for higher learning rates
        # More computationally intensive than RMSNorm but offers greater flexibility

    def forward(self, x):
        """
        Forward pass through the transformer block.
        
        Args:
            x (torch.Tensor): Input embeddings of shape (batch_size, sequence_length, embed_size)
            
        Returns:
            torch.Tensor: Processed embeddings of the same shape
        """
        
        # =============================================================================
        # ATTENTION PATH - Communication Between Tokens
        # =============================================================================
        x = x + self.ma(self.ln1(x))  # Residual connection around attention
        # 1. Normalize the input (ln1)
        # 2. Apply multi-head attention (ma)
        # 3. Add residual connection (x + ...)
        # The residual connection helps with gradient flow and training stability
        
        # =============================================================================
        # FEED-FORWARD PATH - Computation and Transformation
        # =============================================================================
        x = x + self.feed_forward(self.ln2(x))  # Residual connection around feed-forward
        # 1. Normalize the attention output (ln2)
        # 2. Apply feed-forward network
        # 3. Add residual connection (x + ...)
        # This allows the model to learn complex transformations
        
        return x


In [None]:
# =============================================================================
# FORWARD LAYER - Feed-Forward Network for Complex Processing
# =============================================================================
# This section defines the feed-forward network that processes each token
# independently. It increases the computational complexity and allows the
# model to learn complex transformations and patterns.

class ForwardLayer(nn.Module):
    """
    Feed-Forward Network Layer
    
    The ForwardLayer applies a network that increases the computational complexity
    of the processing. It processes each token independently and applies non-linear
    transformations to the attention output.
    
    Architecture:
    - Linear layer: embed_size → 6*embed_size (expansion)
    - GELU activation: Non-linear transformation
    - Linear layer: 6*embed_size → embed_size (projection)
    - Dropout: Regularization to prevent overfitting
    """
    
    def __init__(self, embed_size):
        super().__init__()
        
        # =============================================================================
        # FEED-FORWARD NETWORK ARCHITECTURE - Multi-Layer Processing
        # =============================================================================
        self.network = nn.Sequential(
            # First linear layer: Expand the embedding dimension
            nn.Linear(embed_size, 6*embed_size, bias=BIAS),
            # Expands from embed_size to 6*embed_size (e.g., 384 → 2304)
            # This gives the network more capacity for complex transformations
            # The 6x expansion is a common choice in transformer architectures
            
            # GELU activation function
            nn.GELU(),
            # Gaussian Error Linear Unit - smooth, non-linear activation
            # GELU(x) = x * Φ(x) where Φ is the standard normal CDF
            # More smooth than ReLU, often performs better in transformers
            # Allows for negative values, which can be beneficial
            
            # Second linear layer: Project back to original dimension
            nn.Linear(6*embed_size, embed_size, bias=BIAS),
            # Projects back from 6*embed_size to embed_size (e.g., 2304 → 384)
            # This maintains the same output dimension as input
            # The network learns to compress the expanded representation
            
            # Dropout for regularization
            nn.Dropout(dropout)
            # Randomly sets some neurons to zero during training
            # Prevents overfitting and improves generalization
            # Only active during training, not during inference
        )
    
    def forward(self, x):
        """
        Forward pass through the feed-forward network.
        
        Args:
            x (torch.Tensor): Input embeddings of shape (batch_size, sequence_length, embed_size)
            
        Returns:
            torch.Tensor: Processed embeddings of the same shape
        """
        x = self.network(x)  # Apply the entire feed-forward network
        return x

In [None]:
# =============================================================================
# MULTI-HEAD ATTENTION LAYER - Parallel Attention Processing
# =============================================================================
# This section defines the Multi-Head Attention mechanism, which is the core
# of the transformer architecture. It allows the model to attend to different
# parts of the sequence simultaneously and learn complex relationships.

class Multihead(nn.Module):
    """
    Multi-Head Attention Layer
    
    This layer coordinates multiple attention heads within each transformer block.
    Each head can focus on different types of relationships, allowing the model
    to capture various patterns and dependencies in the sequence.
    
    Architecture:
    - Multiple attention heads (parallel processing)
    - Linear combination of head outputs
    - Dropout for regularization
    """
    
    def __init__(self, n_heads, head_size):
        super().__init__()
        
        # =============================================================================
        # ATTENTION HEADS - Parallel Processing Units
        # =============================================================================
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_heads)])
        # Setup multiple attention heads
        # head_size = embed_size // n_heads (e.g., 384 // 7 ≈ 54)
        # Each head processes a subset of the embedding dimensions
        # Multiple heads allow the model to focus on different types of relationships
        
        # =============================================================================
        # OUTPUT COMBINATION - Merge Head Results
        # =============================================================================
        self.combine = nn.Linear(head_size * n_heads, embed_size, bias=BIAS)
        # Linear layer to combine outputs from all heads
        # Input: head_size * n_heads (e.g., 54 * 7 = 378)
        # Output: embed_size (e.g., 384)
        # Projects the concatenated head outputs back to the original embedding size
        
        # =============================================================================
        # REGULARIZATION - Prevent Overfitting
        # =============================================================================
        self.dropout = nn.Dropout(dropout)
        # Dropout applied to the final output
        # Helps prevent overfitting during training

    def forward(self, x):
        """
        Forward pass through multi-head attention.
        
        Args:
            x (torch.Tensor): Input embeddings of shape (batch_size, sequence_length, embed_size)
            
        Returns:
            torch.Tensor: Attention output of shape (batch_size, sequence_length, embed_size)
        """
        
        # =============================================================================
        # PARALLEL ATTENTION PROCESSING - Multiple Heads Working Together
        # =============================================================================
        # x is (BS, SL, embed_size)  # e.g., (8, 512, 384)
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        # Each head processes the input independently
        # Each head outputs (BS, SL, head_size)
        # Concatenating all heads produces (BS, SL, head_size * n_heads)
        # Example: (8, 512, 378) where 378 = 54 * 7 heads
        
        # =============================================================================
        # OUTPUT PROJECTION - Combine and Transform
        # =============================================================================
        x = self.combine(x)  # Project back to embed_size (BS, SL, embed_size)
        # Linear transformation to combine all head outputs
        # Projects from (head_size * n_heads) back to embed_size
        # This allows the model to learn how to combine information from different heads
        
        # =============================================================================
        # REGULARIZATION - Apply Dropout
        # =============================================================================
        x = self.dropout(x)  # Apply dropout for regularization
        # Randomly sets some outputs to zero during training
        # Helps prevent overfitting and improves generalization
        
        return x

In [None]:
# =============================================================================
# ATTENTION HEAD - The Core of Self-Attention Mechanism
# =============================================================================
# This section defines the individual attention head, which is the fundamental
# building block of the transformer architecture. It implements the self-attention
# mechanism that allows tokens to attend to each other and learn relationships.

class Head(nn.Module):
    """
    Single Attention Head - Self-Attention Mechanism
    
    Detects and reinforces patterns in relationships between members of sequence.
    Each head implements the core self-attention mechanism with Query, Key, and Value
    projections, allowing tokens to attend to each other and learn dependencies.
    
    Architecture:
    - Query, Key, Value projections
    - Scaled dot-product attention
    - Causal masking (prevents looking at future tokens)
    - Dropout for regularization
    """
    
    def __init__(self, head_size):
        super().__init__()
        
        # =============================================================================
        # QUERY, KEY, VALUE PROJECTIONS - The Foundation of Attention
        # =============================================================================
        self.queries = nn.Linear(embed_size, head_size, bias=BIAS)  # Query Projection
        # Projects input embeddings to query vectors
        # Shape: (embed_size, head_size) e.g., (384, 54)
        # Queries represent "what am I looking for?"
        
        self.keys = nn.Linear(embed_size, head_size, bias=BIAS)  # Key Projection
        # Projects input embeddings to key vectors
        # Shape: (embed_size, head_size) e.g., (384, 54)
        # Keys represent "what can I provide?"
        
        self.values = nn.Linear(embed_size, head_size, bias=BIAS)  # Value Projection
        # Projects input embeddings to value vectors
        # Shape: (embed_size, head_size) e.g., (384, 54)
        # Values represent "what information do I contain?"
        
        # =============================================================================
        # CAUSAL MASKING - Prevent Looking at Future Tokens
        # =============================================================================
        # We declare a triangular matrix that we will use to mask future tokens
        # self.tril contains 0s in upper triangle and 1s in lower triangle + diagonal
        self.register_buffer('tril', torch.tril(torch.ones(context, context)))
        # Shape: (context, context) e.g., (512, 512)
        # Lower triangular matrix: 1s below diagonal, 0s above
        # This ensures each token can only attend to previous tokens (causal attention)
        # Essential for autoregressive language modeling
        
        # =============================================================================
        # REGULARIZATION - Prevent Overfitting
        # =============================================================================
        self.dropout = nn.Dropout(dropout)
        # Dropout applied to attention weights
        # Helps prevent overfitting during training

    def forward(self, x):
        """
        Forward pass through the attention head.
        
        Args:
            x (torch.Tensor): Input embeddings of shape (batch_size, sequence_length, embed_size)
            
        Returns:
            torch.Tensor: Attention output of shape (batch_size, sequence_length, head_size)
        """
        
        BS, SL, VS = x.shape  # Batch Size, Sequence Length, Embedding Size
        
        # =============================================================================
        # QUERY, KEY, VALUE COMPUTATION - Project Input Embeddings
        # =============================================================================
        q = self.queries(x)  # (BS, SL, head_size) e.g., (8, 512, 54)
        # Compute query vectors for each token
        # Each token gets a query vector representing what it's looking for
        
        k = self.keys(x)  # (BS, SL, head_size) e.g., (8, 512, 54)
        # Compute key vectors for each token
        # Each token gets a key vector representing what it can provide
        
        v = self.values(x)  # (BS, SL, head_size) e.g., (8, 512, 54)
        # Compute value vectors for each token
        # Each token gets a value vector containing its information
        
        # =============================================================================
        # ATTENTION WEIGHTS COMPUTATION - Scaled Dot-Product Attention
        # =============================================================================
        # Calculate attention weights matrix with dot product of q and k, and normalize
        attn_w = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5  # (BS, SL, SL)
        # q @ k.transpose(-2, -1): Dot product between queries and keys
        # Shape: (BS, SL, SL) - attention weights for each token pair
        # k.shape[-1]**-0.5: Scaling factor (1/√head_size) for stability
        # This prevents attention weights from becoming too large
        
        # =============================================================================
        # CAUSAL MASKING - Apply Causal Attention
        # =============================================================================
        # Mask out future tokens, pay attention only to the past
        attn_w = attn_w.masked_fill(self.tril[:SL, :SL] == 0, float('-inf'))
        # Set attention weights to -inf for future tokens (upper triangle)
        # This ensures each token can only attend to previous tokens
        # Essential for autoregressive generation
        
        # =============================================================================
        # SOFTMAX NORMALIZATION - Convert to Probabilities
        # =============================================================================
        attn_w = F.softmax(attn_w, dim=-1)  # Transform into probabilities (BS, SL, SL)
        # Apply softmax to get attention probabilities
        # Each row sums to 1 (probability distribution over all tokens)
        # Higher values indicate stronger attention
        
        # =============================================================================
        # DROPOUT REGULARIZATION - Prevent Overfitting
        # =============================================================================
        attn_w = self.dropout(attn_w)  # (BS, SL, SL)
        # Apply dropout to attention weights
        # Randomly sets some attention weights to zero
        # Helps prevent overfitting and improves generalization
        
        # =============================================================================
        # ATTENTION OUTPUT - Weighted Combination of Values
        # =============================================================================
        # Use attention weights to update the features of our tokens
        x = attn_w @ v  # (BS, SL, head_size)
        # Matrix multiplication: attention weights × values
        # Each token gets a weighted combination of all value vectors
        # The attention weights determine how much each token contributes
        # This is the final output of the attention mechanism
        
        return x

In [None]:
# =============================================================================
# MODEL INSTANTIATION AND SETUP - Create and Configure Our LLM
# =============================================================================
# This section creates our GPT model instance and configures it for training.
# We set up the model with the right data types, device placement, and
# optional compilation for optimal performance.

# =============================================================================
# MODEL CREATION - Instantiate Our GPT Model
# =============================================================================
model = GPT()  # Create a new GPT model instance
# This initializes all the layers we defined:
# - Token and position embeddings
# - Multiple transformer blocks
# - Final linear layer
# - All with proper weight initialization

# =============================================================================
# DATA TYPE CONFIGURATION - Set Precision for Training
# =============================================================================
model = model.to(dtype)  # Set the precision type
# Convert model to the specified data type (e.g., bfloat16)
# This affects memory usage and training speed
# bfloat16 uses less memory and can be faster on modern GPUs

# =============================================================================
# DEVICE PLACEMENT - Move Model to GPU/CPU
# =============================================================================
model = model.to(device)  # Move model to the appropriate device
# Move all model parameters and buffers to GPU or CPU
# GPU is much faster for training deep neural networks
# All computations will now happen on the specified device

# =============================================================================
# MODEL COMPILATION - Optional Performance Optimization
# =============================================================================
# Torch.compile compiles a PyTorch model to an optimized version,
# aiming to improve runtime performance and efficiency.
# This can provide 20-30% speedup on compatible systems
if compile:
    print("Torch :: Compiling model")
    model = torch.compile(model)
    # Compiles the model for faster execution
    # Requires PyTorch 2.0+ and compatible hardware
    # May take some time on first run but speeds up subsequent runs

# =============================================================================
# MODEL PARAMETER COUNT - Display Model Size
# =============================================================================
# Print the number of parameters of our model (19 million in our case)
print(sum(p.numel() for p in model.parameters()) / 1e6, " Million parameters")
# Count total number of trainable parameters
# numel() returns the number of elements in each parameter tensor
# Sum all parameters and convert to millions
# This gives us an idea of model complexity and memory requirements

In [None]:
# =============================================================================
# LOSS CALCULATION - Evaluate Model Performance
# =============================================================================
# This section defines a function to calculate the model's loss on both
# training and validation data. This is essential for monitoring training
# progress and detecting overfitting.

@torch.no_grad()  # Prevent gradient calculation during evaluation
def calculate_loss():
    """
    Calculate the model's loss on training and validation data.
    
    This function evaluates the model's performance without updating weights.
    It's used to monitor training progress and detect overfitting.
    
    Returns:
        dict: Dictionary containing 'train' and 'eval' loss values
    """
    
    out = {}  # Dictionary to store results
    model.eval()  # Set model to evaluation mode
    
    # =============================================================================
    # EVALUATION LOOP - Test on Both Training and Validation Data
    # =============================================================================
    for split in ['train', 'eval']:
        # Create tensor to store loss values
        l = torch.zeros(eval_iters)  # Create a tensor of zeros the size of eval_iters
        
        # =============================================================================
        # BATCH EVALUATION - Calculate Loss on Multiple Batches
        # =============================================================================
        for i in range(eval_iters):
            x, y = get_batch(split)  # Get a new batch of data
            _, loss = model(x, y)  # Calculate the loss
            l[i] = loss  # Store the loss in the next position of tensor
            
        # =============================================================================
        # LOSS AVERAGING - Get Mean Loss Across Batches
        # =============================================================================
        out[split] = l.mean().item()  # Calculate the mean and extract the final value
        # Average the loss across all evaluation batches
        # This gives us a more stable estimate of model performance
        
    # =============================================================================
    # RESTORE TRAINING MODE - Switch Back to Training
    # =============================================================================
    model.train()  # Set model back to training mode
    # This ensures dropout and other training-specific behaviors are restored
    
    return out

# =============================================================================
# INITIAL LOSS CALCULATION - Baseline Performance
# =============================================================================
l = calculate_loss()  # Calculate initial loss
print(l)  # Display the results
# This shows us the model's performance before training
# Training loss should be high initially (random predictions)
# Validation loss should be similar to training loss (no overfitting yet)

In [None]:
# =============================================================================
# TEXT GENERATION - Create New Text with Our Model
# =============================================================================
# This section defines a function to generate new text using our trained model.
# It demonstrates the model's ability to create coherent text by sampling
# from the learned probability distributions.

@torch.no_grad()  # Disable gradient calculation during generation
def generate_sample(input):
    """
    Generate new text using the trained model.
    
    Args:
        input (str): Starting text prompt for generation
    """
    
    # =============================================================================
    # INPUT TOKENIZATION - Convert Text to Numbers
    # =============================================================================
    t1 = torch.tensor(encode(input), dtype=torch.long, device=device)
    # Tokenize the input string into token IDs
    # encode(input): Convert text to list of token IDs
    # torch.tensor(): Convert to PyTorch tensor
    # dtype=torch.long: Ensure integer type for token IDs
    # device=device: Move to the same device as the model
    
    t1 = t1[None, :]  # (1, [size of ids])
    # Add batch dimension: (sequence_length,) → (1, sequence_length)
    # The model expects batched input
    
    # =============================================================================
    # TEXT GENERATION - Sample from the Model
    # =============================================================================
    newgen = model.generate(t1, max=64)[0].tolist()
    # model.generate(): Generate new tokens
    # max=64: Limit generation to 64 tokens
    # [0]: Get the first (and only) sequence from the batch
    # .tolist(): Convert tensor to Python list
    
    # =============================================================================
    # OUTPUT DECODING - Convert Numbers Back to Text
    # =============================================================================
    result = decode(newgen)  # Decode token IDs back to text
    # decode(): Convert list of token IDs to human-readable text
    # This gives us the final generated text
    
    print(f"{result}")  # Display the generated text

# =============================================================================
# SAMPLE GENERATION - Test the Model
# =============================================================================
generate_sample("The mountain in my city is")  # Generate a sample
# This demonstrates the model's text generation capabilities
# The model will try to complete the sentence based on its training
# Initially, the output may be random or incoherent
# As training progresses, the output should become more coherent and relevant

In [None]:
# =============================================================================
# OPTIMIZER AND SCHEDULER SETUP - Configure Training Optimization
# =============================================================================
# This section sets up the optimizer and learning rate scheduler for training.
# We use different weight decay settings for different parameter types to
# optimize training performance and prevent overfitting.

# =============================================================================
# PARAMETER GROUPING - Different Treatment for Different Parameters
# =============================================================================
# Set Weight Decay differently for different kinds of parameters
# parameter dictionary where keys are parameter names, and values are the parameter themselves
p_dict = {p_name: p for p_name, p in model.named_parameters() if p.requires_grad}
# Get all trainable parameters from the model
# len: 370 (total number of parameter tensors)

# =============================================================================
# WEIGHT MATRICES - Parameters that Benefit from Weight Decay
# =============================================================================
# Isolate weight matrices as they benefit specially from weight decay
weight_decay_p = [p for n, p in p_dict.items() if p.dim() >= 2]  # len: 171
# Parameters with 2 or more dimensions (weight matrices)
# These include linear layer weights, embedding weights, etc.
# Weight decay helps prevent these from becoming too large

# =============================================================================
# BIAS AND OTHER PARAMETERS - Parameters that Don't Need Weight Decay
# =============================================================================
# Isolate other parameters like bias parameters, that don't benefit from weight decay
no_weight_decay_p = [p for n, p in p_dict.items() if p.dim() < 2]  # len: 199
# Parameters with less than 2 dimensions (bias terms, scalars)
# These don't benefit from weight decay and can be left unregularized

# =============================================================================
# OPTIMIZER GROUPS - Different Settings for Different Parameters
# =============================================================================
# Store the parameter types in a list of dictionaries
optimizer_groups = [
    {'params': weight_decay_p, 'weight_decay': weight_decay},  # Apply weight decay
    {'params': no_weight_decay_p, 'weight_decay': 0.0}        # No weight decay
]
# This allows us to apply different regularization to different parameter types

# =============================================================================
# ADAMW OPTIMIZER - Advanced Gradient Descent
# =============================================================================
# Declare optimizer, it helps us compute gradients, update parameters, manage learning rate, apply weight decay
optimizer = torch.optim.AdamW(optimizer_groups, lr=lr, betas=(0.9, 0.99))
# AdamW: Adam optimizer with decoupled weight decay
# lr: Learning rate for parameter updates
# betas: Control the exponential moving averages of the gradient and its square
# (0.9, 0.99): Standard values for momentum and squared gradient momentum
# These are essential components of the Adam and AdamW optimization algorithms

# =============================================================================
# LEARNING RATE SCHEDULER - Dynamic Learning Rate
# =============================================================================
# Declare scheduler to change learning rate through the training
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_iters, eta_min=lr/10)
# CosineAnnealingLR: Cosine annealing learning rate schedule
# train_iters: Total number of training iterations
# eta_min: Minimum learning rate (lr/10)
# Learning rate will descend from lr to lr/10 following a cosine curve
# This helps with convergence and can improve final performance

# =============================================================================
# TRAINING STATE INITIALIZATION - Track Training Progress
# =============================================================================
start_iteration = 0  # Starting iteration number
best_val_loss = float('inf')  # Track best validation loss value
# Initialize with infinity so any improvement will be recorded
# This helps us save the best model during training


In [None]:
# =============================================================================
# CHECKPOINT LOADING - Resume Training from Saved State
# =============================================================================
# This section provides functionality to load previously saved model checkpoints.
# This allows us to resume training from where we left off or load a
# pre-trained model for further training.

# =============================================================================
# CHECKPOINT LOADING FUNCTION - Restore Model State
# =============================================================================
def load_checkpoint(path):
    """
    Load a previously saved checkpoint.
    
    Args:
        path (str): Path to the checkpoint file
        
    Returns:
        tuple: (iteration, loss) - The iteration and loss from the checkpoint
    """
    
    print("LLM - Loading model")
    checkpoint = torch.load(path)  # Load the checkpoint file
    
    # =============================================================================
    # MODEL STATE RESTORATION - Load Model Parameters
    # =============================================================================
    model.load_state_dict(checkpoint['model_state_dict'])  # Load model parameters
    # Restore all model weights and biases to their saved values
    # This brings the model back to the exact state when it was saved
    
    # =============================================================================
    # OPTIMIZER STATE RESTORATION - Load Optimizer State
    # =============================================================================
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # Load optimizer state
    # Restore optimizer internal state (momentum, learning rate, etc.)
    # This ensures training continues smoothly from where it left off
    
    # =============================================================================
    # TRAINING STATE RESTORATION - Load Training Progress
    # =============================================================================
    iteration = checkpoint['iteration']  # In what iteration did we save the model?
    loss = checkpoint['loss']  # What was the last loss value?
    
    print(f"Loaded iter {iteration} with loss {loss}")
    return iteration, loss

# =============================================================================
# OPTIONAL CHECKPOINT LOADING - Resume from Previous Training
# =============================================================================
# Load a previous checkpoint if it exists and load_pretrained is True
if os.path.exists(f"{checkpoint_dir}/{checkpoint_load_fn}") and load_pretrained:
    start_iteration, loss = load_checkpoint(checkpoint_dir + checkpoint_load_fn)
    best_val_loss = loss
    # Update the starting iteration and best validation loss
    # This allows us to continue training from where we left off
    # Useful for resuming interrupted training sessions

In [None]:
# =============================================================================
# INFERENCE MODE - Interactive Text Generation
# =============================================================================
# This section provides an interactive interface for testing the model's
# text generation capabilities. It allows users to input prompts and see
# how the model responds, which is useful for evaluating model performance.

# =============================================================================
# INFERENCE MODE ACTIVATION - Interactive Testing
# =============================================================================
if inference == True:
    model.eval()  # Set model to evaluation mode
    # Disable dropout and other training-specific behaviors
    # This ensures consistent generation behavior
    
    # =============================================================================
    # INTERACTIVE LOOP - Continuous Text Generation
    # =============================================================================
    while True:
        qs = input("Enter text (q to quit) >>> ")  # Get user input
        
        # =============================================================================
        # INPUT VALIDATION - Handle Empty Inputs
        # =============================================================================
        if qs == "":  # Skip empty inputs
            continue
            
        # =============================================================================
        # EXIT CONDITION - Quit the Interactive Loop
        # =============================================================================
        if qs == 'q':  # Exit when user types 'q'
            break
            
        # =============================================================================
        # TEXT GENERATION - Generate Response to User Input
        # =============================================================================
        generate_sample(qs)  # Generate text based on user input
        # This demonstrates the model's ability to complete or continue text
        # The model will try to generate coherent text based on the input prompt
        # This is useful for testing the model's understanding and generation quality

In [None]:
# =============================================================================
# MAIN TRAINING LOOP - The Heart of Model Learning
# =============================================================================
# This section contains the main training loop that teaches our model to
# understand and generate text. It's the most critical part of the entire
# process, where the model learns from data and improves its performance.

# =============================================================================
# TRAINING EXECUTION - Learn from Data
# =============================================================================
try:
    # =============================================================================
    # TRAINING ITERATION LOOP - Process Each Batch
    # =============================================================================
    for i in tqdm(range(start_iteration, train_iters)):
        # =============================================================================
        # BATCH PREPARATION - Get Training Data
        # =============================================================================
        xb, yb = get_batch("train")  # Get a new batch of training data
        # xb: Input sequences (batch_size, sequence_length)
        # yb: Target sequences (batch_size, sequence_length)
        
        # =============================================================================
        # FORWARD PASS - Run the Model
        # =============================================================================
        logits, loss = model(xb, yb)  # Run the LLM and get the logits and the loss
        # logits: Model predictions for each token position
        # loss: Cross-entropy loss between predictions and targets
        
        # =============================================================================
        # EVALUATION - Monitor Training Progress
        # =============================================================================
        if (i % eval_interval == 0 or i == train_iters - 1):  # Calculate the loss
            l = calculate_loss()  # Evaluate on both training and validation data
            print(f"\n{i}: train loss: {l['train']} / val loss: {l['eval']}")
            
            # =============================================================================
            # TEXT GENERATION TEST - Observe Model Evolution
            # =============================================================================
            # We do a quick test so that we observe the evolution through the training
            # Remember that we use a very small dataset which doesn't include all topics
            generate_sample("The mountain in my city is")  # Generate a sample
            # This shows how the model's text generation improves over time
            # Initially random, gradually becoming more coherent
            
            # =============================================================================
            # CHECKPOINT SAVING - Save Best Model
            # =============================================================================
            if l['eval'] < best_val_loss:  # If we improved the best loss, save a checkpoint
                best_val_loss = l['eval']
                print("[CHECKPOINT]: Saving with loss: ", best_val_loss)
                torch.save({
                    'model_state_dict': model.state_dict(),      # Model parameters
                    'optimizer_state_dict': optimizer.state_dict(),  # Optimizer state
                    'loss': best_val_loss,                      # Best validation loss
                    'iteration': i,                             # Current iteration
                }, checkpoint_dir + checkpoint_fn)
                # Save the model when validation loss improves
                # This ensures we keep the best performing model
            
            # =============================================================================
            # WANDB LOGGING - Track Training Metrics
            # =============================================================================
            if wandb_log:
                wandb.log({
                    "loss/train": l['train'],      # Training loss
                    "loss/val": l['eval'],         # Validation loss
                    "lr": scheduler.get_last_lr()[0],  # Current learning rate
                }, step=i)
                # Log metrics to Weights & Biases for visualization
                # This helps track training progress and identify issues
        
        # =============================================================================
        # BACKWARD PASS - Compute Gradients
        # =============================================================================
        optimizer.zero_grad(set_to_none=True)  # Reset gradients
        # Clear gradients from previous iteration
        # set_to_none=True is more memory efficient
        
        loss.backward()  # Calculate new gradients
        # Compute gradients of loss with respect to all parameters
        # This tells us how to adjust each parameter to reduce the loss
        
        # =============================================================================
        # GRADIENT CLIPPING - Prevent Exploding Gradients
        # =============================================================================
        # This line clips the gradients to prevent the exploding gradient problem during training.
        # Exploding gradients can occur when gradients become too large, causing unstable updates to model weights.
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
        # Clip gradients to maximum norm of grad_clip
        # This prevents gradients from becoming too large and destabilizing training
        
        # =============================================================================
        # PARAMETER UPDATE - Apply Gradients
        # =============================================================================
        optimizer.step()  # Update the model parameters
        # Apply the computed gradients to update model weights
        # This is where the actual learning happens
        
        scheduler.step()  # Update the learning rate value
        # Adjust learning rate according to the schedule
        # This helps with convergence and can improve final performance
    
    # =============================================================================
    # TRAINING COMPLETION - Finish Logging
    # =============================================================================
    if wandb_log:
        wandb.finish()  # Close wandb session
        # Finalize experiment tracking

# =============================================================================
# EXCEPTION HANDLING - Graceful Training Interruption
# =============================================================================
except KeyboardInterrupt:
    print("Training interrupted. Cleaning up...")
    # Handle Ctrl+C interruption gracefully
    # This allows users to stop training without losing progress

# =============================================================================
# CLEANUP - Release Resources
# =============================================================================
finally:
    # Release GPU memory
    torch.cuda.empty_cache()
    print("GPU memory released.")
    # Clear GPU memory to free up resources
    # This is important for preventing memory leaks

# =============================================================================
# FINAL CLEANUP - Ensure All Resources are Released
# =============================================================================
if wandb_log:   
    wandb.finish()  # Ensure wandb is properly closed
torch.cuda.empty_cache()  # Final GPU memory cleanup

# =============================================================================
# CREDITS - Acknowledgment
# =============================================================================
# Code designed by Javier ideami
# ideami.com
