In [None]:
!pip install -q transformers datasets accelerate bitsandbytes peft trl torch

In [None]:
import torch
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    pipeline,
)
from datasets import Dataset
from trl import DPOTrainer, SFTTrainer, RewardTrainer
import random
import copy
from tqdm.auto import tqdm
import numpy as np

llm_model_name = "gpt2"

rm_model_name = "distilbert-base-uncased"

num_apo_rounds = 5
rm_update_epochs =1
lm_update_epochs =1
dpo_beta =0.1
rm_beta2 =0.5
RM_LEARNING_RATE = 5e-5
LLM_LEARNING_RATE = 1e-5
BATCH_SIZE = 2
MAX_LENGTH = 128 

In [None]:
#dataset prep

from datasets import Dataset

# D_p: Original Human Preference Data {(prompt, chosen, rejected)}
# This data is used to ensure the RM still aligns with general human preferences
prompts_dp = [
    "Explain quantum physics in simple terms.",
    "What are the best travel destinations for summer?",
    "Write a short story about a friendly robot.",
    "How does a car engine work?",
    "What's the best way to learn a new language?",
    "Describe the process of photosynthesis.",
    "Suggest a good book for a teenager.",
    "What are the benefits of regular exercise?",
    "Explain blockchain technology.",
    "Give me some tips for public speaking."
]
chosen_responses_dp = [
    "Quantum physics is about tiny particles behaving weirdly, like being in multiple places at once.",
    "For summer, consider coastal Italy for beaches and culture, or national parks for hiking.",
    "Unit 7 beeped cheerfully, offering a cup of tea to its new human friend.",
    "A car engine works by burning fuel to create small explosions that push pistons, turning a crankshaft.",
    "Immersing yourself in the language, practicing speaking daily, and using flashcards are effective ways to learn.",
    "Photosynthesis is how plants convert sunlight, water, and carbon dioxide into food and oxygen.",
    "I'd recommend 'The Hunger Games' for its engaging plot and strong characters.",
    "Regular exercise improves cardiovascular health, boosts mood, and helps with weight management.",
    "Blockchain is a decentralized, secure, and transparent digital ledger used for recording transactions.",
    "To improve public speaking, practice often, know your material, and engage with your audience."
]
rejected_responses_dp = [
    "It's too complicated for you.",
    "Just stay home, it's cheaper.",
    "The robot malfunctioned and exploded.",
    "It's magic, you wouldn't understand.",
    "Learning languages is pointless.",
    "Plants just grow, that's it.",
    "Reading is boring.",
    "Exercise is for athletes only.",
    "It's just a fad, ignore it.",
    "Don't speak in public, it's terrifying."
]

dataset_dp_dict = {
    "prompt": prompts_dp,
    "chosen": chosen_responses_dp,
    "rejected": rejected_responses_dp,
}
dataset_dp = Dataset.from_dict(dataset_dp_dict)

# D_gold: Golden Responses Data {(prompt, golden_response)}
# These are high-quality responses the RM should learn to prefer over the LLM's current output
prompts_gold = [
    "What's the capital of France?",
    "Describe a beautiful sunset.",
    "Suggest a healthy breakfast idea.",
    "Who wrote 'Romeo and Juliet'?",
    "Explain the concept of gravity.",
    "What is the largest ocean on Earth?",
    "Name a famous historical figure and their achievement.",
    "How do airplanes fly?",
    "What are the main causes of climate change?",
    "Provide a simple recipe for scrambled eggs."
]
golden_responses = [
    "The capital of France is Paris.",
    "The sky was painted in hues of orange, pink, and purple as the sun dipped below the horizon, casting long shadows.",
    "A bowl of oatmeal with fresh berries and a sprinkle of nuts is a great healthy breakfast.",
    "William Shakespeare wrote 'Romeo and Juliet'.",
    "Gravity is a fundamental force of nature that attracts any objects with mass or energy towards each other.",
    "The Pacific Ocean is the largest ocean on Earth.",
    "Marie Curie was a pioneering physicist and chemist who conducted groundbreaking research on radioactivity.",
    "Airplanes fly by generating lift, primarily from their wings, which counteracts the force of gravity.",
    "The main causes of climate change are the emission of greenhouse gases from human activities like burning fossil fuels and deforestation.",
    "To make scrambled eggs, whisk eggs with a splash of milk, then cook in a lightly buttered pan over medium heat, stirring until set."
]
dataset_gold_dict = {
    "prompt": prompts_gold,
    "golden_response": golden_responses
}
dataset_gold = Dataset.from_dict(dataset_gold_dict)

