In [None]:
!pip install wandb
!pip install datasets

import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from tokenizers import Tokenizer
from pathlib import Path

from torch.utils.data import Dataset, DataLoader
import wandb
from transformers import AutoModelForCausalLM, AutoTokenizer


In [None]:
wandb.init(
            # entity = 'rajceo2031',
                        project = 'GPTJ-DPO',
                        # config = CFG,
                        # save_code = True,
                        #group = 'ANN',
                        #job_type = 'train'
)

In [3]:
#Hyperparameters

batch_size = 2
beta = 0.1
max_lr = 1e-6

In [None]:
class DPO:
  def __init__(self, ref_model, sft_model, device, beta, tokenizer):


    self.ref_model = ref_model
    self.sft_model = sft_model
    self.device=device
    self.beta = beta
    self.tokenizer = tokenizer
    self.ref_model.eval()



  def DPOloss(self, datapoint):



    self.win_prompt = datapoint['chosen']
    self.lose_prompt = datapoint['rejected']

  #Token level aggregation 
    with torch.no_grad():
      self.win_log_ref = torch.nn.functional.log_softmax(self.ref_model(**self.win_prompt).logits, dim=-1)
      self.win_log_ref = torch.gather(self.win_log_ref, -1, self.win_prompt['input_ids'].unsqueeze(-1)).squeeze(-1) #Why gather? Because its not token level stuff we care about but sequence level. Hence, we will sum up the probs of every token to get seq level but we don't want to do it for attention maksed tokens too. Hence we we will use gather() to get the ids and multiply the probs by the masked out tokens indexes.
      # print("Gather: ", self.chosen_log_probs)
      self.win_log_ref = self.win_log_ref * (self.win_prompt['attention_mask'])
      self.win_log_ref = self.win_log_ref.sum(dim=-1)
      
      self.lose_log_ref = torch.nn.functional.log_softmax(self.ref_model(**self.lose_prompt).logits, dim=-1)
      self.lose_log_ref = torch.gather(self.lose_log_ref, -1, self.lose_prompt['input_ids'].unsqueeze(-1)).squeeze(-1) #Why gather? Because its not token level stuff we care about but sequence level. Hence, we will sum up the probs of every token to get seq level but we don't want to do it for attention maksed tokens too. Hence we we will use gather() to get the ids and multiply the probs by the masked out tokens indexes.
      # print("Gather: ", self.chosen_log_probs)
      self.lose_log_ref = self.lose_log_ref * (self.lose_prompt['attention_mask'])
      self.lose_log_ref = self.lose_log_ref.sum(dim=-1)
      
    self.win_log_sft = torch.nn.functional.log_softmax(self.sft_model(**self.win_prompt).logits, dim=-1)
    self.win_log_sft = torch.gather(self.win_log_sft, -1, self.win_prompt['input_ids'].unsqueeze(-1)).squeeze(-1) #Why gather? Because its not token level stuff we care about but sequence level. Hence, we will sum up the probs of every token to get seq level but we don't want to do it for attention maksed tokens too. Hence we we will use gather() to get the ids and multiply the probs by the masked out tokens indexes.
    self.win_log_sft = self.win_log_sft * (self.win_prompt['attention_mask'])
    self.win_log_sft = self.win_log_sft.sum(dim=-1)
    
    self.lose_log_sft = torch.nn.functional.log_softmax(self.sft_model(**self.lose_prompt).logits, dim=-1)
    self.lose_log_sft = torch.gather(self.lose_log_sft, -1, self.lose_prompt['input_ids'].unsqueeze(-1)).squeeze(-1) #Why gather? Because its not token level stuff we care about but sequence level. Hence, we will sum up the probs of every token to get seq level but we don't want to do it for attention maksed tokens too. Hence we we will use gather() to get the ids and multiply the probs by the masked out tokens indexes.
    self.lose_log_sft = self.lose_log_sft * (self.lose_prompt['attention_mask'])
    self.lose_log_sft = self.lose_log_sft.sum(dim=-1)

    self.diff1 = self.win_log_sft - self.win_log_ref
    self.diff2 = self.lose_log_sft - self.lose_log_ref

    self.final = -nn.functional.logsigmoid(self.beta *(self.diff1 - self.diff2)).mean() #Remember we have to maximize the rewards thus minimizing the negative sign! Also, since the var of rewards could be very much, we take mean so as to have a notion of normalizing it!

    # sft_model.train()
    return self.final



In [5]:
# !huggingface-cli login
from google.colab import userdata
HF_TOKEN = userdata.get('HF_TOKEN')


In [6]:
device='cuda:0'

In [7]:
torch.cuda.set_device(device)

In [8]:

sft_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct", token=HF_TOKEN, device_map=device)
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct", token=HF_TOKEN, device_map=device)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct", token=HF_TOKEN, device_map=device)

In [9]:
from datasets import load_dataset, Dataset

train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train", token=HF_TOKEN)
val_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="test", token=HF_TOKEN)


In [11]:
def dpo_collate_fn_merged_prompt(batch):

    merged_chosen_prompts = []
    merged_rejected_prompts = []

    for sample in batch:

        # print(sample)

        # Extract and merge chosen response
        prompt = sample['prompt']
        chosen_data = sample['chosen']
        chosen_data = "Instruction: " + prompt + "\n" + "Output: " + chosen_data[1]['content'] + "\n"
        # Extract and merge rejected response
        rejected_data = sample['rejected']
        rejected_data =  "Instruction: " + prompt + "\n" + "Output: " + rejected_data[1]['content'] + "\n"

        # print(chosen_data)
        # print(rejected_data)
        merged_chosen_prompts.append(chosen_data)


        merged_rejected_prompts.append(rejected_data)

    tokenized_win_prompt = tokenizer(merged_chosen_prompts, max_length = 1024, padding='max_length', truncation=True, return_tensors="pt").to(device)

    tokenized_lose_prompt = tokenizer(merged_rejected_prompts, max_length = 1024, truncation=True, padding='max_length', return_tensors="pt").to(device)



    return {
        # 'prompt': prompts, # Still return original prompts for potential use
        'chosen': tokenized_win_prompt, # List of merged prompt-chosen texts
        'rejected': tokenized_lose_prompt # List of merged prompt-rejected texts
    }

In [12]:


from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=dpo_collate_fn_merged_prompt)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=dpo_collate_fn_merged_prompt)

In [None]:

# Optimizer setup and scheduler steup
sft_model.train()
optimizer = torch.optim.AdamW(sft_model.parameters(), lr=max_lr)

total_steps = 3000
eval_iters = 20

dpo_loss = DPO(ref_model, sft_model, device, beta, tokenizer)



val_iterator = iter(val_loader)
train_itertaor = iter(train_loader)

@torch.inference_mode()
def estimate_loss():
    loader = None
    out = {}
    sft_model.eval()
    for split in ['train', 'val']:
        if(split == 'train'):
            loader = train_itertaor

        elif (split == 'val'):
            loader = val_iterator

        
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):

            datapoint = next(loader)


            loss = dpo_loss.DPOloss(datapoint)

            losses[k] = loss.item()
        out[split] = losses.mean()
    sft_model.train()
    return out

In [None]:

#Train the  model
from tqdm import tqdm



train_iterator = iter(train_loader)

for step in tqdm(range(total_steps)):


    if (step  % eval_iters == 0 and step != 0) or step == total_steps - 1:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        wandb.log({
            "step": step,
            "training_loss": losses['train'],
            "val_loss": losses['val']
        })

    text  = next(train_iterator)


    loss = dpo_loss.DPOloss(text)


    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
