# FLAN T 5 IPO Training

Minimal implementation of Implicit Preference Optimization (IPO) training for FLAN T 5 model.

In [None]:
import jax
import jax.numpy as jnp
from jax import random
import optax
import json
from typing import Dict, List, Tuple
from transformers import AutoTokenizer
from flax.training import train_state
import wandb
import os
from tqdm import tqdm
import numpy as np

# Initialize TPU
try:
    import tensorflow as tf
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    print("TPU initialized successfully")
except:
    print("TPU initialization failed, falling back to CPU/GPU")

class PreferenceDataset:
    def __init__(self, json_path):
        with open(json_path, 'r') as f:
            raw_data = json.load(f)

        self.data = []
        for question, entry in raw_data.items():
            for hop, hop_data in entry["hops"].items():
                queries = hop_data["queries"]
                preferences = hop_data["preference_pairs"]
                for i, j in preferences:
                    self.data.append({
                        "question": question,
                        "preferred": queries[i],
                        "dispreferred": queries[j]
                    })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
    def get_batch(self, batch_size, rng_key):
        """Get a random batch for JAX training"""
        indices = random.choice(rng_key, len(self.data), (batch_size,), replace=False)
        batch = [self.data[int(idx)] for idx in indices]
        return batch

In [None]:
# Install required packages for TPU
!pip install -q flax transformers[flax] optax

from transformers import FlaxT5ForConditionalGeneration, T5Config
from flax import linen as nn
import jax.numpy as jnp

model_name = "google/flan-t5-small"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load Flax model for TPU
model = FlaxT5ForConditionalGeneration.from_pretrained(
    model_name,
    dtype=jnp.float32,
    _do_init=True
)

print(f"Model {model_name} loaded successfully for TPU!")
print(f"JAX devices: {jax.devices()}")
print(f"JAX device count: {jax.device_count()}")

