In [None]:
!pip install datasets
!pip install wandb
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from tokenizers import Tokenizer
from pathlib import Path

from torch.utils.data import Dataset, DataLoader
import wandb
from transformers import AutoModelForCausalLM, AutoTokenizer


In [None]:
import wandb

from kaggle_secrets import UserSecretsClient
# secret_label = "your-secret-label"
secret_value = UserSecretsClient().get_secret('API_KEY')
# !wandb login
wandb.login(key=secret_value)

In [None]:

wandb.init(
            # entity = 'rajceo2031',
                        project = 'Phi2-ORPO',
                        # config = CFG,
                        # save_code = True,
                        #group = 'ANN',
                        #job_type = 'train'
)

In [4]:
#Hyperparameters

batch_size = 2
beta = 0.2
max_lr = 8e-6
betas = (0.95, 0.99)
weight_decay=0.1
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [5]:
class ORPO:
  def __init__(self, model, device, tokenizer):


    self.model = model
    self.device=device

    self.tokenizer = tokenizer
    # self.ref_model.eval()



  def ORloss(self, datapoint):



    self.win_prompt = datapoint['chosen']
    self.lose_prompt = datapoint['rejected']

    # with torch.no_grad():
    # print("Prompt ", self.win_prompt)
    # print("Prompt ", self.lose_prompt)
    # print("Logits: ", self.model(**self.win_prompt).logits)
    self.chosen_log_probs = torch.nn.functional.log_softmax(self.model(**self.win_prompt).logits, dim=-1)
    # print("Softmax: ", self.chosen_log_probs)
    self.chosen_log_probs = torch.gather(self.chosen_log_probs, -1, self.win_prompt['input_ids'].unsqueeze(-1)).squeeze(-1) #Why gather? Because its not token level stuff we care about but sequence level. Hence, we will sum up the probs of every token to get seq level but we don't want to do it for attention maksed tokens too. Hence we we will use gather() to get the ids and multiply the probs by the masked out tokens indexes.
    # print("Gather: ", self.chosen_log_probs)
    self.chosen_log_probs = self.chosen_log_probs * (self.win_prompt['attention_mask'])
    # print("Post Gather: ", self.chosen_log_probs)
    self.chosen_log_probs = self.chosen_log_probs.sum(dim=-1)
    # print("Post Sum: ", self.chosen_log_probs)

    self.rejected_log_probs = torch.nn.functional.log_softmax(self.model(**self.lose_prompt).logits, dim=-1)
    self.rejected_log_probs = torch.gather(self.rejected_log_probs, -1, self.lose_prompt['input_ids'].unsqueeze(-1)).squeeze(-1) #Why gather? Because its not token level stuff we care about but sequence level. Hence, we will sum up the probs of every token to get seq level but we don't want to do it for attention maksed tokens too. Hence we we will use gather() to get the ids and multiply the probs by the masked out tokens indexes.
    self.rejected_log_probs = self.rejected_log_probs * (self.lose_prompt['attention_mask'])
    self.rejected_log_probs = self.rejected_log_probs.sum(dim=-1)
    # print("Values: ", torch.exp(self.chosen_log_probs) - (1 - torch.log1p(torch.exp(self.chosen_log_probs))))
    # print("Values: ", torch.exp(self.rejected_log_probs) - (1 - torch.log1p(torch.exp(self.rejected_log_probs))))
      
    self.log_odds1 = torch.log1p(torch.exp(self.chosen_log_probs)) - (1 - torch.log1p(torch.exp(self.chosen_log_probs)))
    self.log_odds2 = torch.log1p(torch.exp(self.rejected_log_probs)) - (1 - torch.log1p(torch.exp(self.rejected_log_probs))) # Log1p because the gradient are exploding! (adds 1 the take the log)
    # self.odds1 = torch.div(self.chosen_log_probs,  1 - self.chosen_log_probs)
    # self.odds2 = torch.div(self.rejected_log_probs , 1 - self.rejected_log_probs)

    # self.OR = -( (self.chosen_log_probs - self.rejected_log_probs) -  nn.functional.logsigmoid(torch.log1p(torch.exp(self.chosen_log_probs)) -  torch.log1p(torch.exp(self.rejected_log_probs)))).mean()
    self.OR = -nn.functional.logsigmoid(self.log_odds1 - self.log_odds2).mean()

    return self.OR



In [None]:

# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m", token='...')
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", token='...')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})


In [11]:
from datasets import load_dataset, Dataset

