In [None]:
# LLM Alignment with ORPO Technique - Notebook created by Javier Ideami (ideami.com)
# ORPO: Monolithic Preference Optimization without Reference Model (https://arxiv.org/abs/2403.07691)

## What is this notebook about?
# This notebook demonstrates how to align a Large Language Model (LLM) using the ORPO (Odds Ratio Preference Optimization) technique.
# ORPO is a modern approach to AI alignment that teaches models to prefer helpful, harmless, and honest responses over harmful ones.

## Why do we need LLM alignment?
# LLM Alignment needs to be applied to an LLM that is a bit more powerful than the basic one we trained earlier.
# That's why I am providing the checkpoint of an already trained more powerful LLM (138 million parameters vs 19 million of ours)
# so that we can apply ORPO alignment on top of it.
# The pretrained 138 million parameter model has been pre-trained on the open source Fineweb-Edu dataset.

## What makes ORPO special?
# Traditional alignment techniques like RLHF (Reinforcement Learning from Human Feedback) take a long time and money to be implemented. 
# New techniques like ORPO simplify alignment a lot.
# We will be able to train ORPO alignment with 1 single GPU and little GPU memory (around 4 gigabytes)  

## What to expect from this tutorial?
# Of course results will be imperfect because we are still applying ORPO on top of a model that is small and has been trained with much
# less data than the large models out there, and for a much shorter amount of time. But it will be enough to compare the before and after,
# and to see how ORPO improves the way the LLM communicates in many occasions.
# And we will understand every part of the code that makes it possible.

## License Information
# This file includes code from the open source ORPO repository licensed under the Apache License 2.0.
# See licenses/orpo-license for details.
# Modifications: Variable names have been changed for educational purposes.

# This file also uses the ORPO-DPO-mix-40k dataset licensed under the Apache License 2.0.
# See licenses/orpo-dpo-mix-40k-license for details.

# Official notebook 

In [None]:
## 🚨 IMPORTANT: GPU Requirements
#### For GOOGLE COLAB and similar platform Users:
#### Make sure to select a GPU in the online platform. Don't run this code with a CPU (it will be too slow)

# If you are running this code locally, your GPU should be selected automatically

## Why do we need a GPU?
# - ORPO training involves complex mathematical operations on large neural networks
# - CPU processing would take hours or days to complete
# - GPU parallel processing reduces training time from hours to minutes
# - We need at least 4GB of GPU memory for this tutorial

In [None]:
# 📦 INSTALLATION SECTION
# Uncomment and run the following installation lines ONLY if you haven't installed these libraries already outside of the notebook

# Core libraries for debugging and progress tracking
#!pip install -q ipdb        # Interactive Python debugger for troubleshooting
#!pip install -q tqdm        # Progress bars for training loops

# Hugging Face ecosystem for datasets and models
#!pip install -q datasets    # Library for loading and processing datasets
#!pip install -q transformers # Hugging Face transformers library for model handling

# Experiment tracking and visualization
#!pip install -q wandb       # Weights & Biases for experiment logging and visualization

# PyTorch installation (if not already installed)
# And if you are not in Google Colab and you didn't yet install Pytorch, make sure to do it:
# Find the ideal pytorch installation command at https://pytorch.org/get-started/locally/

# Optional performance optimization
# In addition, you may try to install the flash-attn library, although at the moment it is having issues
# Flash Attention can speed up training but may have compatibility issues


In [None]:
# 🔍 GPU MEMORY CHECK
# This command shows your GPU information and available memory
# Make sure that you have at least 4GB of free GPU memory to run this tutorial successfully

!nvidia-smi 

# 📋 What to look for in the output:
# - GPU Name: Should show your GPU model (e.g., Tesla T4, V100, A100)
# - Memory Usage: Should show available memory (we need at least 4GB free)
# - Driver Version: Should be compatible with PyTorch

# If you are using Google Colab or a similar online platform, make sure to select a GPU in the menus
# In Google Colab, the option is within the Runtime → Change runtime type → Hardware accelerator → GPU

In [None]:
# Download necessary files and create necessary folders
# llm.py - llm model: an llm architecture that is more powerful
# models folder: pretrained checkpoints for the base and the aligned model
# data folder: tokenized dataset stored in this folder
# tokenizers folder: pretrained tokenizer on the large dataset FineWeb-Edu

# NOTE: Downloading will take a while, be patient. You can refresh your folder from time to time to see when the files
# have been created. If you have any problems downloading the files with this code, I have also added llm_align.zip
# to the downloadable resources of this lecture (however, best option is to use this code, because then you don't need
# to upload the files or do anything else)

import os, requests, zipfile, io 

files_url = "https://ideami.com/llm_align"

# Downloading proceeds if we detect that one of the key files to download is not present
if not os.path.exists(f"llm.py"):
    print("Downloading files using Python")
    response = requests.get(files_url)
    zipfile.ZipFile(io.BytesIO(response.content)).extractall(".")
else:
    print("you seem to have already downloaded the files. If you wish to re-download them, delete the llm.py file")



