# Risk-Averse Reward Model Training

This notebook implements the complete risk aversion experiment for training **Qwen3-8B** to prefer risk-averse choices over risk-neutral ones.

## Features
- **CSV Data Loading**: Loads scenarios from `11_7_low_stakes_training_set.csv`
- **CARA vs LINEAR Utility**: Trains on CARA (risk-averse) vs LINEAR (risk-neutral) best options


**Memory Optimized for 8B Model:** Uses gradient checkpointing, smaller batch size, and fp16 for memory efficiency


## 1. Install Dependencies

In [None]:
# Install required packages
!pip install -q transformers datasets accelerate torch pandas numpy scikit-learn matplotlib seaborn

print("✓ Dependencies installed successfully!")

## 2. Import Libraries

In [None]:
import os
import sys
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    get_linear_schedule_with_warmup
)
from sklearn.model_selection import train_test_split
import json
from typing import List, Dict, Tuple
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")

# Reproducibility function (will be called after config is loaded)
def set_seed(seed):
    """Set all seeds for reproducibility"""
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 3. Configuration

All configurable parameters for the experiment.

In [None]:
# =============================================================================
# CONFIGURATION - All configurable parameters for the experiment
# =============================================================================

# Model settings
MODEL_NAME = "Qwen/Qwen3-8B"          # Base model to use
MAX_LENGTH = 256                       # Maximum sequence length for tokenization

# Training hyperparameters
BATCH_SIZE = 2                         # Batch size per forward pass
GRADIENT_ACCUMULATION_STEPS = 4        # Effective batch size = BATCH_SIZE * this
LEARNING_RATE = 5e-4                   # Learning rate for reward head
WEIGHT_DECAY = 0.01                    # L2 regularization on reward head
NUM_EPOCHS = 5                         # Maximum number of training epochs

# Stability settings
REWARD_CLIP_VALUE = 10.0               # Clip r+ - r- to avoid extreme logits
EARLY_STOPPING_PATIENCE = 2            # Stop if no improvement for N epochs

# Data settings
DATA_FILE = "11_7_low_stakes_training_set.csv"  # Training data file
TEST_SPLIT = 0.2                       # Fraction of data for validation
RANDOM_SEED = 42                       # Random seed for reproducibility

# =============================================================================
# Set seed for reproducibility
set_seed(RANDOM_SEED)

print("Configuration loaded:")
print(f"  Model: {MODEL_NAME}")
print(f"  Batch size: {BATCH_SIZE} (effective: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS})")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Early stopping patience: {EARLY_STOPPING_PATIENCE}")
print(f"  Random seed: {RANDOM_SEED}")

## 4. Data Loading Class

Loads risk scenarios from CSV file with proper prompt modification and grouping by situation_id.

In [None]:
class RiskAversionDataLoader:
    """Load and process data from CSV file for risk aversion training"""
    
    def __init__(self, csv_file_path="11_7_low_stakes_training_set.csv"):
        self.csv_file_path = csv_file_path
        
    def load_and_process_data(self) -> pd.DataFrame:
        """Load CSV data and process it for training
        
        CSV format (11_7_low_stakes_training_set.csv):
        - Multiple rows per situation (one per option)
        - is_best_cara = True marks risk-averse option
        - is_best_linear = True marks risk-neutral option
        - correct_label/incorrect_label columns contain the actual labels
        """
        # Check if CSV file exists
        if not os.path.exists(self.csv_file_path):
            raise FileNotFoundError(
                f"Required data file '{self.csv_file_path}' not found. "
                f"Please ensure the CSV file is uploaded to Colab."
            )
        
        # Load the CSV file
        df = pd.read_csv(self.csv_file_path)
        print(f"Loaded {len(df)} rows from {self.csv_file_path}")
        
        # Check required columns exist
        required_columns = ['situation_id', 'prompt_text', 'option_index', 'is_best_cara', 'is_best_linear']
        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            raise ValueError(
                f"Missing required columns in CSV: {missing_columns}. "
                f"Available columns: {list(df.columns)}"
            )
        
        # Check if pre-computed labels exist
        has_precomputed_labels = 'correct_label' in df.columns and 'incorrect_label' in df.columns
        
        # Group by situation_id to get unique situations
        situations = []
        situations_skipped = 0
        
        for situation_id, group in df.groupby('situation_id'):
            # Find risk-averse option (CARA best)
            cara_rows = group[group['is_best_cara'] == True]
            if len(cara_rows) == 0:
                situations_skipped += 1
                continue
            cara_option = cara_rows.iloc[0]
            
            # Find risk-neutral option (LINEAR best)
            linear_rows = group[group['is_best_linear'] == True]
            if len(linear_rows) == 0:
                situations_skipped += 1
                continue
            linear_option = linear_rows.iloc[0]
            
            # Get prompt text from first row (same for all options)
            prompt_text = group.iloc[0]['prompt_text']
            
            # Use pre-computed labels if available, otherwise compute from option_index
            if has_precomputed_labels and pd.notna(cara_option['correct_label']):
                correct_label = str(cara_option['correct_label'])
                incorrect_label = str(cara_option['incorrect_label'])
            else:
                # Fallback: convert 0-indexed option_index to 1-indexed option numbers
                correct_label = str(cara_option['option_index'] + 1)
                incorrect_label = str(linear_option['option_index'] + 1)
            
            situations.append({
                'situation_id': situation_id,
                'prompt_text': prompt_text,
                'correct_label': correct_label,  # Risk-averse option
                'incorrect_label': incorrect_label,  # Risk-neutral option
                'num_options': len(group)
            })
        
        if situations_skipped > 0:
            print(f"Warning: Skipped {situations_skipped} situations missing CARA or LINEAR best options")
        
        result_df = pd.DataFrame(situations)
        print(f"Processed into {len(result_df)} unique situations")
        
        # Display sample data
        if len(result_df) > 0:
            sample = result_df.iloc[0]
            print(f"\nSample situation:")
            print(f"Prompt: {sample['prompt_text'][:200]}...")
            print(f"Risk-averse choice (CARA best): Option {sample['correct_label']}")
            print(f"Risk-neutral choice (LINEAR best): Option {sample['incorrect_label']}")
            print(f"Number of options in this situation: {sample['num_options']}")
        
        return result_df