# Setup W&B
os.environ["WANDB_API_KEY"] = "57b8585a9cdb363d54a7d215dd95c824d880868b"
wandb.login()

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Model google/flan-t5-small loaded successfully!


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mohitit20[0m ([33mohitit[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
def compute_logp_jax(params, prompt_ids, prompt_mask, target_ids, target_mask):
    """Compute log probability using JAX/Flax model"""
    # For T5, we need to create decoder_input_ids by shifting target_ids
    # Prepend with pad token (which becomes the decoder start token)
    decoder_start_token_id = model.config.decoder_start_token_id or model.config.pad_token_id
    decoder_input_ids = jnp.concatenate([
        jnp.full((target_ids.shape[0], 1), decoder_start_token_id),
        target_ids[:, :-1]
    ], axis=1)
    
    # Create decoder attention mask
    decoder_attention_mask = jnp.concatenate([
        jnp.ones((target_mask.shape[0], 1)),
        target_mask[:, :-1]
    ], axis=1)
    
    outputs = model(
        input_ids=prompt_ids,
        attention_mask=prompt_mask,
        decoder_input_ids=decoder_input_ids,
        decoder_attention_mask=decoder_attention_mask,
        params=params,
        train=True
    )
    
    # Get logits and compute loss manually
    logits = outputs.logits
    
    # Shift logits and target_ids for loss computation
    shift_logits = logits[..., :-1, :]
    shift_labels = target_ids[..., 1:]
    shift_mask = target_mask[..., 1:]
    
    # Compute cross entropy loss
    vocab_size = shift_logits.shape[-1]
    shift_logits = shift_logits.reshape(-1, vocab_size)
    shift_labels = shift_labels.reshape(-1)
    shift_mask = shift_mask.reshape(-1)
    
    # Create one-hot labels and compute loss
    one_hot_labels = jax.nn.one_hot(shift_labels, vocab_size)
    loss_per_token = -jnp.sum(one_hot_labels * jax.nn.log_softmax(shift_logits), axis=-1)
    
    # Apply mask and compute mean
    masked_loss = loss_per_token * shift_mask
    total_loss = jnp.sum(masked_loss) / jnp.maximum(jnp.sum(shift_mask), 1)
    
    return -total_loss  # Return negative log probability

def ipo_loss_jax(logp_win, logp_lose, tau=0.05):
    """Compute IPO loss in JAX"""
    diff = logp_win - logp_lose - 0.5 / tau
    return jnp.mean(diff ** 2)

def tokenize_batch(questions, preferred, dispreferred, max_prompt_len=128, max_target_len=64):
    """Tokenize a batch of data"""
    batch_size = len(questions)
    
    # Prepare prompts
    prompts = [f"Generate a search query for: {q}\nQuery: " for q in questions]
    
    # Tokenize prompts
    prompt_encodings = tokenizer(
        prompts, 
        padding=True, 
        truncation=True, 
        max_length=max_prompt_len,
        return_tensors="np"
    )
    
    # Tokenize preferred and dispreferred completions
    preferred_encodings = tokenizer(
        preferred,
        padding=True,
        truncation=True,
        max_length=max_target_len,
        return_tensors="np"
    )
    
    dispreferred_encodings = tokenizer(
        dispreferred,
        padding=True,
        truncation=True,
        max_length=max_target_len,
        return_tensors="np"
    )
    
    return {
        'prompt_ids': jnp.array(prompt_encodings['input_ids']),
        'prompt_mask': jnp.array(prompt_encodings['attention_mask']),
        'preferred_ids': jnp.array(preferred_encodings['input_ids']),
        'preferred_mask': jnp.array(preferred_encodings['attention_mask']),
        'dispreferred_ids': jnp.array(dispreferred_encodings['input_ids']),
        'dispreferred_mask': jnp.array(dispreferred_encodings['attention_mask'])
    }

@jax.jit
def train_step(state, batch, tau=0.05):
    """Single training step compiled with JAX JIT"""
    def loss_fn(params):
        # Compute log probabilities for preferred completions
        logp_preferred = jax.vmap(
            lambda p_ids, p_mask, t_ids, t_mask: compute_logp_jax(params, p_ids[None], p_mask[None], t_ids[None], t_mask[None])
        )(batch['prompt_ids'], batch['prompt_mask'], batch['preferred_ids'], batch['preferred_mask'])
        
        # Compute log probabilities for dispreferred completions
        logp_dispreferred = jax.vmap(
            lambda p_ids, p_mask, t_ids, t_mask: compute_logp_jax(params, p_ids[None], p_mask[None], t_ids[None], t_mask[None])
        )(batch['prompt_ids'], batch['prompt_mask'], batch['dispreferred_ids'], batch['dispreferred_mask'])
        
        # Compute IPO loss
        loss = ipo_loss_jax(logp_preferred, logp_dispreferred, tau)
        return loss
    
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# Training setup
parent_path = '/content/drive/MyDrive/c438_project'
dataset_path = f'{parent_path}/preference_dataset_hotpotqa_final.json'
dataset = PreferenceDataset(dataset_path)

# Training hyperparameters
batch_size = 4  # Larger batch size for TPU efficiency
learning_rate = 5e-6
tau = 0.05
num_epochs = 3

# Initialize optimizer and training state
optimizer = optax.adamw(learning_rate=learning_rate)
rng = random.PRNGKey(42)
state = train_state.TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=optimizer
)

print(f"Dataset size: {len(dataset)}")
print(f"Batch size: {batch_size}")
print(f"Starting IPO training on TPU...")

Dataset size: 70073
Starting IPO training...


In [None]:
wandb.init(
    project="c438_project",
    name="ipo_training_tpu_run",
    config={
        "learning_rate": learning_rate,
        "num_epochs": num_epochs,
        "batch_size": batch_size,
        "model_name": model_name,
        "tau": tau,
        "device": "TPU",
        "device_count": jax.device_count()
    }
)

# Training loop
num_batches_per_epoch = len(dataset) // batch_size
rng = random.PRNGKey(42)

for epoch in range(num_epochs):
    total_loss = 0.0
    epoch_rng = random.fold_in(rng, epoch)
    
    pbar = tqdm(range(num_batches_per_epoch), desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch_idx in pbar:
        # Get batch
        batch_rng = random.fold_in(epoch_rng, batch_idx)
        raw_batch = dataset.get_batch(batch_size, batch_rng)
        
        # Extract data from batch
        questions = [item['question'] for item in raw_batch]
        preferred = [item['preferred'].strip() for item in raw_batch]
        dispreferred = [item['dispreferred'].strip() for item in raw_batch]
        
        # Tokenize batch
        batch = tokenize_batch(questions, preferred, dispreferred)
        
        # Training step
        try:
            state, loss = train_step(state, batch, tau)
            current_loss = float(loss)
            
            if jnp.isnan(loss):
                print(f"Skipping batch {batch_idx} due to NaN loss.")
                continue
            
            total_loss += current_loss
            avg_loss = total_loss / (batch_idx + 1)
            
            pbar.set_postfix({"loss": f"{avg_loss:.4f}"})
            wandb.log({"step_loss": current_loss, "average_loss": avg_loss})
            
        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            continue
    
    epoch_avg_loss = total_loss / num_batches_per_epoch
    print(f"[Epoch {epoch + 1}] Average Loss: {epoch_avg_loss:.4f}")
    wandb.log({"epoch": epoch + 1, "epoch_average_loss": epoch_avg_loss})
    
    # Save model checkpoint
    epoch_save_path = f"{parent_path}/ipo_trained_model_tpu/epoch_{epoch+1}"
    os.makedirs(epoch_save_path, exist_ok=True)
    
    # Save Flax model
    model.save_pretrained(
        epoch_save_path,
        params=state.params,
        push_to_hub=False
    )
    tokenizer.save_pretrained(epoch_save_path)
    print(f"Epoch {epoch + 1} model saved to {epoch_save_path}")

print("Training completed on TPU!")
wandb.finish()


Epoch 1/3:  27%|██▋       | 19181/70073 [57:57<2:34:43,  5.48it/s, loss=91.4975]

In [None]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from transformers import T5Tokenizer, T5Config, FlaxT5ForConditionalGeneration
import numpy as np
import random

# Load T5 tokenizer and model configuration
tokenizer = T5Tokenizer.from_pretrained("t5-base")
config = T5Config.from_pretrained("t5-base")

# Initialize the model
model = FlaxT5ForConditionalGeneration(config)

# Dummy state for demonstration (replace with actual trained model parameters)
class DummyState:
    params = None  # This should be the actual model parameters

state = DummyState()

@jax.jit
def generate_query_jax(params, input_ids, attention_mask, rng_key, max_length=30):
    """Generate query using JAX/Flax model"""
    return model.generate(
        input_ids,
        attention_mask=attention_mask,
        params=params,
        prng_key=rng_key,
        max_length=input_ids.shape[1] + max_length,
        do_sample=True,
        temperature=0.8,
        top_p=0.9,
        num_beams=1
    ).sequences

def generate_query(question, max_length=30, temperature=0.8, top_p=0.9):
    """Generate a search query for the given question using the trained T5 model on TPU"""
    # T5 works better with clear instruction format
    prompt = f"Generate a search query for this question: {question}"
    encoded = tokenizer(prompt, return_tensors="np", truncation=True, max_length=128, padding=True)
    
    input_ids = jnp.array(encoded['input_ids'])
    attention_mask = jnp.array(encoded['attention_mask'])
    
    # Generate with the trained model
    rng_key = random.PRNGKey(42)
    
    try:
        outputs = generate_query_jax(
            state.params, 
            input_ids, 
            attention_mask, 
            rng_key,
            max_length=max_length
        )
        
        # Decode the generated sequence
        generated_ids = outputs[0]
        generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        # Clean up the generated text by removing the input prompt
        if prompt in generated_text:
            generated_text = generated_text.replace(prompt, "").strip()
        
        # Remove question marks and periods
        generated_text = generated_text.replace('?', '').replace('.', '')
        
        return generated_text if generated_text else "search query"
    
    except Exception as e:
        print(f"Generation error: {e}")
        return "search query"

# Test with example questions
test_questions = [
    "Who was the first person to climb Mount Everest?",
    "What is the capital of the country where the Eiffel Tower is located?",
    "Which movie won the Academy Award for Best Picture in 2020?",
    "What is the largest planet in our solar system?",
    "Who wrote the novel '1984'?"
]

print("Testing the TPU-trained model on example questions:")
print("=" * 60)

for i, question in enumerate(test_questions, 1):
    print(f"\n{i}. Question: {question}")
    
    # Generate multiple queries with different parameters
    configs = [
        {"temperature": 0.5, "top_p": 0.8},
        {"temperature": 0.8, "top_p": 0.9},
        {"temperature": 1.0, "top_p": 0.95}
    ]
    
    for j, config in enumerate(configs):
        try:
            query = generate_query(question, **config)
            print(f"   Query {j+1}: {query}")
        except Exception as e:
            print(f"   Query {j+1}: Error generating query - {e}")
    
    print("-" * 40)

print("\nTesting completed on TPU!")