In [None]:
### 📚 IMPORT NECESSARY LIBRARIES

# Standard Python libraries for file operations and system utilities
import os          # File and directory operations
import sys         # System-specific parameters and functions
import math        # Mathematical functions (needed for learning rate scheduling)
from tqdm import tqdm        # Progress bars for training loops
from datetime import datetime # Timestamp generation for experiment tracking
import ipdb        # Interactive debugger for troubleshooting

# Type hints for better code documentation and IDE support
from typing import List, Dict, Union

# 🧠 PYTORCH - Core deep learning framework
import torch           # Main PyTorch library
import torch.nn as nn  # Neural network modules and layers
from torch.nn import functional as F  # Functional interface for neural network operations

# 🤗 HUGGING FACE - Pre-trained models and datasets
import transformers    # Hugging Face transformers library for model handling
from datasets import load_dataset, load_from_disk  # Dataset loading utilities

# ⚡ PERFORMANCE OPTIMIZATIONS
# These lines improve performance for Ampere Architecture GPUs (e.g: A100s, RTX 30/40 series)
torch.backends.cuda.matmul.allow_tf32 = True  # Allow tf32 precision for matrix multiplications (faster)
torch.backends.cudnn.allow_tf32 = True        # Allow tf32 precision for cuDNN operations (faster)
# Empty GPU cache memory to start with a clean slate
torch.cuda.empty_cache()

# 🔧 DEBUGGING SETTINGS
# Optional settings for debugging - makes tensor printing more readable
torch.set_printoptions(threshold=10000)      # Show more tensor elements when printing
torch.set_printoptions(sci_mode=False, precision=2)  # Use decimal notation instead of scientific

In [None]:
# 🎛️ TRAINING PARAMETERS - Core settings that control the training process
batch_size = 1  # Number of examples processed together (you can change it to 2 if you have enough GPU memory)
epochs = 3      # Number of complete passes through the training dataset
lr = 6e-5       # Learning rate - how fast the model learns (6e-5 = 0.00006)
lr_warmup_steps = 100  # Number of steps to gradually increase learning rate from 0 to full value
context = 1024  # Maximum sequence length the model can handle (in tokens)
alpha = 0.5     # Weighting factor for the ORPO odds ratio (balances standard loss vs preference loss)
prompt_max_size = 512  # Maximum length for the prompt part (leaves room for the response)
compile = False # PyTorch compilation for performance (disable if your system doesn't support it)
dtype = torch.bfloat16  # Data type for calculations (bfloat16 is faster and uses less memory than float32)
log_iters = 100 # How often to print training progress and metrics

# 🔧 HYPERPARAMETERS - Advanced training settings
dropout = 0.    # Dropout rate (0 = no dropout, 1 = maximum dropout)
grad_clip = 1.0 # Maximum gradient norm (prevents exploding gradients)
weight_decay = 0.0  # L2 regularization strength (0 = no regularization)

# 💻 DEVICE CONFIGURATION - Choose between GPU and CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device: You will be using: ",device)

# 📊 What these parameters mean:
# - batch_size: Larger batches = more stable training but need more GPU memory
# - epochs: More epochs = better learning but risk of overfitting
# - lr: Higher learning rate = faster learning but risk of instability
# - context: Longer sequences = more context but need more memory
# - alpha: Higher alpha = more focus on preference learning vs standard language modeling

In [None]:
# 📊 EXPERIMENT LOGGING - Track training progress and metrics
project_name="test"  # Name for organizing your experiments
wandb_log = True    # Enable Weights & Biases logging (set to False to disable)
wandb_project = project_name  # Project name in W&B dashboard
wandb_run_name = "test-run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")  # Unique run name with timestamp

# Initialize Weights & Biases for experiment tracking
if wandb_log:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name)

# 🔑 W&B SETUP INSTRUCTIONS:
# The first time you run this logging code with wandb_log=True, the weights and biases library
# will ask you for an API key. You can follow the instructions in the video, or you can
# also simply click on a link that should appear when you run this cell, pointing to this
# address: https://wandb.ai/authorize  
# Going to that address will allow you to quickly get an API key as well

# 📈 What W&B tracks:
# - Training and validation losses over time
# - Learning rate schedules
# - Model performance metrics
# - GPU memory usage
# - Training speed and efficiency

In [None]:
# 📁 DATASET PATHS - Where to find and store our training data
dataset_path = "./data/orpo_dataset"  # Local directory where the processed ORPO dataset will be stored

# 🌐 HUGGING FACE DATASET
# This is a special dataset prepared for ORPO alignment training. It contains:
# - Paired examples of "chosen" (good) vs "rejected" (bad) responses
# - Human preferences for AI assistant responses
# - Available at: https://huggingface.co/datasets/mlabonne/orpo-dpo-mix-40k/blob/main/README.md
dataset_name = "mlabonne/orpo-dpo-mix-40k"  # HuggingFace dataset identifier

# 🔤 TOKENIZER AND MODEL PATHS
tokenizer_path = "tokenizers/tok16384"  # Path to the pre-trained tokenizer (converts text to numbers)
checkpoint_dir = './models/'           # Directory where model checkpoints are stored