print("RiskAversionDataLoader defined")

## 5. Load and Validate Data

Load the training data and split into training and validation sets.

In [None]:
# Load data using config
loader = RiskAversionDataLoader(DATA_FILE)
full_dataset = loader.load_and_process_data()

# Split into train and validation sets
train_df, val_df = train_test_split(
    full_dataset, 
    test_size=TEST_SPLIT, 
    random_state=RANDOM_SEED
)

print(f"\n{'='*60}")
print(f"Dataset Split:")
print(f"  Total situations: {len(full_dataset)}")
print(f"  Training situations: {len(train_df)} ({100*(1-TEST_SPLIT):.0f}%)")
print(f"  Validation situations: {len(val_df)} ({100*TEST_SPLIT:.0f}%)")
print(f"{'='*60}")

# Validate data format
assert 'prompt_text' in train_df.columns, "Missing prompt_text column"
assert 'correct_label' in train_df.columns, "Missing correct_label column"
assert 'incorrect_label' in train_df.columns, "Missing incorrect_label column"

print("\nData validation passed!")
print("Train/validation split complete!")

## 6. Pairwise Dataset for Reward Modeling

Dataset class that provides pairs of (preferred, rejected) options for Bradley-Terry loss.

In [None]:
class PairwiseRewardDataset(Dataset):
    """Dataset for pairwise reward model training with Bradley-Terry loss"""
    
    def __init__(self, dataframe: pd.DataFrame, tokenizer, max_length: int = 256):
        """
        Args:
            dataframe: DataFrame with columns: prompt_text, correct_label, incorrect_label
            tokenizer: Tokenizer for encoding text
            max_length: Maximum sequence length
        """
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Format: prompt + chosen option
        preferred_text = f"{row['prompt_text']}\n\nChosen option: {row['correct_label']}"
        rejected_text = f"{row['prompt_text']}\n\nChosen option: {row['incorrect_label']}"
        
        # Tokenize both options
        preferred_encoding = self.tokenizer(
            preferred_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        rejected_encoding = self.tokenizer(
            rejected_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'preferred_input_ids': preferred_encoding['input_ids'].squeeze(0),
            'preferred_attention_mask': preferred_encoding['attention_mask'].squeeze(0),
            'rejected_input_ids': rejected_encoding['input_ids'].squeeze(0),
            'rejected_attention_mask': rejected_encoding['attention_mask'].squeeze(0),
        }

print("✓ PairwiseRewardDataset defined")

## 7. Reward Model Architecture

Qwen3-8B base model with frozen weights + trainable scalar reward head.

In [None]:
from transformers import AutoModel

class RewardModel(nn.Module):
    """Reward model with frozen backbone and trainable scalar reward head"""
    
    def __init__(self, model_name: str = "Qwen/Qwen3-8B"):
        super().__init__()
        
        # Load base model
        print(f"Loading base model: {model_name}")
        self.backbone = AutoModel.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto",
        )
        
        # Freeze all backbone parameters
        for param in self.backbone.parameters():
            param.requires_grad = False
        
        print(f"Backbone loaded and frozen ({sum(p.numel() for p in self.backbone.parameters())/1e9:.2f}B params)")
        
        # Trainable scalar reward head - replaces LM head
        # Takes final hidden state h_T and computes: reward = W^T * h_T + b
        hidden_size = self.backbone.config.hidden_size
        self.reward_head = nn.Linear(hidden_size, 1, bias=True)
        
        # Xavier initialization for better gradient flow
        nn.init.xavier_uniform_(self.reward_head.weight)
        nn.init.zeros_(self.reward_head.bias)
        
        print(f"Reward head initialized ({sum(p.numel() for p in self.reward_head.parameters())} trainable params)")
        print(f"  Architecture: Linear({hidden_size} -> 1) with bias")
        
    def forward(self, input_ids, attention_mask):
        """
        Forward pass to compute scalar reward from final hidden state
        
        Args:
            input_ids: Input token IDs [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            
        Returns:
            rewards: Scalar reward scores [batch_size]
        """
        # Get hidden states from frozen backbone (in fp16)
        with torch.no_grad():
            outputs = self.backbone(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
        
        # Extract final hidden state h_T (last non-padding token per sequence)
        # For each sequence, find the position of the last non-padding token
        sequence_lengths = attention_mask.sum(dim=1) - 1  # 0-indexed position
        batch_size = input_ids.shape[0]
        
        # Gather the hidden state at the last token position for each sequence
        # Shape: [batch_size, hidden_size]
        hidden_states = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        last_hidden_states = hidden_states[
            torch.arange(batch_size, device=hidden_states.device),
            sequence_lengths
        ]
        
        # Detach and convert to fp32 for stable training
        last_hidden_states = last_hidden_states.detach().float()
        
        # Compute scalar reward: r = W^T * h_T + b
        # Shape: [batch_size, 1] -> [batch_size]
        rewards = self.reward_head(last_hidden_states).squeeze(-1)
        
        return rewards

print("RewardModel class defined")

## 8. Training Setup

Initialize model, tokenizer, datasets, loss function, and optimizer.

In [None]:
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
# Left padding is standard for decoder-only models when batching
tokenizer.padding_side = 'left'

# Create datasets
print("Creating datasets...")
train_dataset = PairwiseRewardDataset(train_df, tokenizer, max_length=MAX_LENGTH)
val_dataset = PairwiseRewardDataset(val_df, tokenizer, max_length=MAX_LENGTH)

print(f"  Training examples: {len(train_dataset)}")
print(f"  Validation examples: {len(val_dataset)}")

# Initialize model
print("\nInitializing reward model...")
model = RewardModel(MODEL_NAME)

# Move reward head to correct device (keep in fp32 for stability)
device = next(model.backbone.parameters()).device
model.reward_head = model.reward_head.to(device)

print(f"  Device: {device}")
print(f"  Backbone dtype: {next(model.backbone.parameters()).dtype}")
print(f"  Reward head dtype: {next(model.reward_head.parameters()).dtype}")

# Bradley-Terry loss function with clipping for stability
def bradley_terry_loss(preferred_rewards, rejected_rewards, clip_value=REWARD_CLIP_VALUE):
    """
    Bradley-Terry pairwise ranking loss with margin clipping
    Loss = -log(sigmoid(r_preferred - r_rejected))
    
    Clipping prevents extreme logits that can cause numerical instability.
    Encourages: r_preferred > r_rejected
    """
    margin = preferred_rewards - rejected_rewards
    margin = torch.clamp(margin, -clip_value, clip_value)  # Clip for stability
    return -F.logsigmoid(margin).mean()  # More numerically stable than log(sigmoid())

# Optimizer with L2 weight decay on reward head only
optimizer = torch.optim.AdamW(
    model.reward_head.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY
)

# Learning rate scheduler with warmup
train_dataloader_for_scheduler = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
total_steps = (len(train_dataloader_for_scheduler) // GRADIENT_ACCUMULATION_STEPS) * NUM_EPOCHS
warmup_steps = total_steps // 10

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

print(f"\n{'='*60}")
print("Training Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Batch size: {BATCH_SIZE} (effective: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS})")
print(f"  Gradient accumulation steps: {GRADIENT_ACCUMULATION_STEPS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Weight decay (L2): {WEIGHT_DECAY}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Max sequence length: {MAX_LENGTH}")
print(f"  Reward margin clipping: +/- {REWARD_CLIP_VALUE}")
print(f"  LR scheduler: linear warmup ({warmup_steps} steps) + decay")
print(f"  Early stopping patience: {EARLY_STOPPING_PATIENCE} epochs")
print(f"  Total training steps: {total_steps}")
print(f"{'='*60}")
print("\nTraining setup complete!")

## 9. Evaluation Function

Compute pairwise accuracy: percentage of pairs where preferred option scores higher.

In [None]:
def evaluate_model(model, dataset, batch_size=None):
    """
    Evaluate model on pairwise accuracy
    
    Args:
        model: RewardModel to evaluate
        dataset: PairwiseRewardDataset
        batch_size: Batch size for evaluation (defaults to BATCH_SIZE * 2)
        
    Returns:
        accuracy: Float, percentage of pairs where preferred scores higher
        avg_loss: Float, average Bradley-Terry loss
        preferred_scores: List of reward scores for preferred options
        rejected_scores: List of reward scores for rejected options
    """
    if batch_size is None:
        batch_size = BATCH_SIZE * 2  # Can use larger batch for eval (no gradients)
    
    model.eval()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    correct = 0
    total = 0
    total_loss = 0.0
    preferred_scores_list = []
    rejected_scores_list = []
    
    with torch.no_grad():
        for batch in dataloader:
            # Get rewards for preferred options
            preferred_rewards = model(
                input_ids=batch['preferred_input_ids'].to(device),
                attention_mask=batch['preferred_attention_mask'].to(device)
            )
            
            # Get rewards for rejected options
            rejected_rewards = model(
                input_ids=batch['rejected_input_ids'].to(device),
                attention_mask=batch['rejected_attention_mask'].to(device)
            )
            
            # Compute loss
            loss = bradley_terry_loss(preferred_rewards, rejected_rewards)
            total_loss += loss.item() * len(preferred_rewards)
            
            # Compute accuracy: count pairs where preferred > rejected
            correct += (preferred_rewards > rejected_rewards).sum().item()
            total += len(preferred_rewards)
            
            # Store scores for analysis
            preferred_scores_list.extend(preferred_rewards.cpu().float().numpy())
            rejected_scores_list.extend(rejected_rewards.cpu().float().numpy())
    
    accuracy = correct / total if total > 0 else 0.0
    avg_loss = total_loss / total if total > 0 else 0.0
    
    return accuracy, avg_loss, preferred_scores_list, rejected_scores_list

print("Evaluation function defined")

## 10. Training Loop

Train the model with logging, validation, and checkpointing.

In [None]:
# Create output directory for checkpoints
os.makedirs("outputs", exist_ok=True)

# Training history
history = {
    'train_loss': [],
    'train_steps': [],
    'val_accuracy': [],
    'val_loss': [],
    'epochs': [],
    'reward_margins': [],
    'preferred_rewards': [],
    'rejected_rewards': [],
    'learning_rates': [],
}

# Create dataloader
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    pin_memory=True if torch.cuda.is_available() else False,
)

print("Starting training...")
print(f"Total batches per epoch: {len(train_dataloader)}")
print(f"Gradient accumulation steps: {GRADIENT_ACCUMULATION_STEPS}")
print(f"Optimizer steps per epoch: {len(train_dataloader) // GRADIENT_ACCUMULATION_STEPS}")
print(f"Total optimizer steps: {total_steps}\n")

best_val_accuracy = 0.0
global_step = 0
optimizer_step = 0
epochs_without_improvement = 0

# Store initial weights for comparison
initial_weight = model.reward_head.weight.clone().detach()
initial_bias = model.reward_head.bias.clone().detach()

# Zero gradients at start
optimizer.zero_grad()

for epoch in range(NUM_EPOCHS):
    print(f"{'='*60}")
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
    print(f"{'='*60}")
    
    model.train()
    epoch_loss = 0.0
    epoch_preferred_rewards = []
    epoch_rejected_rewards = []
    epoch_margins = []
    accumulated_loss = 0.0
    
    for step, batch in enumerate(train_dataloader):
        # Forward pass for preferred options
        preferred_rewards = model(
            input_ids=batch['preferred_input_ids'].to(device),
            attention_mask=batch['preferred_attention_mask'].to(device)
        )
        
        # Forward pass for rejected options
        rejected_rewards = model(
            input_ids=batch['rejected_input_ids'].to(device),
            attention_mask=batch['rejected_attention_mask'].to(device)
        )
        
        # Track reward statistics
        epoch_preferred_rewards.extend(preferred_rewards.detach().cpu().tolist())
        epoch_rejected_rewards.extend(rejected_rewards.detach().cpu().tolist())
        reward_margin = (preferred_rewards - rejected_rewards).detach()
        epoch_margins.extend(reward_margin.cpu().tolist())
        
        # Compute Bradley-Terry loss (scaled for gradient accumulation)
        loss = bradley_terry_loss(preferred_rewards, rejected_rewards) / GRADIENT_ACCUMULATION_STEPS
        
        # Check for NaN loss
        if torch.isnan(loss):
            print(f"  WARNING: NaN loss detected at step {step + 1}")
            print(f"    Preferred rewards: {preferred_rewards}")
            print(f"    Rejected rewards: {rejected_rewards}")
            optimizer.zero_grad()
            continue
        
        # Backward pass (accumulate gradients)
        loss.backward()
        accumulated_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS
        epoch_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS
        
        # Detailed diagnostics on first few optimizer steps
        if optimizer_step < 3 and (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            print(f"\n  === Diagnostics for Optimizer Step {optimizer_step + 1} ===")
            
            # Reward statistics
            print(f"  Rewards:")
            print(f"    Preferred: mean={preferred_rewards.mean().item():.4f}, std={preferred_rewards.std().item():.4f}")
            print(f"    Rejected:  mean={rejected_rewards.mean().item():.4f}, std={rejected_rewards.std().item():.4f}")
            print(f"    Margin:    mean={reward_margin.mean().item():.4f}, std={reward_margin.std().item():.4f}")
            print(f"    Loss:      {accumulated_loss:.4f}")
            
            # Gradient statistics
            print(f"  Gradients:")
            for name, param in model.reward_head.named_parameters():
                if param.grad is not None:
                    grad_norm = param.grad.norm().item()
                    grad_mean = param.grad.mean().item()
                    grad_max = param.grad.abs().max().item()
                    print(f"    {name}: norm={grad_norm:.6f}, mean={grad_mean:.6f}, max={grad_max:.6f}")
                else:
                    print(f"    {name}: NO GRADIENT!")
            
            # Parameter statistics
            print(f"  Parameters:")
            weight_change = (model.reward_head.weight - initial_weight).abs().max().item()
            bias_change = (model.reward_head.bias - initial_bias).abs().max().item()
            print(f"    Weight max change: {weight_change:.6f}")
            print(f"    Bias max change: {bias_change:.6f}")
            print(f"    Current LR: {scheduler.get_last_lr()[0]:.6f}")
            print()
        
        # Optimizer step after accumulation
        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            # Gradient clipping
            grad_norm = torch.nn.utils.clip_grad_norm_(model.reward_head.parameters(), max_norm=1.0)
            
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            optimizer_step += 1
            global_step += 1
            
            # Log every 50 optimizer steps
            if optimizer_step % 50 == 0:
                avg_loss = epoch_loss / (step + 1)
                weight_change = (model.reward_head.weight - initial_weight).abs().max().item()
                recent_margin = np.mean(epoch_margins[-100:]) if len(epoch_margins) >= 100 else np.mean(epoch_margins)
                current_lr = scheduler.get_last_lr()[0]
                
                print(f"  Step {optimizer_step} | "
                      f"Loss: {avg_loss:.4f} | "
                      f"Margin: {recent_margin:+.4f} | "
                      f"LR: {current_lr:.2e} | "
                      f"Weight: {weight_change:.6f}")
                
                history['train_loss'].append(avg_loss)
                history['train_steps'].append(global_step)
                history['learning_rates'].append(current_lr)
            
            accumulated_loss = 0.0
    
    # Handle remaining gradients if dataset doesn't divide evenly
    remaining_steps = len(train_dataloader) % GRADIENT_ACCUMULATION_STEPS
    if remaining_steps > 0:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.reward_head.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        optimizer_step += 1
    
    # End of epoch statistics
    avg_train_loss = epoch_loss / len(train_dataloader)
    
    # Weight changes
    total_weight_change = (model.reward_head.weight - initial_weight).abs().mean().item()
    total_bias_change = (model.reward_head.bias - initial_bias).abs().mean().item()
    
    # Reward statistics
    mean_preferred = np.mean(epoch_preferred_rewards)
    mean_rejected = np.mean(epoch_rejected_rewards)
    mean_margin = np.mean(epoch_margins)
    std_margin = np.std(epoch_margins)
    percent_correct = np.mean([m > 0 for m in epoch_margins]) * 100
    
    print(f"\n  Epoch {epoch + 1} Summary:")
    print(f"    Loss: {avg_train_loss:.4f}")
    print(f"    Rewards: preferred={mean_preferred:+.4f}, rejected={mean_rejected:+.4f}")
    print(f"    Margin: mean={mean_margin:+.4f}, std={std_margin:.4f}")
    print(f"    Correct ranking: {percent_correct:.1f}% (preferred > rejected)")
    print(f"    Weight changes: weight={total_weight_change:.6f}, bias={total_bias_change:.6f}")
    print(f"    Learning rate: {scheduler.get_last_lr()[0]:.2e}")
    
    # Store for history
    history['reward_margins'].append(mean_margin)
    history['preferred_rewards'].append(mean_preferred)
    history['rejected_rewards'].append(mean_rejected)
    
    # Validation
    print(f"\n  Running validation...")
    val_accuracy, val_loss, val_pref_scores, val_rej_scores = evaluate_model(model, val_dataset)
    
    val_margin = np.mean(val_pref_scores) - np.mean(val_rej_scores)
    
    print(f"    Validation accuracy: {val_accuracy:.4f} ({val_accuracy*100:.2f}%)")
    print(f"    Validation loss: {val_loss:.4f}")
    print(f"    Validation margin: {val_margin:+.4f}")
    
    # Save to history
    history['val_accuracy'].append(val_accuracy)
    history['val_loss'].append(val_loss)
    history['epochs'].append(epoch + 1)
    
    # Save checkpoint if best model
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        epochs_without_improvement = 0
        checkpoint_path = f"outputs/best_model_epoch{epoch+1}.pt"
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.reward_head.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_accuracy': val_accuracy,
            'val_loss': val_loss,
        }, checkpoint_path)
        print(f"    New best! Saved checkpoint: {checkpoint_path}")
    else:
        epochs_without_improvement += 1
        print(f"    No improvement for {epochs_without_improvement} epoch(s)")
        
        # Early stopping check
        if epochs_without_improvement >= EARLY_STOPPING_PATIENCE:
            print(f"\n  Early stopping triggered: no improvement for {EARLY_STOPPING_PATIENCE} epochs")
            break
    
    # Clear GPU cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print()

