In [None]:
BASE_MODEL = 'google/gemma-3-270m-it'
DATASET_NAME = 'Dahoas/rm-static'
from transformers import AutoTokenizer , AutoModelForCausalLM
from datasets import load_dataset

from trl import DPOConfig , DPOTrainer
 
import os 

FILE_PATH = os.path.dirname(os.path.abspath(""))

import logging
logging.basicConfig(level = logging.CRITICAL)

import warnings
warnings.filterwarnings(action = 'ignore')


In [None]:
# preparing dataset 
train_dataset = load_dataset(DATASET_NAME , cache_dir = FILE_PATH, split='train[:5%]')


print(type(train_dataset))

# train_dataset = train_dataset[0:500] #only take first 500 examples 
# print(type(train_dataset)) 


In [None]:
train_dataset[0]

In [None]:
# model and tokenizer 


policy_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, cache_dir = FILE_PATH, attn_implementation='eager')
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL , cache_dir = FILE_PATH)


In [None]:
# vizualize_tokenizer 
import json

cfg_path = os.path.join(FILE_PATH+'/Dahoas__rm-static/', "tokenizer_config.json")
print('cfg path', cfg_path)
if os.path.isfile(cfg_path):
    print("tokenizer_config.json:")
    print(json.dumps(json.load(open(cfg_path)), indent=2))

prompt = 'Hi Chat, How are you doing ?'
prompt2 = 'Hello world ! I am here'
out2 = tokenizer(prompt2)
out = tokenizer(prompt)

print(out)
print(out2)

print(tokenizer.bos_token_id, tokenizer.pad_token_id)
print(policy_model.config.bos_token_id, policy_model.config.pad_token_id)


if hasattr(policy_model, "generation_config") and policy_model.generation_config is not None:
    policy_model.generation_config.bos_token_id = tokenizer.bos_token_id
    policy_model.generation_config.pad_token_id = tokenizer.pad_token_id


print('\n\n---- TOKENIZER AND MODELS ----')
print("tokenizer vocab_size:", len(tokenizer))
print("model.config.vocab_size:", getattr(policy_model.config, "vocab_size", None))
print("embeddings rows:", policy_model.get_input_embeddings().weight.shape[0])

print("tokenizer special map:", tokenizer.special_tokens_map)
print("tokenizer ids: bos, pad, eos =", tokenizer.bos_token_id, tokenizer.pad_token_id, tokenizer.eos_token_id)
print("model ids:     bos, pad, eos =", policy_model.config.bos_token_id, policy_model.config.pad_token_id, policy_model.config.eos_token_id)
print("generation_config:", getattr(policy_model, "generation_config", None))


model config just stores what were the bos , eos tokens used while training this model (not the actual token strings) so if you are planning to use another tokenizer set to finetune , match these up and if possible try to use the same tokenizer  

In [None]:
policy_model.generation_config.eos_token_id = ['1']
print("generation_config:", getattr(policy_model, "generation_config", None))


In [None]:
# find the config file for the tokenizer 
# find the extra id & token
extra_id = len(tokenizer) - 1
extra_token = tokenizer.convert_ids_to_tokens(extra_id)
print("extra_id:", extra_id, "extra_token:", extra_token)

# check if it is an added token or a special token
print("is special token?", extra_token in tokenizer.all_special_tokens)
print("added tokens:", tokenizer.get_added_vocab())   # dict token->id for user-added tokens
print("appears in vocab?", extra_token in tokenizer.get_vocab())


In [None]:
policy_model.config.tokenizer_class


In [None]:

config = DPOConfig(
    do_train = True,
    per_device_train_batch_size=8,
    learning_rate=5e-8,
    bf16 = False,
    fp16=False,
    logging_strategy='steps', 
    logging_steps=2, 
)


trainer = DPOTrainer(
    model = policy_model, 
    args = config, 
    train_dataset = train_dataset
)


trainer.train()


In [None]:
del policy_model , tokenizer

import gc
gc.collect()


## DPO from scratch 

In [None]:
del model , tokenizer

import gc 
gc.collect()

In [None]:
# lets now self create the model and see what actually works 
'''
Preference / Reward model 
Here it learns to reward the 

** Dataset ** 

Prompt : Why is the color of sky blue ?
Chosen : Its because of the scattering of blue wavelength by air molecules 
Rejected : Because sky likes blue color


The reward model, is a linear layer at top, that learn to output whether this response is good or not 

Input1 to model : {Prompt + Chosen} ,  Output1 : 1  
Input2 to model : {Prompt + Rejected}  , Output2: 0 


Here the model learns to pick up a good response !

This is the logic that happens in PPO loop ! So there is nothing as seperate training for a PPO model directly just use this as a Value head  
'''