# 📋 What we'll download:
# - llm.py: The neural network architecture code
# - models/: Pre-trained model weights (base and aligned versions)
# - data/: Processed training dataset
# - tokenizers/: Text tokenizer for converting words to numbers


In [None]:
# 🔤 TOKENIZER SETUP - Convert text to numbers that the model can understand
###########
# Load the pre-trained tokenizer (converts text to token IDs and vice versa)
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path)

# 💬 CHAT TEMPLATE CONFIGURATION
# Set up a template that formats conversations between user and assistant
# This template adds special tokens like <|user|>, <|assistant|> to mark different speakers
tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"

# 🔧 TOKENIZER CONFIGURATION
# Make padding token equal to the end of sentence token (which has id of 2 in this case)
# This ensures consistent padding behavior during training
tokenizer.pad_token = tokenizer.eos_token

# 🔄 DATASET LOADING STRATEGY
# If you want to debug the tokenization of the dataset, delete the orpo_dataset folder or change
# this dataset_path to something else like "./data/tmp"
# uncomment to activate the preprocessing of dataset
#dataset_path = "./data/tmp" # delete this folder to preprocess again

# 📂 SMART DATASET LOADING
# If dataset has already been processed and tokenized, we can load it directly from our disk
if os.path.exists(dataset_path):
    print("✅ Loading pre-processed dataset from disk (faster!)")
    dataset = load_from_disk(dataset_path)
# Otherwise, we will load the dataset from HuggingFace and then filter and tokenize it
else:
    print("🔄 Preprocessing and tokenizing dataset (this may take a while...)")
    dataset = load_dataset(dataset_name, split="all")

    # 🚫 CONTENT FILTERING
    # Optional: Filter out toxic entries to improve training quality
    # Without this filter: ~37,136 elements, with filter: ~36,622 elements
    dataset = dataset.filter(lambda r: r["source"] != "toxic-dpo-v0.2")

    # 🔍 DATASET FILTERING FUNCTION
    # This function eliminates entries that are too long to fit in our context window
    # We need prompt + answer to fit within the total context (1024 tokens)
    def filter_dataset(examples):
        # examples['chosen'][:-1] picks the prompt minus the final answer
        # This gives us just the conversation history without the response we want to predict
        prompt_length = tokenizer.apply_chat_template(examples['chosen'][:-1], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)  
        
        # Preserve only samples that have a prompt smaller than prompt_max_size (512 tokens)
        # This leaves room for the response within our 1024 token limit
        if prompt_length < prompt_max_size:    
            return True  # Keep this example
        else:
            return False  # Skip this example (too long)


    # 🔄 MAIN DATASET PREPROCESSING FUNCTION
    # This function converts raw text conversations into tokenized format for training
    # HF Tokenizer Dict Format: Encoding(num_tokens=1024, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])
    def preprocess_dataset(examples: Union[List, Dict]):
        # Process examples in batches of 1000 by default for efficiency
        
        # 📝 PROMPT EXTRACTION
        # Take chosen field, eliminate last answer, apply template adding assistant prompt
        # This creates the "prompt" - everything the user said plus conversation history
        prompt = [tokenizer.apply_chat_template(item[:-1], tokenize=False, add_generation_prompt=True) for item in examples['chosen']]
        
        # 💡 IMPORTANT: Multi-turn conversations
        # Some samples are multi-turn conversations. The prompt includes all interactions between user and assistant
        # until the last question. We remove the last answer and all previous interactions become the prompt.

        # ✅ CHOSEN RESPONSE (Good answer)
        # Take the chosen field (the preferred response), then apply chat template 
        chosen = [tokenizer.apply_chat_template(item, tokenize=False) for item in examples['chosen']]

        # ❌ REJECTED RESPONSE (Bad answer)
        # Take the rejected field (the dispreferred response), then apply chat template 
        rejected = [tokenizer.apply_chat_template(item, tokenize=False) for item in examples['rejected']]
        
        # 🔢 TOKENIZATION - Convert text to numbers
        # Tokenize the prompt (conversation history)
        inputs = tokenizer(prompt, max_length=context, padding='max_length', truncation=True, return_tensors='pt')
        # inputs will have dict fields: input_ids and attention_mask (e.g: inputs.input_ids[0])
        # You can test by decoding: tokenizer.decode(inputs.input_ids[0])

        # 📏 PADDING EXPLANATION
        # All elements will have the same length of 1024 tokens
        # Extra padding tokens will be added to reach 1024 for shorter sequences

        # ✅ TOKENIZE CHOSEN RESPONSE (Good answer)
        pos_labels = tokenizer(chosen, max_length=context, padding='max_length', truncation=True, return_tensors='pt')
        # pos_labels will have dict fields: input_ids and attention_mask (e.g: pos_labels.input_ids[0])
        # You can test by decoding: tokenizer.decode(pos_labels.input_ids[0])
        
        # ❌ TOKENIZE REJECTED RESPONSE (Bad answer)
        neg_labels = tokenizer(rejected, max_length=context, padding='max_length', truncation=True, return_tensors='pt') 
        # Same structure as pos_labels above

        # 🔗 COMBINE ALL DATA INTO SINGLE DICTIONARY
        # Add the chosen-positive and rejected-negative response ids and masks to the prompt data
        # This creates a unified structure containing all the information we need for ORPO training
        inputs['positive_input_ids'] = pos_labels['input_ids']        # Token IDs for good responses
        inputs['positive_attention_mask'] = pos_labels['attention_mask']  # Attention mask for good responses

        inputs['negative_input_ids'] = neg_labels['input_ids']         # Token IDs for bad responses  
        inputs['negative_attention_mask'] = neg_labels['attention_mask']  # Attention mask for bad responses

        return inputs  # Return the complete dataset entry
    
    # 🔍 APPLY FILTERING
    # Filter dataset to exclude prompts that are too long for our context window
    dataset = dataset.filter(filter_dataset)

    # 🔄 APPLY PREPROCESSING
    # Preprocess the dataset (tokenize and format all examples)
    # If you have issues with multiprocessing, make sure to use num_proc=1
    # Multiprocessing alternative: dataset = dataset.map(preprocess_dataset, batched=True, num_proc=min(32,os.cpu_count()), remove_columns=dataset.column_names)  
    dataset = dataset.map(preprocess_dataset, batched=True, num_proc=1, remove_columns=dataset.column_names) 
    # Processed in batches of 1000 by default for efficiency

    # 📊 FINAL DATASET STRUCTURE
    # As a result of the preprocessing, dataset variable will have all these internal fields:
    # Dataset({features: ['input_ids', 'token_type_ids', 'attention_mask', 'positive_input_ids', 'positive_attention_mask', 'negative_input_ids', 'negative_attention_mask'], num_rows: 39091})

    # 💾 SAVE PROCESSED DATASET
    # Save the processed dataset to disk for faster loading in future runs
    dataset.save_to_disk(dataset_path)    