print(f"\n{'='*60}")
print("Training Complete!")
print(f"{'='*60}")
print(f"Best validation accuracy: {best_val_accuracy:.4f} ({best_val_accuracy*100:.2f}%)")
print(f"Total optimizer steps: {optimizer_step}")
print(f"Final model saved to: outputs/best_model_epoch*.pt")

# Save training history
history_path = "outputs/training_history.json"
with open(history_path, 'w') as f:
    json.dump(history, f, indent=2)
print(f"Training history saved to: {history_path}")

## 11. Visualization and Results

Plot training curves and analyze model performance.

In [None]:
# Evaluate final model on validation set
print("Evaluating final model on validation set...")
final_accuracy, final_loss, preferred_scores, rejected_scores = evaluate_model(
    model, val_dataset
)

print(f"\nFinal Validation Results:")
print(f"  Accuracy: {final_accuracy:.4f} ({final_accuracy*100:.2f}%)")
print(f"  Loss: {final_loss:.4f}")
print(f"  Mean preferred score: {np.mean(preferred_scores):.4f}")
print(f"  Mean rejected score: {np.mean(rejected_scores):.4f}")
print(f"  Score difference: {np.mean(preferred_scores) - np.mean(rejected_scores):.4f}")

# Calculate reward margins
reward_margins = np.array(preferred_scores) - np.array(rejected_scores)

