In [None]:
# LLM Alignment with ORPO Technique - Notebook created by Javier Ideami (ideami.com)
# ORPO: Monolithic Preference Optimization without Reference Model (https://arxiv.org/abs/2403.07691)

# LLM Alignment needs to be applied to an LLM that is a bit more powerful than the basic one we trained earlier
# That's why I am providing the checkpoint of an already trained more powerful LLM (138 million parameters vs 19 million of ours)
# so that we can apply ORPO alignment on top of it.
# The pretrained 138 million parameter model has been pre-trained on the open source Fineweb-Edu dataset.

# Traditional alignment techniques like RLHF take a long time and money to be implemented. New techniques like ORPO simplify alignment a lot.
# We will be able to train ORPO alignment with 1 single GPU and little GPU memory (around 4 gigabytes)  
# Of course results will be imperfect because we are still applying ORPO on top of a model that is small and has been trained with much
# less data than the large models out there, and for a much shorter amount of time. But it will be enough to compare the before and after,
# and to see how ORPO improves the way the LLM communicates in many occasions.
# And we will understand every part of the code that makes it possible.

# This file includes code from the open source ORPO repository licensed under the Apache License 2.0.
# See licenses/orpo-license for details.
# Modifications: Variable names have been changed for educational purposes.

# This file also uses the ORPO-DPO-mix-40k dataset licensed under the Apache License 2.0.
# See licenses/orpo-dpo-mix-40k-license for details.

# Official notebook 

In [None]:
#### For GOOGLE COLAB and similar platform Users:
#### Make sure to select a GPU in the online platform. Don't run this code with a CPU (it will be too slow)

# If you are running this code locally, your GPU should be selected automatically

In [None]:
# uncomment and run the following installation lines ONLY if you havent installed these libraries already outside of the notebook
#!pip install -q ipdb
#!pip install -q tqdm
#!pip install -q datasets
#!pip install -q transformers
#!pip install -q wandb

# And if you are not in Google Colab and you didn't yet install Pytorch, make sure to do it:
# find the ideal pytorch installation command at https://pytorch.org/get-started/locally/

# in addition, you may try to install the flash-attn library, although at the moment it is having issues


In [None]:
# You can use this command to view information about your GPU and the amount of free memory it has
# Make sure that you have at last 4GB of free GPU memory to do this course
!nvidia-smi 
# If you are using Google Colab or a similar online platform, make sure to select a GPU in the menus
# In Google colab, at the moment the option is within the Runtime menus

In [None]:
# Download necessary files and create necessary folders
# llm.py - llm model: an llm architecture that is more powerful
# models folder: pretrained checkpoints for the base and the aligned model
# data folder: tokenized dataset stored in this folder
# tokenizers folder: pretrained tokenizer on the large dataset FineWeb-Edu

# NOTE: Downloading will take a while, be patient. You can refresh your folder from time to time to see when the files
# have been created. If you have any problems downloading the files with this code, I have also added llm_align.zip
# to the downloadable resources of this lecture (however, best option is to use this code, because then you don't need
# to upload the files or do anything else)

import os, requests, zipfile, io 

files_url = "https://ideami.com/llm_align"

# Downloading proceeds if we detect that one of the key files to download is not present
if not os.path.exists(f"llm.py"):
    print("Downloading files using Python")
    response = requests.get(files_url)
    zipfile.ZipFile(io.BytesIO(response.content)).extractall(".")
else:
    print("you seem to have already downloaded the files. If you wish to re-download them, delete the llm.py file")



In [None]:
### Import necessary libraries

import os
import sys
import math
from tqdm import tqdm
from datetime import datetime
import ipdb

from typing import List, Dict, Union

# Pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# Hugging Face libraries that will accelerate our implementation
import transformers
from datasets import load_dataset, load_from_disk


# These lines improve performance for Ampere Architecture (e.g: A100s)
torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
# Empty GPU cache memory
torch.cuda.empty_cache()

# Optional, for debugging, if you want to view entire tensors and values in a more comfortable way, etc
torch.set_printoptions(threshold=10000)
torch.set_printoptions(sci_mode=False, precision=2)