In [None]:
# At this point, you can test some of the content of the dataset with for example these: 
# dataset[0]['input_ids']  /  dataset[0]['positive_input_ids']
# testing Ids to Text: tokenizer.decode(dataset[0]['positive_input_ids'])

# Split the data into train and validation, 5% for the validation set
dataset = dataset.shuffle(42).train_test_split(test_size=0.05)  

train_data = dataset["train"]
#Dataset({features: ['input_ids', 'token_type_ids', 'attention_mask', 'positive_input_ids', 'positive_attention_mask', 'negative_input_ids', 'negative_attention_mask'],num_rows: 37136})

val_data = dataset["test"]
#Dataset({features: ['input_ids', 'token_type_ids', 'attention_mask', 'positive_input_ids', 'positive_attention_mask', 'negative_input_ids', 'negative_attention_mask'],num_rows: 1955})

# Data_collator efficiently prepares your training and validation data for language modeling by batching, padding (optional), and performing masking (optional) according to your configuration
data_collator = transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Setup DataLoaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, collate_fn=data_collator, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False, collate_fn=data_collator, num_workers=0)


In [None]:
# OPTIONAL: Test your DataLoaders
#batch = next(iter(train_loader))
#print(tokenizer.decode(batch['input_ids'][0]))  # this will show just the prompt, followed by padding tokens

# debug: batch['input_ids'] (shape is 1,1024)  
# debug: Ids to Text -> tokenizer.decode(batch['input_ids'][0]) - [0] because needs to address inside batch number dim



In [None]:
###############################################################
###############################################################
################### 🧠 SETUP MODEL ############################
###############################################################
###############################################################

# 🦙 IMPORT LLAMA-BASED MODEL ARCHITECTURE
# Import the Llama model class and configuration from our custom implementation
from llm import Llama, ModelArgs

# 📂 LOAD PRETRAINED MODEL CHECKPOINT
# Load the pre-trained model weights (138 million parameters)
# This model was trained on the FineWeb-Edu dataset and is ready for alignment
checkpoint = torch.load(os.path.join(checkpoint_dir, "base_model.pt"))

# ⚙️ EXTRACT MODEL CONFIGURATION
# Remove the config from checkpoint and store it separately
# The config contains all the architectural parameters (layers, heads, dimensions, etc.)
config = checkpoint.pop("config")

# 🏗️ CREATE MODEL CONFIGURATION
# Instantiate ModelArgs with the necessary parameters from the loaded config
model_args = ModelArgs(
    dim=config.hidden_size,                    # Hidden dimension (768)
    n_layers=config.num_hidden_layers,        # Number of transformer layers (12)
    n_heads=config.num_attention_heads,        # Number of attention heads (12)
    n_kv_heads=config.num_key_value_heads,    # Number of key-value heads (12)
    vocab_size=config.vocab_size,              # Vocabulary size (16384)
    norm_eps=config.rms_norm_eps,             # RMS normalization epsilon (1e-06)
    rope_theta=config.rope_theta,             # RoPE theta parameter (10000.0)
    max_seq_len=context,                       # Maximum sequence length (1024)
    dropout=config.attention_dropout,          # Dropout rate (0.0)
    hidden_dim=config.intermediate_size,       # Feed-forward hidden dimension (3072)
    attention_bias=config.attention_bias,      # Whether to use attention bias (False)
    mlp_bias=config.mlp_bias                   # Whether to use MLP bias (False)
)