# Create comprehensive visualizations
fig = plt.figure(figsize=(18, 12))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

fig.suptitle('Reward Model Training Results - Risk Aversion Analysis', 
             fontsize=16, fontweight='bold', y=0.995)

# Plot 1: Training Loss Over Time
ax1 = fig.add_subplot(gs[0, 0])
if len(history['train_steps']) > 0:
    ax1.plot(history['train_steps'], history['train_loss'], 'b-', linewidth=2, label='Training Loss')
    ax1.set_xlabel('Training Steps')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss Over Time')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

# Plot 2: Validation Accuracy Over Epochs
ax2 = fig.add_subplot(gs[0, 1])
if len(history['epochs']) > 0:
    ax2.plot(history['epochs'], history['val_accuracy'], 'g-o', linewidth=2, markersize=8)
    ax2.axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='Random (50%)')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Validation Pairwise Accuracy')
    ax2.set_ylim([0, 1])
    ax2.legend()
    ax2.grid(True, alpha=0.3)

# Plot 3: Reward Margin Progression
ax3 = fig.add_subplot(gs[0, 2])
if len(history['epochs']) > 0 and len(history['reward_margins']) > 0:
    ax3.plot(history['epochs'], history['reward_margins'], 'purple', linewidth=2, marker='s', markersize=8)
    ax3.axhline(y=0, color='r', linestyle='--', alpha=0.5, label='No Preference')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Mean Reward Margin')
    ax3.set_title('Risk-Averse Preference Strength\n(Preferred - Rejected)')
    ax3.legend()
    ax3.grid(True, alpha=0.3)