In [None]:
# TRAINING PARAMETERS
batch_size = 1  # you can change it to 2 if you have enough GPU memory 
epochs = 3
lr = 6e-5 
lr_warmup_steps = 100 # learning rate warmup phase
context = 1024
alpha = 0.5 # weighting for the ORPO odds ration
prompt_max_size = 512 # limit for the prompt part of the interaction
compile = False # Compile improves performance in compatible systems
dtype = torch.bfloat16 # Setting precision for calculations
log_iters = 100

# HYPERPARAMETERS
dropout = 0.
grad_clip = 1.0
weight_decay = 0.0

# DEVICE - Set device to GPU or CPU (use GPU definitely)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device: You will be using: ",device)

In [None]:
# LOGGING
project_name="test"
wandb_log = True
wandb_project = project_name
wandb_run_name = "test-run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")

if wandb_log:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name)

# The first time you run this logging code set to True, the weights and biases library
# will ask you for an API key. You can follow the instructions in the video, or you can
# also simply click on a link that should appear when you run this cell, pointing to this
# address: https://wandb.ai/authorize  
# Going to that address will allow you to quickly get an API key as well

In [None]:
# DATASET paths
dataset_path = "./data/orpo_dataset" # where the orpo dataset will be stored

# This is a special dataset prepared for ORPO alignment training. It is available at HuggingFace
# https://huggingface.co/datasets/mlabonne/orpo-dpo-mix-40k/blob/main/README.md
dataset_name = "mlabonne/orpo-dpo-mix-40k" # path to the huggingface orpo dataset

tokenizer_path = "tokenizers/tok16384" # path to the tokenizer
checkpoint_dir = './models/'  # where we store checkpoints


In [None]:
# Tokenizing the Dataset
###########
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path) # Load tokenizer in HuggingFace format

# Set our interaction template
tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"

# Make padding token equal to the end of sentence token (which has id of 2 in this case)
tokenizer.pad_token = tokenizer.eos_token

# If you want to debug the tokenization of the dataset, delete the orpo_dataset folder or change
# this dataset_path to something else like "./data/tmp"
# uncomment to activate the preprocessing of dataset
#dataset_path = "./data/tmp" # delete this folder to preprocess again

# If dataset has already been processed and tokenized, we can load it directly from our disk
if os.path.exists(dataset_path):
    print("Loading encoded dataset from disk")
    dataset = load_from_disk(dataset_path)
