In [None]:
!pip install datasets transformers torch==2.7 tqdm numpy pylate bitsandbytes accelerate huggingface_hub wandb torchvision

# FLAN T 5 IPO Training

Minimal implementation of Implicit Preference Optimization (IPO) training for FLAN T 5 model.

# Import libraries and define Preference Dataset Class

In [None]:
import torch
import torch.nn.utils as utils
import json
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW
from tqdm import tqdm
import wandb
import os

# Modified PreferenceDataset to include hop information
class PreferenceDataset(Dataset):
    def __init__(self, json_path):
        with open(json_path, 'r') as f:
            raw_data = json.load(f)

        self.data = []
        for question, entry in raw_data.items():
            for hop, hop_data in entry["hops"].items():
                queries = hop_data["queries"]
                preferences = hop_data["preference_pairs"]
                for i, j in preferences:
                    self.data.append({
                        "question": question,
                        "preferred": queries[i],
                        "dispreferred": queries[j],
                        "hop": hop  # Add hop information
                    })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Import Base Model

In [11]:
# Replace the model loading cell with:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "google/flan-t5-small"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    torch_dtype=torch.float32,
    device_map="auto",
    low_cpu_mem_usage=True
)

model.train()
print(f"Model {model_name} loaded successfully!")

os.environ["WANDB_API_KEY"] = "57b8585a9cdb363d54a7d215dd95c824d880868b"
wandb.login()