# 📊 MODEL ARCHITECTURE SUMMARY
# ModelArgs(dim=768, n_layers=12, n_heads=12, n_kv_heads=12, vocab_size=16384, 
#          norm_eps=1e-06, rope_theta=10000.0, max_seq_len=1024, dropout=0.0, 
#          hidden_dim=3072, attention_bias=False, mlp_bias=False)

# 🚀 INSTANTIATE AND LOAD MODEL
model = Llama(model_args)           # Create the model with our configuration
model.load_state_dict(checkpoint)   # Load the pre-trained weights into the model

# ⚡ MODEL OPTIMIZATION AND DEVICE SETUP
model = model.to(dtype)   # Set the precision type (bfloat16 for faster training)
model = model.to(device)  # Move the model to GPU (or CPU if GPU not available)

# 🎯 SET TRAINING MODE
model.train()  # Enable training mode (enables dropout, batch norm updates, etc.)

# 🚀 OPTIONAL: MODEL COMPILATION
# Torch.compile compiles a PyTorch model to an optimized version for better performance
# This can significantly speed up training but may not be supported on all systems
if compile:
    print("[INFO] Compiling model for optimal performance")
    model = torch.compile(model)

# 📊 MODEL SIZE INFORMATION
# Print the total number of parameters in the model
total_params = sum(p.numel() for p in model.parameters())
print(f"Model has {total_params / 1e6:.1f} Million parameters")
print(f"Model architecture: {config.num_hidden_layers} layers, {config.hidden_size} hidden size")
print(f"Model is ready for ORPO alignment training!")

In [None]:
#######################################################
########## 🎯 SETUP TRAINING AND OPTIMIZER ############
#######################################################

# 🚀 OPTIMIZER CONFIGURATION
# Declare optimizer - helps us compute gradients, update parameters, manage learning rate, apply weight decay
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=lr,                    # Learning rate (6e-5)
    betas=(0.9, 0.95),        # Control exponential moving averages of gradient and its square (AdamW algorithm)
    eps=1e-8,                 # Small number for numerical stability in computations
    fused=device == 'cuda',   # Fused operations for better GPU performance (combines multiple operations)
    weight_decay=weight_decay # L2 regularization strength
)

# 📊 TRAINING SCHEDULE CALCULATION
# Calculate total number of training steps: length of training loader × number of epochs
num_training_steps = len(train_loader) * epochs  # 111,408 with default settings (batch size = 1)

# 📈 LEARNING RATE SCHEDULER
# First 100 steps: linear warmup (gradually increase LR from 0 to full value)
# After warmup: cosine decay (gradually decrease LR following a cosine curve)
def lr_lambda(current_step):
    # Linear warmup phase
    if current_step < lr_warmup_steps:
        return float(current_step) / float(max(1, lr_warmup_steps))
    
    # Cosine decay phase
    progress = float(current_step - lr_warmup_steps) / float(max(1, num_training_steps - lr_warmup_steps))
    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(0.5) * 2.0 * progress)))

# 🔧 CREATE LEARNING RATE SCHEDULER
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)



In [None]:
# 🧮 COMPUTE LOG PROBABILITIES - Core function for ORPO odds ratio calculation
# This function calculates the log probabilities for chosen responses, necessary for Log Odds Calculation
def compute_logps(prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
    """
    Calculate the average log probability of the chosen response tokens.
    
    Args:
        prompt_attention_mask: Mask for prompt tokens (1s where prompt exists, 0s elsewhere)
        chosen_inputs: Token IDs of the chosen response
        chosen_attention_mask: Mask for chosen response tokens
        logits: Model predictions (logits) for all tokens
    
    Returns:
        Average log probability of the chosen response
    """

    # 📝 LABEL SHIFTING EXPLANATION
    # In general, we get rid of the first element in labels because we want to match each token 
    # to the label in the next position (we shift labels one to the left).
    # As a consequence, we get rid of the last element of logits to equalize dimensions
    # and also because we don't care about predictions for the last token (no next token after that)

    # 🎯 CREATE RESPONSE MASK
    # Create mask with only positions of the last answer, starting from the character before the last answer
    # because we will start predicting from that position
    mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:]

    # 🔍 GATHER LOG PROBABILITIES
    # torch.gather selects elements of logits based on the indices in index_tensor along dimension 2
    # For example: if index gives us token 1160, we go to logits and extract the probability of token 1160
    # IMPORTANT: log_softmax function already incorporates the negative sign, so it produces negative log probabilities
    # logits[:,:-1,:] shape: (1, 1023, 16384)
    # index = (mask * chosen_inputs[:, 1:]).unsqueeze(2) shape: (1, 1023, 1)
    # final result: per_token_logps shape: (1, 1023)
    per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2, 
                                    index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2)

    # 📊 NORMALIZE BY RESPONSE LENGTH
    # Mask the per_token_logps to leave only positions of the last answer, then normalize
    # mask.sum will only sum the active elements so that we normalize by the total tokens of the answer
    return torch.mul(per_token_logps, mask.to(dtype)).sum(dim=1).to(dtype) / mask.sum(dim=1).to(dtype)