# Plot 4: Score Distribution Comparison (Histogram)
ax4 = fig.add_subplot(gs[1, 0])
bins = np.linspace(min(min(preferred_scores), min(rejected_scores)),
                  max(max(preferred_scores), max(rejected_scores)), 30)
ax4.hist(preferred_scores, bins=bins, alpha=0.6, label='Preferred (Risk-Averse)', 
         color='green', density=True, edgecolor='black', linewidth=0.5)
ax4.hist(rejected_scores, bins=bins, alpha=0.6, label='Rejected (Risk-Neutral)', 
         color='red', density=True, edgecolor='black', linewidth=0.5)
ax4.set_xlabel('Reward Score')
ax4.set_ylabel('Density')
ax4.set_title('Score Distribution Comparison')
ax4.legend()
ax4.grid(True, alpha=0.3, axis='y')

# Plot 5: Scatter Plot - Preferred vs Rejected Scores
ax5 = fig.add_subplot(gs[1, 1])
ax5.scatter(rejected_scores, preferred_scores, alpha=0.5, s=30, c=reward_margins, 
            cmap='RdYlGn', edgecolors='black', linewidth=0.5)
min_val = min(min(rejected_scores), min(preferred_scores))
max_val = max(max(rejected_scores), max(preferred_scores))
ax5.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.5, linewidth=2, label='Equal Scores')
ax5.fill_between([min_val, max_val], [min_val, max_val], [max_val, max_val],
                alpha=0.15, color='green', label='Risk-Averse Preferred')