# Otherwise, we will load the dataset from huggingface and then filter it and tokenize it and save it
else:
    print("Preprocessing and tokenizing dataset")
    dataset = load_dataset(dataset_name, split="all")

    # Optional: Filter Toxic entries / without vs with this filter: 37136 vs 36622 elements aprox
    dataset = dataset.filter(lambda r: r["source"] != "toxic-dpo-v0.2")

    # FILTER DATASET
    # This function will eliminate entries that are longer than 512(prompt_max_size). This is important because we want prompt+answer to fit
    # within the total context (1024).
    def filter_dataset(examples):
        # examples['chosen'][:-1] picks the prompt minus the answer
        prompt_length = tokenizer.apply_chat_template(examples['chosen'][:-1], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)  
        # Preserve only samples that have a prompt smaller than prompt_max_size
        if prompt_length < prompt_max_size:    
            return True
        else:
            return False


    # PREPROCESS DATASET: tokenize it and store fields you will need later    
    # HF Tokenizer Dict Format
    # Encoding(num_tokens=1024, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])
    def preprocess_dataset(examples: Union[List, Dict]):
        # processed in batches of 1000 by default
        # Take chosen field, eliminate last answer, apply template adding assistant prompt - explore: prompt[0]
        prompt = [tokenizer.apply_chat_template(item[:-1], tokenize=False, add_generation_prompt=True) for item in examples['chosen']]
        # very important: some of the samples are multi-turn conversations. The prompt includes all interactions between user and assistant
        # until the last question. We remove the last answer and all the previous interaction becomes the prompt.

        # Take the chosen field, then apply chat template 
        chosen = [tokenizer.apply_chat_template(item, tokenize=False) for item in examples['chosen']]

        # Take the rejected field, then apply chat template 
        rejected = [tokenizer.apply_chat_template(item, tokenize=False) for item in examples['rejected']]
        
        # Tokenice the prompt
        inputs = tokenizer(prompt,   max_length=context, padding='max_length', truncation=True, return_tensors='pt')
        # model_inputs will have dict fields input_ids and attention_mask (e.g: model_inputs.input_ids[0])
        # testing by decoding: tokenizer.decode(model_inputs.input_ids[0])

        # Important, all elements will have same length of 1024 tokens, extra padding tokens will be added to reach 1024

        # Tokenice the chosen positive response
        pos_labels   = tokenizer(chosen,   max_length=context, padding='max_length', truncation=True, return_tensors='pt')
        # pos_labels will have dict fields input_ids and attention_mask (e.g: pos_labels.input_ids[0])
        # testing by decoding: tokenizer.decode(pos_labels.input_ids[0])
        
        # Tokenice the rejected negative response
        neg_labels   = tokenizer(rejected, max_length=context, padding='max_length', truncation=True, return_tensors='pt') 
        # same as before

        # Add the chosen-positive and rejected-negative response ids and masks to the prompt ones, so that we have it all in -inputs-
        inputs['positive_input_ids'] = pos_labels['input_ids']
        inputs['positive_attention_mask'] = pos_labels['attention_mask']

        inputs['negative_input_ids'] = neg_labels['input_ids']
        inputs['negative_attention_mask'] = neg_labels['attention_mask']

        return inputs
    
    # Filter dataset to exclude prompts that are too long
    dataset = dataset.filter(filter_dataset)

    # Preprocess the dataset, if you have issues with multiprocessing, make sure to use num_proc=1
    # multiprocessing alternative: dataset = dataset.map(preprocess_dataset, batched=True, num_proc=min(32,os.cpu_count()), remove_columns=dataset.column_names)  
    dataset = dataset.map(preprocess_dataset, batched=True, num_proc=1, remove_columns=dataset.column_names) 
    # sent in batches of 1000 by default 

    # As a result of the preprocessing, dataset variable will have all these internal fields:
    #dataset: Dataset({features: ['input_ids', 'token_type_ids', 'attention_mask', 'positive_input_ids', 'positive_attention_mask', 'negative_input_ids', 'negative_attention_mask'],num_rows: 39091})

    dataset.save_to_disk(dataset_path)    


In [None]:
# At this point, you can test some of the content of the dataset with for example these: 
# dataset[0]['input_ids']  /  dataset[0]['positive_input_ids']
# testing Ids to Text: tokenizer.decode(dataset[0]['positive_input_ids'])

# Split the data into train and validation, 5% for the validation set
dataset = dataset.shuffle(42).train_test_split(test_size=0.05)  

train_data = dataset["train"]
#Dataset({features: ['input_ids', 'token_type_ids', 'attention_mask', 'positive_input_ids', 'positive_attention_mask', 'negative_input_ids', 'negative_attention_mask'],num_rows: 37136})

val_data = dataset["test"]
#Dataset({features: ['input_ids', 'token_type_ids', 'attention_mask', 'positive_input_ids', 'positive_attention_mask', 'negative_input_ids', 'negative_attention_mask'],num_rows: 1955})

# Data_collator efficiently prepares your training and validation data for language modeling by batching, padding (optional), and performing masking (optional) according to your configuration
data_collator = transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Setup DataLoaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, collate_fn=data_collator, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False, collate_fn=data_collator, num_workers=0)


In [None]:
# OPTIONAL: Test your DataLoaders
#batch = next(iter(train_loader))
#print(tokenizer.decode(batch['input_ids'][0]))  # this will show just the prompt, followed by padding tokens

# debug: batch['input_ids'] (shape is 1,1024)  
# debug: Ids to Text -> tokenizer.decode(batch['input_ids'][0]) - [0] because needs to address inside batch number dim



In [None]:
###############################################################
###############################################################
################### SETUP MODEL ###############################
###############################################################
###############################################################

# Import Llama based model
from llm import Llama, ModelArgs

# Load Pretrained Model of 138 million parameters
checkpoint = torch.load(os.path.join(checkpoint_dir, "base_model.pt"))

# Extract config from the pretrained model
config = checkpoint.pop("config")