In [None]:
#def compute_logps(self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
        #mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:]
        #per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2, 
        #                               index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2)
        #return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(dtype=torch.float64) / mask.sum(dim=1).to(dtype=torch.float64)

# this is the code of the toy example we create in the videos in order
# to understand in depth how the compute_logps function works
# if you wish to run it, uncomment it (you need to uncomment the three quotes at the beginning and also at the end of the code)

'''
p_mask = torch.tensor([1,1,1,1,0,0,0,0,0,0,0])
c_inputs = torch.tensor([2,4,3,2,4,2,1,0,0,0,0])
c_mask = torch.tensor([1,1,1,1,1,1,1,0,0,0,0])
print(f"p_mask: {p_mask}")
print(f"c_inputs: {c_inputs}")
print(f"c_mask: {c_mask}")
mask = c_mask[:-1] - p_mask[1:] 
print(f"mask: {mask}")
logits = torch.tensor([
    [0.2, 0.4, 0.8, 0.1, 0.3],
    [0.2, 0.1, 0.5, 0.12, 0.31],
    [0.22, 0.44, 0.81, 0.13, 0.32],
    [0.29, 0.42, 0.84, 0.15, 0.32],
    [0.24, 0.48, 0.88, 0.17, 0.34],
    [0.21, 0.41, 0.81, 0.14, 0.33],
    [0.23, 0.43, 0.82, 0.16, 0.35],
    [0.2, 0.4, 0.8, 0.1, 0.3],
    [0.2, 0.1, 0.5, 0.12, 0.31],
    [0.22, 0.44, 0.81, 0.13, 0.32],
    [0.22, 0.44, 0.81, 0.13, 0.32]
])

print(f"c_inputs[1:]: {c_inputs[1:]}")
index=(mask * c_inputs[1:])
print(f"index: {index}")

# Expand dimensions for correct gather shape
index_expanded = index.unsqueeze(1)
print(f"index_expanded: {index_expanded}")

print("shapes: ",index_expanded.shape, logits.shape)
# Gather the values at the specified indices
gathered_values = torch.gather(logits[:-1,:], dim=1, index=index_expanded)
print(f"gathered: {gathered_values}")

# Squeeze to remove the unnecessary dimension
per_token_logps = gathered_values.squeeze(1)
print(f"per_token_logps: {per_token_logps}")

result = torch.mul(per_token_logps, mask)
print(f"result: {result}")
f1 = result.sum(dim=0)
f2 = mask.sum(dim=0)
print(f"f1: {f1}")
print(f"f2: {f2}")
final = f1 / f2
print(f"final: {final}")
'''

In [None]:
# Setup Iterators and update key variables
val_iterator = iter(val_loader)
train_iterator = iter(train_loader)
log_iters = 100
eval_iters= 5 # Use a small number, otherwise things will get too slow

print(f"train loader size: {len(train_loader)}")
print(f"validation loader size: {len(val_loader)}")
print(f"number of training steps: {num_training_steps}")