ax5.fill_between([min_val, max_val], [min_val, min_val], [min_val, max_val],
                alpha=0.15, color='red', label='Risk-Neutral Preferred')
ax5.set_xlabel('Risk-Neutral Score')
ax5.set_ylabel('Risk-Averse Score')
ax5.set_title('Pairwise Score Comparison')
ax5.legend(fontsize=8)
ax5.grid(True, alpha=0.3)
ax5.axis('equal')

# Plot 6: Reward Margin Distribution
ax6 = fig.add_subplot(gs[1, 2])
ax6.hist(reward_margins, bins=40, alpha=0.7, color='purple', edgecolor='black', linewidth=0.5)
ax6.axvline(x=0, color='r', linestyle='--', linewidth=2, label='No Preference')
ax6.axvline(x=np.mean(reward_margins), color='g', linestyle='-', linewidth=2, 
            label=f'Mean: {np.mean(reward_margins):.3f}')
ax6.set_xlabel('Reward Margin (Preferred - Rejected)')
ax6.set_ylabel('Count')
ax6.set_title('Distribution of Reward Margins')
ax6.legend()
ax6.grid(True, alpha=0.3, axis='y')

# Plot 7: Ranking Performance Breakdown
ax7 = fig.add_subplot(gs[2, 0])
correct_rankings = np.sum(reward_margins > 0)
incorrect_rankings = np.sum(reward_margins < 0)
tied_rankings = np.sum(reward_margins == 0)
categories = ['Correct\n(Pref > Rej)', 'Incorrect\n(Pref < Rej)', 'Tied\n(Pref = Rej)']
counts = [correct_rankings, incorrect_rankings, tied_rankings]
colors = ['green', 'red', 'gray']
bars = ax7.bar(categories, counts, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
for bar, count in zip(bars, counts):
    height = bar.get_height()
    pct = 100 * count / len(reward_margins)
    ax7.text(bar.get_x() + bar.get_width()/2., height + 1,
             f'{count}\n({pct:.1f}%)', ha='center', va='bottom', fontweight='bold')
ax7.set_ylabel('Number of Pairs')
ax7.set_title('Ranking Performance Breakdown')
ax7.grid(True, alpha=0.3, axis='y')

# Plot 8: Score Evolution Over Training
ax8 = fig.add_subplot(gs[2, 1])
if len(history['epochs']) > 0 and len(history['preferred_rewards']) > 0:
    ax8.plot(history['epochs'], history['preferred_rewards'], 'g-o', 
             linewidth=2, markersize=8, label='Preferred (Risk-Averse)')
    ax8.plot(history['epochs'], history['rejected_rewards'], 'r-s', 
             linewidth=2, markersize=8, label='Rejected (Risk-Neutral)')
    ax8.set_xlabel('Epoch')
    ax8.set_ylabel('Mean Reward Score')
    ax8.set_title('Score Evolution During Training')
    ax8.legend()
    ax8.grid(True, alpha=0.3)

# Plot 9: Cumulative Distribution Functions
ax9 = fig.add_subplot(gs[2, 2])
sorted_pref = np.sort(preferred_scores)
sorted_rej = np.sort(rejected_scores)
cdf_pref = np.arange(1, len(sorted_pref) + 1) / len(sorted_pref)
cdf_rej = np.arange(1, len(sorted_rej) + 1) / len(sorted_rej)
ax9.plot(sorted_pref, cdf_pref, 'g-', linewidth=2, label='Preferred (Risk-Averse)')
ax9.plot(sorted_rej, cdf_rej, 'r-', linewidth=2, label='Rejected (Risk-Neutral)')
ax9.set_xlabel('Reward Score')
ax9.set_ylabel('Cumulative Probability')
ax9.set_title('Cumulative Distribution Comparison')
ax9.legend()
ax9.grid(True, alpha=0.3)

plt.tight_layout()

# Save plot
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
plot_path = f"outputs/training_results_{timestamp}.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f"\nComprehensive plots saved to: {plot_path}")

plt.show()

# Save results to JSON
results = {
    'model_name': MODEL_NAME,
    'final_validation_accuracy': float(final_accuracy),
    'final_validation_loss': float(final_loss),
    'best_validation_accuracy': float(best_val_accuracy),
    'mean_preferred_score': float(np.mean(preferred_scores)),
    'mean_rejected_score': float(np.mean(rejected_scores)),
    'score_difference': float(np.mean(preferred_scores) - np.mean(rejected_scores)),
    'margin_mean': float(np.mean(reward_margins)),
    'margin_std': float(np.std(reward_margins)),
    'correct_rankings': int(np.sum(reward_margins > 0)),
    'incorrect_rankings': int(np.sum(reward_margins < 0)),
    'tied_rankings': int(np.sum(reward_margins == 0)),
    'num_epochs': NUM_EPOCHS,
    'epochs_trained': len(history['epochs']),
    'batch_size': BATCH_SIZE,
    'effective_batch_size': BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS,
    'learning_rate': LEARNING_RATE,
    'weight_decay': WEIGHT_DECAY,
    'reward_clip_value': REWARD_CLIP_VALUE,
    'training_samples': len(train_dataset),
    'validation_samples': len(val_dataset),
    'timestamp': timestamp
}

results_path = f"outputs/results_{timestamp}.json"
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"Results saved to: {results_path}")

