# GPT-2 IPO Training

Minimal implementation of Implicit Preference Optimization (IPO) training for GPT-2 model.

In [None]:
import torch
import torch.nn.utils as utils
import json
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW
from tqdm import tqdm

class PreferenceDataset(Dataset):
    def __init__(self, json_path):
        with open(json_path, 'r') as f:
            raw_data = json.load(f)

        self.data = []
        for question, entry in raw_data.items():
            for hop, hop_data in entry["hops"].items():
                queries = hop_data["queries"]
                preferences = hop_data["preference_pairs"]
                for i, j in preferences:
                    self.data.append({
                        "question": question,
                        "preferred": queries[i],
                        "dispreferred": queries[j]
                    })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
model_name = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float32,
    device_map="auto",
    low_cpu_mem_usage=True
)

model.train()
print(f"Model {model_name} loaded successfully!")

In [None]:
def compute_logp(prompt, completion):
    """Compute log probability of completion given prompt"""
    full_input = prompt + completion
    encoded = tokenizer(full_input, return_tensors="pt", padding=True, truncation=True, max_length=128)
    
    device = next(model.parameters()).device
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    prompt_encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
    prompt_len = prompt_encoded.input_ids.shape[-1]
    
    labels = input_ids.clone()
    labels[:, :prompt_len] = -100
    
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    return -outputs.loss

def ipo_loss(logp_win, logp_lose, tau=0.05):
    """Compute IPO loss"""
    diff = logp_win - logp_lose - 0.5 / tau
    return (diff ** 2).mean()

# Training setup
dataset_path = 'preference_dataset_hotpotqa_final.json'
dataset = PreferenceDataset(dataset_path)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

optimizer = AdamW(model.parameters(), lr=5e-6)
tau = 0.05
num_epochs = 3

print(f"Dataset size: {len(dataset)}")
print(f"Starting IPO training...")

In [None]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for batch_idx, batch in enumerate(pbar):
        questions = batch["question"]
        preferred = batch["preferred"]
        dispreferred = batch["dispreferred"]
        
        logp_w_list = []
        logp_l_list = []
        
        for q, w, l in zip(questions, preferred, dispreferred):
            prompt = f"Generate a search query for: {q}\nQuery: "
            
            logp_w = compute_logp(prompt, w.strip())
            logp_l = compute_logp(prompt, l.strip())
            
            logp_w_list.append(logp_w)
            logp_l_list.append(logp_l)
        
        logp_w_batch = torch.stack(logp_w_list)
        logp_l_batch = torch.stack(logp_l_list)
        
        loss = ipo_loss(logp_w_batch, logp_l_batch, tau)
        
        optimizer.zero_grad()
        loss.backward()
        utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        avg_loss = total_loss / (batch_idx + 1)
        pbar.set_postfix({"loss": f"{avg_loss:.4f}"})
    
    print(f"[Epoch {epoch + 1}] Average Loss: {total_loss / len(dataloader):.4f}")

print("Training completed!")