In [None]:
@torch.no_grad()  # Prevent gradient calculation
# Calculate average of training and validation losses over multiple batches
def calculate_loss():
    global train_iterator, val_iterator
    loss_mean={}
    odds_mean={}
    ratio_mean={}
    model.eval()
    for split in ['train','val']: 
        l=torch.zeros(eval_iters)  # Create a tensor of zeros the size of eval_iters
        o=torch.zeros(eval_iters)  # Create a tensor of zeros the size of eval_iters
        r=torch.zeros(eval_iters)  # Create a tensor of zeros the size of eval_iters
        for i in range(eval_iters):
            try:
                if split == 'val':
                    batch = next(val_iterator)
                else:
                    batch = next(train_iterator)
            except StopIteration:
                if split == 'val':
                    print("####### Resetting Validation Iterator")
                    val_iterator = iter(val_loader)
                    batch = next(val_iterator)
                else:
                    print("####### Resetting Training Iterator")
                    train_iterator = iter(train_loader)
                    batch = next(train_iterator)                   

            batch["positive_input_ids"] = batch["positive_input_ids"].to(device) 
            batch["positive_attention_mask"] = batch["positive_attention_mask"].to(device)
            batch["negative_input_ids"] = batch["negative_input_ids"].to(device)
            batch["negative_attention_mask"] = batch["negative_attention_mask"].to(device)
            batch["attention_mask"] = batch["attention_mask"].to(device)
        
            neg_labels = batch['negative_input_ids'].clone()
            pos_labels = batch['positive_input_ids'].clone()
        
            mask = batch['attention_mask'] * batch['positive_attention_mask']  # sets mask to have 1s in only the prompt positions
            pos_labels = pos_labels * mask.logical_not()  # puts 0s where the prompt was, preserve last answer (padding tokens are EOS(2))
        
            pos_labels[pos_labels == 0] = tokenizer.pad_token_id # replaces 0s with EOS(2)
            neg_labels[neg_labels == tokenizer.pad_token_id] = -100 # change 2 to -100 so that loss calculations ignore prompt and padding
            pos_labels[pos_labels == tokenizer.pad_token_id] = -100 # change 2 to -100 so that loss calculations ignore prompt and padding
        
            outputs_pos, loss_pos = model(batch['positive_input_ids'], pos_labels)  #  (1,1024) , (1,1024)
            outputs_neg, loss_neg = model(batch['negative_input_ids'], neg_labels)    
        
            # returns the average of the log probabilities for the positive samples (masking out prompt)
            pos_prob = compute_logps(
                        prompt_attention_mask=batch['attention_mask'], 
                        chosen_inputs=batch["positive_input_ids"], 
                        chosen_attention_mask=batch['positive_attention_mask'], 
                        logits=outputs_pos
                    )
            # returns the average of the log probabilities for the negative samples (masking out prompt)
            neg_prob = compute_logps(
                        prompt_attention_mask=batch['attention_mask'], 
                        chosen_inputs=batch["negative_input_ids"], 
                        chosen_attention_mask=batch['negative_attention_mask'], 
                        logits=outputs_neg
                    )    
        
            # CALCULATE ORPO ODDS RATIO
            log_odds = (pos_prob - neg_prob) - (torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)))
            sig_ratio = F.sigmoid(log_odds) # constrain to be between 0 and 1
            ratio = torch.log(sig_ratio) # apply the final log to the calculation
        
            # Calculate the Final Total Loss, combination of standard Cross Entropy loss and the weighted Odds Ratio
            loss = torch.mean(loss_pos - (alpha*ratio).mean()).to(dtype=dtype)
            # notice that mean() is useful if batch size is larger than 1  

            l[i]=loss.item()
            o[i]=log_odds.mean().item()
            r[i]=ratio.mean().item()
        
        loss_mean[split]=l.mean().item()
        odds_mean[split]=o.mean().item()
        ratio_mean[split]=r.mean().item()
        
            
    model.train()
    return loss_mean, odds_mean, ratio_mean

l, o, r = calculate_loss()
print(l,o,r)

In [None]:
################################################
################################################
############### 🎯 ORPO TRAINING ################
################################################
################################################