# Instantiate ModelArgs with the necessary parameters
model_args = ModelArgs(
    dim=config.hidden_size, 
    n_layers=config.num_hidden_layers, 
    n_heads=config.num_attention_heads, 
    n_kv_heads=config.num_key_value_heads, 
    vocab_size=config.vocab_size, 
    norm_eps=config.rms_norm_eps, 
    rope_theta=config.rope_theta,
    max_seq_len=context, 
    dropout=config.attention_dropout, 
    hidden_dim=config.intermediate_size,
    attention_bias=config.attention_bias,
    mlp_bias=config.mlp_bias
)

#ModelArgs(dim=768, n_layers=12, n_heads=12, n_kv_heads=12, vocab_size=16384, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-06, rope_theta=10000.0, max_seq_len=1024, dropout=0.0, hidden_dim=3072, attention_bias=False, mlp_bias=False)

model = Llama(model_args)  # Instantiate model
model.load_state_dict(checkpoint)  # Load checkpoint

model = model.to(dtype) # Set the precision type
model = model.to(device) # Move it to the right device

model.train()

# Torch.compile compiles a PyTorch model to an optimized version, aiming to improve runtime performance and efficiency.
# Disable if your system doesn't support it
if compile:
    print("[INFO] Compiling model")
    model = torch.compile(model)

# Print the number of parameters of the model
print(sum(p.numel() for p in model.parameters()) / 1e6, " Million parameters")

In [None]:
#######################################################
########## SETUP TRAINING AND OPTIMIZER ###############
#######################################################


# Declare optimizer, it helps us compute gradients, update parameters, manage learning rate, apply weight decay
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95), eps=1e-8, fused= device == 'cuda', weight_decay=weight_decay)
# betas: control the exponential moving averages of the gradient and its square (essential part of AdamW alg) 
# eps: a small number to add numerical stability in computations
# fused: technique used to improve the performance of computations, by combining multiple operations into a single one 

# Calculate max total number of steps, the length of training loader times number of epochs
num_training_steps = len(train_loader) * epochs  #111408 with default settings - we use BS of 1 by default

# Scheduler for learning rate: first 100 steps we do a warmup in which we increase linearly the LR.
# After warmup, we decrease it gradually following a cosine curve
def lr_lambda(current_step):
    if current_step < lr_warmup_steps:
        return float(current_step) / float(max(1, lr_warmup_steps))
    progress = float(current_step - lr_warmup_steps) / float(max(1, num_training_steps - lr_warmup_steps))
    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(0.5) * 2.0 * progress)))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)



In [None]:
# Compute log probabilities for positive and negative responses, necessary for Log Odds Calculation
def compute_logps(prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):

    # note: in general we get rid of first element in labels because we want to match each token to the label in next position so 
    # we shift labels one to the left. As a consequence we get rid of last element of logits to equalize dimensions and also
    # because we dont care about the predictions for last token as there is no next token after that

    # create mask with only positions of last answer but starting from the character before the last answer,
    # because we will start predicting from that one
    mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:]

    # Gather logits corresponding to the IDs of the tokens of chosen answer
    # torch.gather selects elements of logits based on the indices in index_tensor along the specified dimension dim=2.
    # for example index gives us token 1160. Now we go to logits and from dimension 2 we extract the probability of token 1160
    # IMPORTANT: log_softmax function already incorporates the negative sign inside, so it produces negative log probabilities
    # logits[:,:-1,:] (1,1023,16384)
    # index = (mask * chosen_inputs[:, 1:]).unsqueeze(2)  (1, 1023, 1)
    # final result: per_token_logps: 1,1023
    per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2, 
                                    index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2)

    # mask the per_token_logps to leave only positions of last answer, then normalize
    # mask.sum will only sum the active elements of the mask so that you normalize by the total tokens of answer
    return torch.mul(per_token_logps, mask.to(dtype)).sum(dim=1).to(dtype) / mask.sum(dim=1).to(dtype)



In [None]:
#def compute_logps(self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
        #mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:]
        #per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2, 
        #                               index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2)
        #return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(dtype=torch.float64) / mask.sum(dim=1).to(dtype=torch.float64)

