# FLAN T 5 IPO Training

Minimal implementation of Implicit Preference Optimization (IPO) training for FLAN T 5 model.

In [8]:
import torch
import torch.nn.utils as utils
import json
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW
from tqdm import tqdm
import wandb
import os

class PreferenceDataset(Dataset):
    def __init__(self, json_path):
        with open(json_path, 'r') as f:
            raw_data = json.load(f)

        self.data = []
        for question, entry in raw_data.items():
            for hop, hop_data in entry["hops"].items():
                queries = hop_data["queries"]
                preferences = hop_data["preference_pairs"]
                for i, j in preferences:
                    self.data.append({
                        "question": question,
                        "preferred": queries[i],
                        "dispreferred": queries[j]
                    })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [11]:
# Replace the model loading cell with:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "google/flan-t5-small"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    torch_dtype=torch.float32,
    device_map="auto",
    low_cpu_mem_usage=True
)

model.train()
print(f"Model {model_name} loaded successfully!")

os.environ["WANDB_API_KEY"] = "57b8585a9cdb363d54a7d215dd95c824d880868b"
wandb.login()

Model google/flan-t5-small loaded successfully!


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mohitit20[0m ([33mohitit[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [12]:
def compute_logp(prompt, completion):
    """Compute log probability of completion given prompt for seq2seq model"""
    # For T5, input is the prompt, target is the completion
    input_encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
    target_encoded = tokenizer(completion, return_tensors="pt", truncation=True, max_length=64)

    device = next(model.parameters()).device
    input_ids = input_encoded.input_ids.to(device)
    input_attention_mask = input_encoded.attention_mask.to(device)
    labels = target_encoded.input_ids.to(device)

    # Replace pad tokens in labels with -100
    labels[labels == tokenizer.pad_token_id] = -100

    outputs = model(
        input_ids=input_ids,
        attention_mask=input_attention_mask,
        labels=labels
    )
    return -outputs.loss

def ipo_loss(logp_win, logp_lose, tau=0.05):
    """Compute IPO loss"""
    diff = logp_win - logp_lose - 0.5 / tau
    return (diff ** 2).mean()

# Training setup
parent_path = 'drive/MyDrive/c438_project'
dataset_path = f'{parent_path}/preference_dataset_hotpotqa_final.json'
dataset = PreferenceDataset(dataset_path)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

optimizer = AdamW(model.parameters(), lr=5e-6)
tau = 0.05
num_epochs = 3

print(f"Dataset size: {len(dataset)}")
print(f"Starting IPO training...")

Dataset size: 70073
Starting IPO training...


In [None]:
import zipfile
from google.colab import files

wandb.init(
    project="c438_project",  # Name of the project in W&B
    name="ipo_training_run",    # A specific name for this run
    config={
        "learning_rate": optimizer.defaults['lr'],
        "num_epochs": num_epochs,
        "batch_size": dataloader.batch_size,
        "model_name": model.name_or_path,
        "tau": tau,
    }
)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    global_step = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for batch_idx, batch in enumerate(pbar):
        questions = batch["question"]
        preferred = batch["preferred"]
        dispreferred = batch["dispreferred"]

        logp_w_list = []
        logp_l_list = []

        for q, w, l in zip(questions, preferred, dispreferred):
            prompt = f"Generate a search query for: {q}\nQuery: "

            logp_w = compute_logp(prompt, w.strip())
            logp_l = compute_logp(prompt, l.strip())

            logp_w_list.append(logp_w)
            logp_l_list.append(logp_l)

        logp_w_batch = torch.stack(logp_w_list)
        logp_l_batch = torch.stack(logp_l_list)

        loss = ipo_loss(logp_w_batch, logp_l_batch, tau)

        if torch.isnan(loss):
            print(f"Skipping batch {batch_idx} due to NaN loss.")
            continue

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        current_loss = loss.item()
        total_loss += current_loss
        avg_loss = total_loss / (batch_idx + 1)
        pbar.set_postfix({"loss": f"{avg_loss:.4f}"})

        wandb.log({"step_loss": current_loss, "average_loss": avg_loss})
        
        global_step = epoch * len(dataloader) + batch_idx + 1
        
        # Save model every 500 iterations (overwrite)s
        if global_step % 500 == 0:
            checkpoint_path = "/ipo_trained_model/checkpoint_500"
            os.makedirs(checkpoint_path, exist_ok=True)
            model.save_pretrained(checkpoint_path)
            tokenizer.save_pretrained(checkpoint_path)
            print(f"Model saved at step {global_step} to {checkpoint_path}")
        
        # Download model every 5000 iterations
        if global_step % 5000 == 0:
            download_path = "/ipo_trained_model/checkpoint_5000"
            os.makedirs(download_path, exist_ok=True)
            model.save_pretrained(download_path)
            tokenizer.save_pretrained(download_path)
            
            # Create zip file and download
            
            zip_filename = f"model_checkpoint_{global_step}.zip"
            with zipfile.ZipFile(zip_filename, 'w') as zipf:
                for root, dirs, files_list in os.walk(download_path):
                    for file in files_list:
                        file_path = os.path.join(root, file)
                        arcname = os.path.relpath(file_path, download_path)
                        zipf.write(file_path, arcname)
            
            files.download(zip_filename)
            print(f"Model checkpoint at step {global_step} downloaded")

    epoch_avg_loss = total_loss / len(dataloader)
    print(f"[Epoch {epoch + 1}] Average Loss: {epoch_avg_loss:.4f}")

    wandb.log({"epoch": epoch + 1, "epoch_average_loss": epoch_avg_loss})


    epoch_save_path = f"ipo_trained_model/epoch_{epoch+1}"
    os.makedirs(epoch_save_path, exist_ok=True)
    model.save_pretrained(epoch_save_path)
    tokenizer.save_pretrained(epoch_save_path)
    print(f"Epoch {epoch + 1} model saved to {epoch_save_path}")

print("Training completed!")

# --- Finish the W&B run ---
wandb.finish()


Epoch 1/3:  27%|██▋       | 19181/70073 [57:57<2:34:43,  5.48it/s, loss=91.4975]

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
def generate_query(question, max_length=30, temperature=0.8, top_p=0.9):
    """Generate a search query for the given question using the trained T5 model"""
    model.eval()

    # T5 works better with clear instruction format
    prompt = f"Generate a search query for this question: {question}"
    encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)

    device = next(model.parameters()).device
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_length,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            early_stopping=False,
            repetition_penalty=1.2,
            no_repeat_ngram_size=3
        )

    # For seq2seq models, decode the entire output (no need to slice)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Clean up the generated text
    generated_text = generated_text.strip()

    # Remove question marks and periods
    generated_text = generated_text.replace('?', '').replace('.', '')


    return generated_text if generated_text else "search query"

# Test with example questions
test_questions = [
    "Who was the first person to climb Mount Everest?",
    "What is the capital of the country where the Eiffel Tower is located?",
    "Which movie won the Academy Award for Best Picture in 2020?",
    "What is the largest planet in our solar system?",
    "Who wrote the novel '1984'?"
]

print("Testing the trained model on example questions:")
print("=" * 60)

for i, question in enumerate(test_questions, 1):
    print(f"\n{i}. Question: {question}")

    # Generate multiple queries with different parameters
    configs = [
        {"temperature": 0.5, "top_p": 0.8},
        {"temperature": 0.8, "top_p": 0.9},
        {"temperature": 1.0, "top_p": 0.95}
    ]

    for j, config in enumerate(configs):
        query = generate_query(question, **config)
        print(f"   Query {j+1}: {query}")

    print("-" * 40)

print("\nTesting completed!")