# 🚀 MAIN TRAINING LOOP - ORPO Alignment Training
try:
    # 📅 EPOCH LOOP - Train for multiple complete passes through the dataset
    for e in range(epochs):
        print(f"\n🔄 Starting Epoch {e+1}/{epochs}")
        
        # 📦 BATCH LOOP - Process each batch of training examples
        for i, batch in tqdm(enumerate(train_loader), total=len(train_loader), dynamic_ncols=True):
        
            # 🧹 RESET GRADIENTS
            optimizer.zero_grad(set_to_none=True)  # Clear gradients from previous iteration
    
            # 🚚 MOVE DATA TO DEVICE
            # Transfer all batch data to GPU (or CPU) for processing
            batch["positive_input_ids"] = batch["positive_input_ids"].to(device) 
            batch["positive_attention_mask"] = batch["positive_attention_mask"].to(device)
            batch["negative_input_ids"] = batch["negative_input_ids"].to(device)
            batch["negative_attention_mask"] = batch["negative_attention_mask"].to(device)
            batch["attention_mask"] = batch["attention_mask"].to(device)

            # 🔍 DEBUGGING HELPER
            # If you want to look inside batch: tokenizer.decode(batch["positive_input_ids"][0])

            # 📋 PREPARE LABELS FOR TRAINING
            # Get the token IDs of positive and negative responses
            neg_labels = batch['negative_input_ids'].clone()  # Copy negative response tokens
            pos_labels = batch['positive_input_ids'].clone()   # Copy positive response tokens
    
            # 🎯 CALCULATE STANDARD CROSS ENTROPY LOSS (focused on POSITIVE responses)

            # 🚫 DISABLE LOSS ON PROMPT TOKENS
            # When we calculate the standard loss, we focus on the loss of the positive responses
            # We want to measure how well the model predicts the next token in the positive, chosen responses
            # So we mask the positive IDs to only consider the response tokens, ignoring the prompt
            mask = batch['attention_mask'] * batch['positive_attention_mask']  # 1s only in prompt positions
            # In our case, this is similar to just mask = batch['attention_mask'] (all sequences have same length)
            pos_labels = pos_labels * mask.logical_not()  # 0s where prompt was, preserve response tokens

            # 🔧 PREPARE LABELS FOR LOSS CALCULATION
            pos_labels[pos_labels == 0] = tokenizer.pad_token_id  # Replace 0s with EOS token (2)
            neg_labels[neg_labels == tokenizer.pad_token_id] = -100  # -100 ignores prompt and padding
            pos_labels[pos_labels == tokenizer.pad_token_id] = -100  # -100 ignores prompt and padding

            # 🧠 MODEL FORWARD PASS - POSITIVE RESPONSE
            outputs_pos, loss_pos = model(batch['positive_input_ids'], pos_labels)  # Shape: (1,1024), (1,1024)
            # positive_input_ids: all token IDs including response and padding (EOS token = 2)
            # pos_labels: everything set to -100 except the response tokens (ignored in loss calculation)

            # 🧠 MODEL FORWARD PASS - NEGATIVE RESPONSE  
            # We don't use the negative loss for standard training, but we need the output logits
            # for the per-token log probability calculations of the negative responses
            outputs_neg, loss_neg = model(batch['negative_input_ids'], neg_labels)    

            # 🧮 CALCULATE PER-TOKEN LOG PROBABILITIES (necessary for ORPO odds ratio)

            # ✅ POSITIVE RESPONSE LOG PROBABILITIES
            # Returns the average log probabilities for the positive samples (masking out prompt)
            pos_prob = compute_logps(
                prompt_attention_mask=batch['attention_mask'], 
                chosen_inputs=batch["positive_input_ids"], 
                chosen_attention_mask=batch['positive_attention_mask'], 
                logits=outputs_pos
            )
            
            # ❌ NEGATIVE RESPONSE LOG PROBABILITIES
            # Returns the average log probabilities for the negative samples (masking out prompt)
            neg_prob = compute_logps(
                prompt_attention_mask=batch['attention_mask'], 
                chosen_inputs=batch["negative_input_ids"], 
                chosen_attention_mask=batch['negative_attention_mask'], 
                logits=outputs_neg
            )    

            # 🎯 CALCULATE ORPO ODDS RATIO
            # This is the core of ORPO: compare how much the model prefers positive vs negative responses
            log_odds = (pos_prob - neg_prob) - (torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)))
            sig_ratio = F.sigmoid(log_odds)  # Constrain to be between 0 and 1
            ratio = torch.log(sig_ratio)     # Apply final log to the calculation

            # 🎯 CALCULATE FINAL TOTAL LOSS
            # Combination of standard Cross Entropy loss and the weighted Odds Ratio
            loss = torch.mean(loss_pos - (alpha*ratio).mean()).to(dtype=dtype)
            # Note: mean() is useful if batch size is larger than 1

            # 📊 LOGGING AND MONITORING
            if i%log_iters == 0:
                # Calculate average losses across multiple batches for more stable metrics
                loss_m, log_odds_m, ratio_m = calculate_loss()

                # 📈 PRINT TRAINING PROGRESS
                print(f"Epochs [{e+1}/{epochs}] Step: [{i}/{len(train_loader)}], train loss: {loss_m['train']:.4f}, val loss: {loss_m['val']:.4f}, Odds Ratio: {log_odds_m['train']:.4f}, val Odds Ratio: {log_odds_m['val']:.4f}")
                
                # 📊 WANDB LOGGING
                if wandb_log:
                    wandb.log({
                        "train_loss": loss_m['train'],           # Training loss
                        "val_loss": loss_m['val'],               # Validation loss  
                        "train_log_odds": log_odds_m['train'],   # Training odds ratio
                        "val_log_odds": log_odds_m['val'],       # Validation odds ratio
                        "train_ratio": (alpha*ratio_m['train']),  # Weighted training ratio
                        "val_ratio": (alpha*ratio_m['val']),      # Weighted validation ratio
                        # Optional metrics (uncomment to track):
                        #"pos_prob": pos_prob.mean().item(),     # Positive response probability
                        #"neg_prob": neg_prob.mean().item(),     # Negative response probability                        
                        #"lr": scheduler.get_last_lr()[0],       # Current learning rate
                    }, 
                    step = (e*len(train_loader) + i))

                # 🚨 NAN DETECTION
                if torch.isnan(loss):
                    print("❌ NaN loss detected! Stopping training...")
                    if wandb_log:   
                        wandb.finish()
                    torch.cuda.empty_cache()
                    sys.exit()

            # 🔄 OPTIMIZATION STEP
            loss.backward()  # Calculate gradients
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)  # Clip gradients to prevent explosion
            optimizer.step()  # Update model parameters
            scheduler.step()  # Update learning rate

        # 💾 SAVE CHECKPOINT AT END OF EPOCH
        # Save model state and configuration for future use
        sd = model.state_dict()
        sd['config'] = config
        torch.save(sd, os.path.join(checkpoint_dir, f'{project_name}_{e+1}.pt'))
        print(f"✅ Checkpoint saved: {project_name}_{e+1}.pt")

except KeyboardInterrupt:
    print("⏹️ Training interrupted by user. Cleaning up...")

finally:
    # 🧹 CLEANUP AND MEMORY MANAGEMENT
    # Release GPU memory to free up resources
    torch.cuda.empty_cache()
    print("🧹 GPU memory released.")

# 📊 FINALIZE LOGGING
if wandb_log:   
    wandb.finish()
    print("📊 Weights & Biases logging finished.")

# 🧹 FINAL CLEANUP
torch.cuda.empty_cache()
print("🎉 ORPO training completed!")