# this is the code of the toy example we create in the videos in order
# to understand in depth how the compute_logps function works
# if you wish to run it, uncomment it (you need to uncomment the three quotes at the beginning and also at the end of the code)

'''
p_mask = torch.tensor([1,1,1,1,0,0,0,0,0,0,0])
c_inputs = torch.tensor([2,4,3,2,4,2,1,0,0,0,0])
c_mask = torch.tensor([1,1,1,1,1,1,1,0,0,0,0])
print(f"p_mask: {p_mask}")
print(f"c_inputs: {c_inputs}")
print(f"c_mask: {c_mask}")
mask = c_mask[:-1] - p_mask[1:] 
print(f"mask: {mask}")
logits = torch.tensor([
    [0.2, 0.4, 0.8, 0.1, 0.3],
    [0.2, 0.1, 0.5, 0.12, 0.31],
    [0.22, 0.44, 0.81, 0.13, 0.32],
    [0.29, 0.42, 0.84, 0.15, 0.32],
    [0.24, 0.48, 0.88, 0.17, 0.34],
    [0.21, 0.41, 0.81, 0.14, 0.33],
    [0.23, 0.43, 0.82, 0.16, 0.35],
    [0.2, 0.4, 0.8, 0.1, 0.3],
    [0.2, 0.1, 0.5, 0.12, 0.31],
    [0.22, 0.44, 0.81, 0.13, 0.32],
    [0.22, 0.44, 0.81, 0.13, 0.32]
])

print(f"c_inputs[1:]: {c_inputs[1:]}")
index=(mask * c_inputs[1:])
print(f"index: {index}")

# Expand dimensions for correct gather shape
index_expanded = index.unsqueeze(1)
print(f"index_expanded: {index_expanded}")

print("shapes: ",index_expanded.shape, logits.shape)
# Gather the values at the specified indices
gathered_values = torch.gather(logits[:-1,:], dim=1, index=index_expanded)
print(f"gathered: {gathered_values}")

# Squeeze to remove the unnecessary dimension
per_token_logps = gathered_values.squeeze(1)
print(f"per_token_logps: {per_token_logps}")

result = torch.mul(per_token_logps, mask)
print(f"result: {result}")
f1 = result.sum(dim=0)
f2 = mask.sum(dim=0)
print(f"f1: {f1}")
print(f"f2: {f2}")
final = f1 / f2
print(f"final: {final}")
'''

In [None]:
# Setup Iterators and update key variables
val_iterator = iter(val_loader)
train_iterator = iter(train_loader)
log_iters = 100
eval_iters= 5 # Use a small number, otherwise things will get too slow

print(f"train loader size: {len(train_loader)}")
print(f"validation loader size: {len(val_loader)}")
print(f"number of training steps: {num_training_steps}")