Model google/flan-t5-small loaded successfully!


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mohitit20[0m ([33mohitit[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# Loss Calculation Functions

In [None]:
def compute_logp(prompt, completion):
    """Compute log probability of completion given prompt for seq2seq model"""
    # For T5, input is the prompt, target is the completion
    input_encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
    target_encoded = tokenizer(completion, return_tensors="pt", truncation=True, max_length=64)

    device = next(model.parameters()).device
    input_ids = input_encoded.input_ids.to(device)
    input_attention_mask = input_encoded.attention_mask.to(device)
    labels = target_encoded.input_ids.to(device)

    # Replace pad tokens in labels with -100
    labels[labels == tokenizer.pad_token_id] = -100

    outputs = model(
        input_ids=input_ids,
        attention_mask=input_attention_mask,
        labels=labels
    )
    return -outputs.loss

def ipo_loss(logp_win, logp_lose, tau=0.05):
    """Compute IPO loss"""
    diff = logp_win - logp_lose - 0.5 / tau
    return (diff ** 2).mean()

# Modified IPO loss function with weighting
def ipo_loss_weighted(logp_win, logp_lose, weights, tau=0.05):
    """Compute weighted IPO loss based on hop importance"""
    diff = logp_win - logp_lose - 0.5 / tau
    weighted_loss = weights * (diff ** 2)
    return weighted_loss.mean()



Dataset size: 70073
Starting IPO training...


# Define Training Setup

- set the preference dataset path here

In [None]:
# Training setup
parent_path = 'drive/MyDrive/c438_project'
dataset_path = f'{parent_path}/preference_dataset_hotpotqa_final.json'
dataset = PreferenceDataset(dataset_path)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

optimizer = AdamW(model.parameters(), lr=5e-6)
tau = 0.05
num_epochs = 3

print(f"Dataset size: {len(dataset)}")
print(f"Starting IPO training...")

In [None]:
import zipfile
from google.colab import files
import os

# Set environment variable to catch CUDA errors early
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

model_configuration = "base_flan-t5-small"  # Change this to your model configuration

wandb.init(
    project="c438_project",  # Name of the project in W&B
    name="ipo_training_run",    # A specific name for this run
    config={
        "learning_rate": optimizer.defaults['lr'],
        "num_epochs": num_epochs,
        "batch_size": dataloader.batch_size,
        "model_name": model.name_or_path if hasattr(model, 'name_or_path') else "flan-t5-small",
        "tau": tau,
    }
)

# Modified training loop with hop-aware weighting and accumulated loss
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    accumulated_loss = 0.0
    accumulation_steps = 4  # Accumulate gradients over 4 steps
    skipped_batches = 0
    
    # Define hop weights - first hop gets higher weight
    hop_weights = {
        "hop_1": 1.0,    # Full weight for first hop
        "hop_2": 0.7,    # Reduced weight for second hop
        "hop_3": 0.5,    # Even lower for third hop (if exists)
        "hop_4": 0.3     # Minimal weight for fourth hop (if exists)
    }
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for batch_idx, batch in enumerate(pbar):
        try:
            questions = batch["question"]
            preferred = batch["preferred"]
            dispreferred = batch["dispreferred"]
            hops = batch["hop"]  # Get hop information

            # Compute log probabilities for all preferences in batch
            logp_w_list = []
            logp_l_list = []
            weights_list = []

            batch_has_error = False
            for q, w, l, hop in zip(questions, preferred, dispreferred, hops):
                try:
                    prompt = f"Generate a search query for: {q}\nQuery: "
                    
                    logp_w = compute_logp(prompt, w.strip())
                    logp_l = compute_logp(prompt, l.strip())
                    
                    # Check for invalid logprobs
                    if torch.isnan(logp_w) or torch.isnan(logp_l) or torch.isinf(logp_w) or torch.isinf(logp_l):
                        print(f"Invalid logprob detected, skipping batch {batch_idx}")
                        batch_has_error = True
                        break
                    
                    # Get weight for this hop (default to 0.5 if hop not found)
                    weight = hop_weights.get(hop, 0.5)
                    
                    logp_w_list.append(logp_w)
                    logp_l_list.append(logp_l)
                    weights_list.append(weight)
                    
                except RuntimeError as e:
                    if "CUDA" in str(e):
                        print(f"CUDA error in batch {batch_idx}: {e}")
                        print("Clearing CUDA cache and skipping batch...")
                        torch.cuda.empty_cache()
                        batch_has_error = True
                        break
                    else:
                        raise e
                except Exception as e:
                    print(f"Unexpected error in batch {batch_idx}: {e}")
                    batch_has_error = True
                    break

            if batch_has_error or len(logp_w_list) == 0:
                skipped_batches += 1
                continue

            # Stack and compute IPO loss
            try:
                logp_w_batch = torch.stack(logp_w_list)
                logp_l_batch = torch.stack(logp_l_list)
                weights_batch = torch.tensor(weights_list, device=logp_w_batch.device)
                
                # Compute IPO loss with hop weighting
                loss = ipo_loss_weighted(logp_w_batch, logp_l_batch, weights_batch, tau)
                
                # Normalize loss by accumulation steps
                loss = loss / accumulation_steps
                
                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"Skipping batch {batch_idx} due to invalid loss.")
                    skipped_batches += 1
                    continue

                # Accumulate gradients
                loss.backward()
                accumulated_loss += loss.item()
                
                # Apply gradients every accumulation_steps
                if (batch_idx + 1) % accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()
                    optimizer.zero_grad()
                    
                    # Log the accumulated loss with hop distribution
                    avg_hop_weight = torch.mean(weights_batch).item()
                    wandb.log({
                        "step_loss": accumulated_loss, 
                        "batch_idx": batch_idx,
                        "avg_hop_weight": avg_hop_weight,
                        "skipped_batches": skipped_batches
                    })
                    accumulated_loss = 0.0

                current_loss = loss.item() * accumulation_steps  # Scale back for display
                total_loss += current_loss
                avg_loss = total_loss / (batch_idx + 1 - skipped_batches) if (batch_idx + 1 - skipped_batches) > 0 else 0
                pbar.set_postfix({"loss": f"{avg_loss:.4f}", "skipped": skipped_batches})
                
            except RuntimeError as e:
                if "CUDA" in str(e):
                    print(f"CUDA error during loss computation in batch {batch_idx}: {e}")
                    torch.cuda.empty_cache()
                    skipped_batches += 1
                    continue
                else:
                    raise e
        
        except Exception as e:
            print(f"Critical error in batch {batch_idx}: {e}")
            skipped_batches += 1
            continue
        
        global_step = epoch * len(dataloader) + batch_idx + 1
        
        # Save model every 500 iterations (overwrite)
        if global_step % 500 == 0:
            try:
                checkpoint_path = "/ipo_trained_model/checkpoint_500"
                os.makedirs(checkpoint_path, exist_ok=True)
                model.save_pretrained(checkpoint_path)
                tokenizer.save_pretrained(checkpoint_path)
                print(f"Model saved at step {global_step} to {checkpoint_path}")
            except Exception as e:
                print(f"Error saving checkpoint: {e}")
        
        # Download model every 5000 iterations
        if global_step % 5000 == 0:
            try:
                download_path = "/ipo_trained_model/checkpoint_5000"
                os.makedirs(download_path, exist_ok=True)
                model.save_pretrained(download_path)
                tokenizer.save_pretrained(download_path)
                
                # Create zip file and download
                zip_filename = f"model_checkpoint_{global_step}.zip"
                with zipfile.ZipFile(zip_filename, 'w') as zipf:
                    for root, dirs, files_list in os.walk(download_path):
                        for file in files_list:
                            file_path = os.path.join(root, file)
                            arcname = os.path.relpath(file_path, download_path)
                            zipf.write(file_path, arcname)
                
                files.download(zip_filename)
                print(f"Model checkpoint at step {global_step} downloaded")
            except Exception as e:
                print(f"Error creating checkpoint download: {e}")

    effective_batches = len(dataloader) - skipped_batches
    epoch_avg_loss = total_loss / effective_batches if effective_batches > 0 else 0
    print(f"[Epoch {epoch + 1}] Average Loss: {epoch_avg_loss:.4f}, Skipped Batches: {skipped_batches}")

    wandb.log({
        "epoch": epoch + 1, 
        "epoch_average_loss": epoch_avg_loss,
        "epoch_skipped_batches": skipped_batches,
        "epoch_effective_batches": effective_batches
    })

    try:
        epoch_save_path = f"ipo_trained_model/epoch_{epoch+1}_{model_configuration}"
        os.makedirs(epoch_save_path, exist_ok=True)
        model.save_pretrained(epoch_save_path)
        tokenizer.save_pretrained(epoch_save_path)
        print(f"Epoch {epoch + 1} model saved to {epoch_save_path}")
    except Exception as e:
        print(f"Error saving epoch model: {e}")

print("Training completed!")

# --- Finish the W&B run ---
wandb.finish()

Epoch 1/3:  27%|██▋       | 19181/70073 [57:57<2:34:43,  5.48it/s, loss=91.4975]

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
def generate_query(question, max_length=30, temperature=0.8, top_p=0.9):
    """Generate a search query for the given question using the trained T5 model"""
    model.eval()

    # T5 works better with clear instruction format
    prompt = f"Generate a search query for this question: {question}"
    encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)

    device = next(model.parameters()).device
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_length,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            early_stopping=False,
            repetition_penalty=1.2,
            no_repeat_ngram_size=3
        )

    # For seq2seq models, decode the entire output (no need to slice)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Clean up the generated text
    generated_text = generated_text.strip()

    # Remove question marks and periods
    generated_text = generated_text.replace('?', '').replace('.', '')


    return generated_text if generated_text else "search query"

# Test with example questions
test_questions = [
    "Who was the first person to climb Mount Everest?",
    "What is the capital of the country where the Eiffel Tower is located?",
    "Which movie won the Academy Award for Best Picture in 2020?",
    "What is the largest planet in our solar system?",
    "Who wrote the novel '1984'?"
]

print("Testing the trained model on example questions:")
print("=" * 60)

for i, question in enumerate(test_questions, 1):
    print(f"\n{i}. Question: {question}")

    # Generate multiple queries with different parameters
    configs = [
        {"temperature": 0.5, "top_p": 0.8},
        {"temperature": 0.8, "top_p": 0.9},
        {"temperature": 1.0, "top_p": 0.95}
    ]

    for j, config in enumerate(configs):
        query = generate_query(question, **config)
        print(f"   Query {j+1}: {query}")

    print("-" * 40)

print("\nTesting completed!")