train_dataset = load_dataset("argilla/ultrafeedback-binarized-preferences-cleaned", split="train", token='...')
train_dataset = train_dataset.train_test_split(test_size=0.1)
train_dataset, val_dataset = train_dataset['train'], train_dataset['test']

In [12]:
def orpo_collate_fn_merged_prompt(batch):

    merged_chosen_prompts = []
    merged_rejected_prompts = []

    for sample in batch:

        # print(sample)

        # Extract and merge chosen response
        prompt = sample['prompt']
        chosen_data = sample['chosen']
        chosen_data = "Instruction: " + prompt + "\n" + "Output: " + chosen_data[1]['content'] + "\n"
        # Extract and merge rejected response
        rejected_data = sample['rejected']
        rejected_data =  "Instruction: " + prompt + "\n" + "Output: " + rejected_data[1]['content'] + "\n"


        merged_chosen_prompts.append(chosen_data)


        merged_rejected_prompts.append(rejected_data)

    tokenized_win_prompt = tokenizer(merged_chosen_prompts, max_length = 1024, padding='max_length', truncation=True, return_tensors="pt").to(device)

    tokenized_lose_prompt = tokenizer(merged_rejected_prompts, max_length = 1024, truncation=True, padding='max_length', return_tensors="pt").to(device)



    return {
        # 'prompt': prompts, # Still return original prompts for potential use
        'chosen': tokenized_win_prompt, # List of merged prompt-chosen texts
        'rejected': tokenized_lose_prompt # List of merged prompt-rejected texts
    }

In [13]:



from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=orpo_collate_fn_merged_prompt)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=orpo_collate_fn_merged_prompt)

In [None]:
print(type(model))
model = model.to(device)
# model = torch.compile(model)

# Optimizer setup and scheduler steup
print(type(model))
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=betas, weight_decay=weight_decay)

total_steps = 8 * len(train_loader)
eval_iters = 20

orpo = ORPO(model, device, tokenizer)


model.train()

val_iterator = iter(val_loader)
train_iterator = iter(train_loader)

@torch.inference_mode()
def estimate_loss():
    loader = None
    out = {}
    model.eval()
    for split in ['val']:
        # if(split == 'train'):
        #     loader = train_itertaor

        # elif (split == 'val'):
            # loader = val_iterator


        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):

            text  = next(val_iterator)
            # text = next(iter(val_loader))
            # print("Val: chosen", text['chosen']['input_ids'].tolist())
            # print("Val: rejected", text['rejected']['input_ids'].tolist())
            targets = text['chosen']['input_ids']
            logits = model(**text['chosen']).logits
            # pred = torch.nn.functional.softmax(model(**text['chosen']).logits, dim=-1)
            logits= logits[..., :-1, :].contiguous()
            targets= targets[..., 1:].contiguous()
            batch_size, block_size, embeddings_dims = logits.shape
            logits = logits.view(batch_size*block_size, embeddings_dims)
            targets = targets.view(batch_size * block_size)
            loss = torch.nn.functional.cross_entropy(logits, targets, ignore_index = tokenizer.pad_token_id) + beta * orpo.ORloss(text)
            # print(loss.item())
            losses[k] = loss.item()
            # print("Final: ", loss.item())
            # break
        out[split] = losses.mean()
    model.train()
    return out

In [None]:

#Train the  model
from tqdm import tqdm



# train_iterator = iter(train_loader)

for step in tqdm(range(total_steps)):


    if (step  % eval_iters == 0 and step != 0) or step == total_steps - 1:
        losses = estimate_loss()
        print(f"step {step}: val loss {losses['val']:.4f}")
        # print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        wandb.log({
            "step": step,
            "val_loss": losses['val']
        })

    text  = next(train_iterator)
  
    targets = text['chosen']['input_ids']
 
    logits = model(**text['chosen']).logits

    targets= targets[..., 1:].contiguous()
    logits= logits[..., :-1, :].contiguous()
   
    batch_size, block_size, vocab_size = logits.shape
    logits = logits.view(batch_size*block_size, vocab_size)
    targets = targets.view(batch_size * block_size)

    loss = torch.nn.functional.cross_entropy(logits, targets, ignore_index = tokenizer.pad_token_id) + beta * orpo.ORloss(text)

    if(step != 0):
        total_norm_before = torch.norm(
        torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
    )
    
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Compute gradient norms after clipping
        total_norm_after = torch.norm(
            torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
        )
    
    # if(device  == 0 and step !=0):
        print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}")
        print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}")

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    # print(type(model))
    wandb.log({
        "training_loss": loss.item()
    })
    # break