In [None]:
@torch.no_grad()  # Prevent gradient calculation
# Calculate average of training and validation losses over multiple batches
def calculate_loss():
    global train_iterator, val_iterator
    loss_mean={}
    odds_mean={}
    ratio_mean={}
    model.eval()
    for split in ['train','val']: 
        l=torch.zeros(eval_iters)  # Create a tensor of zeros the size of eval_iters
        o=torch.zeros(eval_iters)  # Create a tensor of zeros the size of eval_iters
        r=torch.zeros(eval_iters)  # Create a tensor of zeros the size of eval_iters
        for i in range(eval_iters):
            try:
                if split == 'val':
                    batch = next(val_iterator)
                else:
                    batch = next(train_iterator)
            except StopIteration:
                if split == 'val':
                    print("####### Resetting Validation Iterator")
                    val_iterator = iter(val_loader)
                    batch = next(val_iterator)
                else:
                    print("####### Resetting Training Iterator")
                    train_iterator = iter(train_loader)
                    batch = next(train_iterator)                   

            batch["positive_input_ids"] = batch["positive_input_ids"].to(device) 
            batch["positive_attention_mask"] = batch["positive_attention_mask"].to(device)
            batch["negative_input_ids"] = batch["negative_input_ids"].to(device)
            batch["negative_attention_mask"] = batch["negative_attention_mask"].to(device)
            batch["attention_mask"] = batch["attention_mask"].to(device)
        
            neg_labels = batch['negative_input_ids'].clone()
            pos_labels = batch['positive_input_ids'].clone()
        
            mask = batch['attention_mask'] * batch['positive_attention_mask']  # sets mask to have 1s in only the prompt positions
            pos_labels = pos_labels * mask.logical_not()  # puts 0s where the prompt was, preserve last answer (padding tokens are EOS(2))
        
            pos_labels[pos_labels == 0] = tokenizer.pad_token_id # replaces 0s with EOS(2)
            neg_labels[neg_labels == tokenizer.pad_token_id] = -100 # change 2 to -100 so that loss calculations ignore prompt and padding
            pos_labels[pos_labels == tokenizer.pad_token_id] = -100 # change 2 to -100 so that loss calculations ignore prompt and padding
        
            outputs_pos, loss_pos = model(batch['positive_input_ids'], pos_labels)  #  (1,1024) , (1,1024)
            outputs_neg, loss_neg = model(batch['negative_input_ids'], neg_labels)    
        
            # returns the average of the log probabilities for the positive samples (masking out prompt)
            pos_prob = compute_logps(
                        prompt_attention_mask=batch['attention_mask'], 
                        chosen_inputs=batch["positive_input_ids"], 
                        chosen_attention_mask=batch['positive_attention_mask'], 
                        logits=outputs_pos
                    )
            # returns the average of the log probabilities for the negative samples (masking out prompt)
            neg_prob = compute_logps(
                        prompt_attention_mask=batch['attention_mask'], 
                        chosen_inputs=batch["negative_input_ids"], 
                        chosen_attention_mask=batch['negative_attention_mask'], 
                        logits=outputs_neg
                    )    
        
            # CALCULATE ORPO ODDS RATIO
            log_odds = (pos_prob - neg_prob) - (torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)))
            sig_ratio = F.sigmoid(log_odds) # constrain to be between 0 and 1
            ratio = torch.log(sig_ratio) # apply the final log to the calculation
        
            # Calculate the Final Total Loss, combination of standard Cross Entropy loss and the weighted Odds Ratio
            loss = torch.mean(loss_pos - (alpha*ratio).mean()).to(dtype=dtype)
            # notice that mean() is useful if batch size is larger than 1  

            l[i]=loss.item()
            o[i]=log_odds.mean().item()
            r[i]=ratio.mean().item()
        
        loss_mean[split]=l.mean().item()
        odds_mean[split]=o.mean().item()
        ratio_mean[split]=r.mean().item()
        
            
    model.train()
    return loss_mean, odds_mean, ratio_mean

l, o, r = calculate_loss()
print(l,o,r)

In [None]:
################################################
################################################
############### ORPO TRAINING ##################
################################################
################################################