# Save training history with timestamp
history_path = f"outputs/training_history_{timestamp}.json"
with open(history_path, 'w') as f:
    json.dump(history, f, indent=2)
print(f"Training history saved to: {history_path}")

print("\n" + "="*60)
print("Experiment Complete!")
print("="*60)

# Find checkpoints before try block to avoid undefined variable
import glob
checkpoints = glob.glob("outputs/best_model_epoch*.pt")

# Automatic downloads (for Google Colab)
print("\n" + "="*60)
print("Downloading Results...")
print("="*60)

try:
    from google.colab import files
    
    # Download plots
    print(f"Downloading: {plot_path}")
    files.download(plot_path)
    
    # Download results JSON
    print(f"Downloading: {results_path}")
    files.download(results_path)
    
    # Download training history
    print(f"Downloading: {history_path}")
    files.download(history_path)
    
    # Download best model checkpoint (if exists)
    if checkpoints:
        latest_checkpoint = max(checkpoints, key=os.path.getctime)
        print(f"Downloading: {latest_checkpoint}")
        files.download(latest_checkpoint)
    
    print("\nAll files downloaded successfully!")
    
except ImportError:
    print("\nNOTE: Not running in Google Colab - files saved to outputs/ directory")
    print("Files available:")
    print(f"  - {plot_path}")
    print(f"  - {results_path}")
    print(f"  - {history_path}")
    if checkpoints:
        latest_checkpoint = max(checkpoints, key=os.path.getctime)
        print(f"  - {latest_checkpoint}")