# Single runnable cell — minimal, targeted fixes only (copy-paste)
import os
import gc
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    PreTrainedModel,
    GenerationConfig
)
from transformers.modeling_outputs import SequenceClassifierOutput
from trl import AutoModelForCausalLMWithValueHead
from transformers.modeling_outputs import ModelOutput
import torch.nn.functional as F

from dotenv import load_dotenv
load_dotenv()

# ------------------ USER-CHOICE (preserved) ------------------
BASE_MODEL = "google/gemma-3-270m-it"
DATASET_LINK = "Dahoas/rm-static"

import os 
FILE_PATH = os.path.dirname(os.path.abspath(""))
print(FILE_PATH)




In [None]:

model = AutoModelForCausalLM.from_pretrained(BASE_MODEL , cache_dir = FILE_PATH)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, cache_dir = FILE_PATH)

dataset = load_dataset(DATASET_LINK, split="train[:20]", cache_dir=FILE_PATH)



In [None]:

def tokenizer_mapping(example): 
  print("example is ", example)
  prompt = example['prompt']
  response = example['response']
  rejected = example['rejected']
  
  prompt_ids = tokenizer(prompt, add_special_tokens=False).input_ids
  response_ids   = tokenizer(response + tokenizer.eos_token, add_special_tokens=False).input_ids
  rejected_ids = tokenizer(rejected + tokenizer.eos_token, add_special_tokens= False).input_ids


  question_correct_ids  = prompt_ids + response_ids
  question_rejected_ids = prompt_ids + rejected_ids


  attention_mask_correct_ids  = [1 ]* len(question_correct_ids)
  attention_mask_rejected_ids = [1] * len(question_rejected_ids)


  resp_mask_corrected = [0] * len(prompt_ids) + [1] * len(response_ids)
  resp_mask_rejected = [0] * len(prompt_ids) + [1] * len(rejected_ids)


  assert len(question_correct_ids) == len(resp_mask_corrected) ,"These should be equal but are unequal !"  
  assert len(question_rejected_ids) == len(resp_mask_rejected) ,"These should be equal but are unequal !"  

  return { 
    'correct_input_ids'         : torch.tensor(data = question_correct_ids , dtype= torch.int32), 
    'rejected_input_ids'        : torch.tensor(data = question_rejected_ids , dtype= torch.int32), 
    'attention_mask_corrected'       : torch.tensor(data = attention_mask_correct_ids , dtype=torch.int32), 
    'attention_mask_rejected'        : torch.tensor(data = attention_mask_rejected_ids, dtype=torch.int32), 
    'resp_mask_corrected'       : torch.tensor(data = resp_mask_corrected , dtype = torch.int32),
    'resp_mask_rejected'        : torch.tensor(data = resp_mask_rejected , dtype = torch.int32),
  }

revised_dataset = dataset.map(tokenizer_mapping)

print(revised_dataset)


In [None]:
dataset[0]

In [None]:
len(revised_dataset[0].get('input_ids')) , len(revised_dataset[0].get('attention_mask')) , len(revised_dataset[1].get('input_ids')) , len(revised_dataset[1].get('attention_mask')) 
len(revised_dataset['input_ids']) , len(revised_dataset['attention_mask'])

DPO : here we have the choosen response and the rejected one , that is , what is the prob of getting a choosen response and what is the prob of getting a rejected response

x: input / prompt 
yc : correct output
yr : rejected output

> Choosen : Prob (Y_c | X ) , Prob (Y_r | X )

Prob (Y_c | X ) : sum of prob of the individual output tokens ..   

> rejected  : 


In [None]:
import torch.nn.functional as F

logp = None

def next_token_prediction_values(logp:torch.Tensor, idx :int ) -> torch.Tensor:
    return logp[:, idx, :] # output will be torch tensor 