try:
    for e in range(epochs):
        for i, batch in tqdm(enumerate(train_loader), total=len(train_loader), dynamic_ncols=True):
        
            optimizer.zero_grad(set_to_none=True)  # Reset gradients
    
            # Move batch data to device
            batch["positive_input_ids"] = batch["positive_input_ids"].to(device) 
            batch["positive_attention_mask"] = batch["positive_attention_mask"].to(device)
            batch["negative_input_ids"] = batch["negative_input_ids"].to(device)
            batch["negative_attention_mask"] = batch["negative_attention_mask"].to(device)
            batch["attention_mask"] = batch["attention_mask"].to(device)

            # Debug: if anytime you want to look inside batch: #tokenizer.decode(batch["positive_input_ids"][0])

            # Get the token IDs of positive and negative responses
            neg_labels = batch['negative_input_ids'].clone()
            pos_labels = batch['positive_input_ids'].clone()
    
            # CALCULATE STANDARD CROSS ENTROPY LOSS (focused on POSITIVE responses)

            # disabling loss on prompt tokens
            # When we calculate the standard loss, we will focus on the loss of the positive responses, how good is the model
            # predicting the next character in the case of the positive, chosen responses. So we want to mask the positive IDs
            # so that they only take into account the ones of the response, and ignore the prompt
            mask = batch['attention_mask'] * batch['positive_attention_mask']  # sets mask to have 1s in only the prompt positions
            # in our case the line above is similar to just mask = batch['attention_mask'] (because all our batch sequences have same length)
            pos_labels = pos_labels * mask.logical_not()  # puts 0s where the prompt was, preserve last answer (padding tokens are EOS(2))

            pos_labels[pos_labels == 0] = tokenizer.pad_token_id # replaces 0s with EOS(2)
            neg_labels[neg_labels == tokenizer.pad_token_id] = -100 # change 2 to -100 so that loss calculations ignore prompt and padding
            pos_labels[pos_labels == tokenizer.pad_token_id] = -100 # change 2 to -100 so that loss calculations ignore prompt and padding

            # Run model for positive response
            outputs_pos, loss_pos = model(batch['positive_input_ids'], pos_labels)  #  (1,1024) , (1,1024)
            #positive input ids have all IDs including last answer and the padding has EOS 2
            #pos_labels have everything set to -100 except the IDs of the last answer

            # we don't use the negative loss for anything, that's why we didn't do a similar preparation here, but we use the
            # output negative logits for the per token log probability calculations of the negative responses
            outputs_neg, loss_neg = model(batch['negative_input_ids'], neg_labels)    

            # CALCULATE PER TOKEN LOG PROBS, necessary to calculate ORPO ODDS ratio

            # returns the average of The log probabilities for the positive samples (masking out prompt)
            pos_prob = compute_logps(
                prompt_attention_mask=batch['attention_mask'], 
                chosen_inputs=batch["positive_input_ids"], 
                chosen_attention_mask=batch['positive_attention_mask'], 
                logits=outputs_pos
            )
            # returns the average of The log probabilities for the negative samples (masking out prompt)
            neg_prob = compute_logps(
                prompt_attention_mask=batch['attention_mask'], 
                chosen_inputs=batch["negative_input_ids"], 
                chosen_attention_mask=batch['negative_attention_mask'], 
                logits=outputs_neg
            )    

            # CALCULATE ORPO ODDS RATIO
            log_odds = (pos_prob - neg_prob) - (torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)))
            sig_ratio = F.sigmoid(log_odds) # constrain to be between 0 and 1
            ratio = torch.log(sig_ratio) # apply the final log to the calculation

            # Calculate the Final Total Loss, combination of standard Cross Entropy loss and the weighted Odds Ratio
            loss = torch.mean(loss_pos - (alpha*ratio).mean()).to(dtype=dtype)
            # notice that mean() is useful if batch size is larger than 1

            # log info every few iterations
            if i%log_iters == 0:

                # Calculate average losses 
                loss_m, log_odds_m, ratio_m = calculate_loss()

                print(f"Epochs [{e}/{epochs}] Step: [{i}/{len(train_loader)}], train loss: {loss_m['train']:.4f}, val loss: {loss_m['val']:.4f}, Odds Ratio: {log_odds_m['train']:.4f}, val Odds Ratio: {log_odds_m['val']:.4f}")
                
                if wandb_log:
                    wandb.log({
                        "train_loss": loss_m['train'],
                        "val_loss": loss_m['val'],
                        "train_log_odds": log_odds_m['train'],
                        "val_log_odds": log_odds_m['val'],
                        "train_ratio": (alpha*ratio_m['train']),
                        "val_ratio": (alpha*ratio_m['val']),
                        #"pos_prob": pos_prob.mean().item(),
                        #"neg_prob": neg_prob.mean().item(),                        
                        #"lr": scheduler.get_last_lr()[0],
                    }, 
                    step = (e*len(train_loader) + i))

                if torch.isnan(loss):
                    if wandb_log:   
                        wandb.finish()
                    torch.cuda.empty_cache()
                    sys.exit()

            loss.backward() # Calculate gradients
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip) # Clip gradients
            optimizer.step() # Update model parameters
            scheduler.step() # Update learning rate

        # At the end of each epoch, save a checkpoint
        sd = model.state_dict()
        sd['config'] = config
        torch.save(sd, os.path.join(checkpoint_dir, f'{project_name}_{e+1}.pt'))

except KeyboardInterrupt:
    print("Training interrupted. Cleaning up...")

finally:
    # Release GPU memory
    torch.cuda.empty_cache()
    print("GPU memory released.")

if wandb_log:   
    wandb.finish()
torch.cuda.empty_cache()