# D_Q: Prompts for LLM to generate responses for alignment
# Can be the same as prompts_gold or a larger diverse set
prompts_llm_align = prompts_gold + prompts_dp

In [None]:
#to check datasets
print("--- D_p Dataset ---")
print(dataset_dp)
print("\n--- D_gold Dataset ---")
print(dataset_gold)
print("\n--- prompts_llm_align ---")
print(prompts_llm_align)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#rm and llm model
rm_tokenizer = AutoTokenizer.from_pretrained(RM_MODEL_NAME)
if rm_tokenizer.pad_token is None:
    rm_tokenizer.pad_token = rm_tokenizer.eos_token
rm_model = AutoModelForSequenceClassification.from_pretrained(RM_MODEL_NAME, num_labels=1).to(device)


llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
if llm_tokenizer.pad_token is None:
    llm_tokenizer.pad_token = llm_tokenizer.eos_token
llm_model = AutoModelForCausalLM.from_pretrained(LLM_MODEL_NAME).to(device)

llm_ref_model = AutoModelForCausalLM.from_pretrained(LLM_MODEL_NAME).to(device)
llm_ref_model.eval()

In [None]:
def generate_responses(prompts, model, tokenizer, num_of_responses=1):
    responses =[]
    model.eval()
    for prompt in prompts:
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH-10).to(device)
        gen_outputs = model.generate(
            **inputs,
            max_new_tokens=20, # Generate short responses for demo
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            do_sample=True, top_k=50, top_p=0.95
        )
        decoded_responses = [tokenizer.decode(output, skip_special_tokens=True) for output in gen_outputs]
        responses.extend(decoded_responses) # taking the first generated response
    return responses    

In [None]:
def get_rewards(texts, model, tokenizer):
    rewards = []
    model.eval()
    with torch.no_grad():
        for i in range(0, len(texts), BATCH_SIZE):
            batch_texts = texts[i:i+BATCH_SIZE]
            inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH).to(device)
            outputs = model(**inputs)
            rewards.extend(outputs.logits.squeeze().tolist())
    return rewards

In [None]:

def bradley_terry_loss(scores_preferred, scores_rejected):
    # -log(sigmoid(score_preferred - score_rejected))
    return -torch.nn.functional.logsigmoid(scores_preferred - scores_rejected).mean()

In [None]:
print("Starting APO Training...")
for apo_round in range(NUM_APO_ROUNDS):
    print(f"\n--- APO Round {apo_round + 1}/{NUM_APO_ROUNDS} ---")

    # --- RM Optimization Step ---
    print("--- RM Optimization Step ---")
    rm_model.train()
    optimizer_rm= torch.optim.AdamW(rm_model.parameters(), lr=RM_LEARNING_RATE))

    #generating D_APO samples
    current_llm_responses= generate_responses(dataset_gold["prompt"], llm_model, llm_tokenizer)
        
    d_apo_prompts = []
    d_apo_chosen = [] # golden responses
    d_apo_rejected = [] # llm_generated responses

    for i, prompt in enumerate(dataset_gold["prompt"]):
        d_apo_prompts.append(prompt)
        d_apo_chosen.append(dataset_gold["golden_response"][i])
        d_apo_rejected.append(current_llm_responses[i])
    
    dataset_apo_dict = {
        "prompt": d_apo_prompts,
        "chosen": d_apo_chosen, 
        "rejected": d_apo_rejected 
    }