# resp_mask this is the response mask that means it tells which parts of the prompt to take in the input and which parts not to take !!
def seq_logprob(model , input_ids , attention_mask, resp_mask):
    '''
    This is the log prob, this is the 
    '''
    
    if type(input_ids) is not torch.Tensor:
        print(input_ids)
        print(type(input_ids))
        input_ids = torch.tensor(data = input_ids, dtype = torch.int32)
        attention_mask = torch.tensor(data = attention_mask, dtype=torch.int32)

    out = model(input_ids = input_ids.unsqueeze(0), attention_mask = attention_mask.unsqueeze(0))
    print("out logits shape is ", out.logits.shape) # outputs the probability for all the input tokens , the next one will only be from the last one

    # In the DPO, I am already passing the input to the model / the output also in the prompt itself ..  
    logp = F.log_softmax(out.logits, dim = -1)
    print("full logp shape is : ", logp.shape) # log outputs for all the next tokens  

    # now output the dpo probs
        

    # Now I need to find the last one, so we are taking the probability of the whole sequence .. the input + output ( prompt + response) 
    # global logp
    logp = F.log_softmax(out.logits, dim = -1)
    print('logp shape is : ' , logp.shape) # 1 x 163 x VS

    # logp: here tells the probabity values of the next tokens coming out from these 
    # !! next_token_prediction_values(logp, idx)

    output_model_prob = []
    for i,val in enumerate(resp_mask):
        if val == 0: 
            continue
        else:
            # print("the value is : ", val)
            prob_distribution = next_token_prediction_values(logp, i-1).squeeze(0) # so to get prev one -> 1xVS 
            # print("prob distribution is", prob_distribution.shape)
            correct_token_idx = input_ids[i] # correct / incorrect input idx ..  
            val = prob_distribution[correct_token_idx]
            print('Token should have been : ', tokenizer.decode([correct_token_idx]), ' The probability I am getting out from this is : ', val.item()) 
            output_model_prob.append(val)

    print(type(output_model_prob))
    
    return torch.stack(output_model_prob).mean()

    # --- # --- #


# seq_logprob(model, input_ids = revised_dataset[0]['input_ids'], attention_mask = revised_dataset[0]['attention_mask'], resp_mask= revised_dataset[0]['resp_mask'])



"<bos><start_of_turn>user\nHello, how are you?<end_of_turn>\n<start_of_turn>model\nI'm doing great. How can I help you today?<end_of_turn>\n<start_of_turn>user\nI'd like to show off how chat templating works!<end_of_turn>\n<start_of_turn>model\n"

In [None]:

# policy_model = AutoModelForCausalLM.from_pretrained()
# reference_model = AutoModelForCausalLM.from_pretrained()
#     features: ['prompt', 'response', 'chosen', 'rejected', 'correct_input_ids', 'rejected_input_ids', 'attention_mask_corrected', 'attention_mask_rejected', 'resp_mask_corrected', 'resp_mask_rejected'],

import torch.nn as nn
import torch.nn.functional as F


import copy
policy_model = model
reference_model = copy.deepcopy(model).eval()

for param in reference_model.parameters():
    param.requires_grad = False
    param.grad = None

optim = torch.optim.AdamW(policy_model.parameters(), lr = 5e-8)

def dpo_loss(policy_correct_log, policy_rejected_log ,reference_correct_log, reference_rejected_log, beta = 0.1):
    # these are all log 
    v1 = beta * (policy_rejected_log - reference_rejected_log - policy_correct_log + reference_correct_log)
    return -F.logsigmoid(v1).mean()


# dataset = dataset.map(tokenizer_mapping)

device = 'cpu'
policy_model = policy_model.to(device)
reference_model = reference_model.to(device) 

def training_loop(epochs = 1): 
    print('Starting the training loop')
    for idx, data in enumerate(dataset):
        print("data is : ", data)
        optim.zero_grad() # we dont like accumulation !! 
        correct_input_ids = data['correct_input_ids']
        rejected_input_ids = data['rejected_input_ids']
        
        attention_mask_corrected = data['attention_mask_corrected']
        attention_mask_rejected = data['attention_mask_rejected']

        resp_mask_corrected = data['resp_mask_corrected']
        resp_mask_rejected = data['resp_mask_rejected']

        # Forward pass (policy)
        policy_correct_log = seq_logprob(policy_model, correct_input_ids, attention_mask_corrected, resp_mask_corrected)
        policy_rejected_log = seq_logprob(policy_model, rejected_input_ids, attention_mask_rejected, resp_mask_rejected)

        # Forward pass (reference, no grad)
        with torch.inference_mode():
            reference_correct_log = seq_logprob(reference_model, correct_input_ids, attention_mask_corrected, resp_mask_corrected)
            reference_rejected_log = seq_logprob(reference_model, rejected_input_ids, attention_mask_rejected, resp_mask_rejected)

        # Loss + backward
        loss = dpo_loss(policy_correct_log, policy_rejected_log, reference_correct_log, reference_rejected_log)
        loss.backward()
        optim.step()

        if idx % 50 == 0:
            print(f"Epoch {epochs} | Step {idx} | Loss {loss.item():.4f}")


training_loop(1)
