# Risk-Averse Reward Model Training

This notebook implements the complete risk aversion experiment for training **Qwen3-8B** to prefer risk-averse choices over risk-neutral ones.

## Features
- **Separate Train/Val Data**: Uses dedicated training and validation CSV files with different label types
- **CARA-based Training**: Trains on CARA (risk-averse) labels with smart incorrect label selection
- **Cooperation-based Validation**: Validates on cooperate labels to test generalization
- **Per-Epoch Re-randomization**: Training data with "both" bucket_label gets re-randomized each epoch

## Data Files
- **Training**: `data/2025_12_5_training_set_low_stakes_balanced.csv` (CARA labels)
- **Validation**: `data/2025_12_5_val_set_medium_stakes_balanced.csv` (cooperate labels)

**Memory Optimized for 8B Model:** Uses gradient checkpointing, smaller batch size, and fp16 for memory efficiency

## 1. Install Dependencies

In [None]:
# Install required packages
!pip install -q transformers datasets accelerate torch pandas numpy scikit-learn matplotlib seaborn hf_transfer peft

print("✓ Dependencies installed successfully!")

## 2. Import Libraries

In [None]:
import os
import sys
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from peft import LoraConfig, get_peft_model, TaskType
from sklearn.model_selection import train_test_split
import json
from typing import List, Dict, Tuple
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")

# Reproducibility function (will be called after config is loaded)
def set_seed(seed):
    """Set all seeds for reproducibility"""
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 3. Configuration

All configurable parameters for the experiment.

In [None]:
# =============================================================================
# CONFIGURATION - All configurable parameters for the experiment
# =============================================================================

# Model settings
MODEL_NAME = "Qwen/Qwen3-8B"          # Base model to use
MAX_LENGTH = 256                       # Maximum sequence length for tokenization

# Training hyperparameters
BATCH_SIZE = 2                         # Batch size per forward pass
LEARNING_RATE = 2e-4                   # Learning rate for LoRA layers
WEIGHT_DECAY = 0.01                    # L2 regularization
NUM_EPOCHS = 10                        # Number of training epochs

# LoRA configuration
LORA_R = 8                             # LoRA rank (low for small dataset)
LORA_ALPHA = 16                        # LoRA alpha (scaling = alpha/r = 2.0)
LORA_DROPOUT = 0.05                    # LoRA dropout for regularization
LORA_TARGET_MODULES = ["q_proj", "v_proj"]  # Query and Value attention projections

# Data settings - separate training and validation files
TRAIN_DATA_FILE = "data/2025_12_5_training_set_low_stakes_balanced.csv"
VAL_DATA_FILE = "data/2025_12_5_val_set_medium_stakes_balanced.csv"
RANDOM_SEED = 42                       # Random seed for reproducibility

# =============================================================================
# Set seed for reproducibility
set_seed(RANDOM_SEED)

print("Configuration loaded:")
print(f"  Model: {MODEL_NAME}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Random seed: {RANDOM_SEED}")
print(f"\nLoRA Configuration:")
print(f"  Rank (r): {LORA_R}")
print(f"  Alpha: {LORA_ALPHA}")
print(f"  Dropout: {LORA_DROPOUT}")
print(f"  Target modules: {LORA_TARGET_MODULES}")
print(f"\nData Files:")
print(f"  Training: {TRAIN_DATA_FILE}")
print(f"  Validation: {VAL_DATA_FILE}")

## 4. Data Loading Classes

Two specialized loaders for training and validation data:
- **TrainingDataLoader**: Loads CARA-based labels with `low_bucket_label` logic for incorrect label selection
- **ValidationDataLoader**: Loads cooperate-based labels for generalization testing

In [None]:
class TrainingDataLoader:
    """Load and process training data with CARA-based labels and low_bucket_label logic.
    
    This loader handles the special logic for selecting incorrect labels based on
    the low_bucket_label field:
    - "010_only": use CARA_alpha_0_10_best_labels as incorrect (avoid over-risk-aversion)
    - "lin_only": use linear_best_labels as incorrect (avoid being linear/risk-neutral)
    - "both": randomly choose between the two (re-randomizes each epoch)
    """
    
    def __init__(self, csv_file_path: str, epoch: int = 0, random_seed: int = 42):
        """
        Args:
            csv_file_path: Path to training CSV file
            epoch: Current epoch number (used for reproducible per-epoch randomization)
            random_seed: Base random seed for reproducibility
        """
        self.csv_file_path = csv_file_path
        self.epoch = epoch
        self.rng = np.random.default_rng(random_seed + epoch)  # Per-epoch randomization
        
    def load_and_process_data(self) -> pd.DataFrame:
        """Load CSV data and process it for training.
        
        Returns:
            DataFrame with columns: situation_id, prompt_text, correct_label, incorrect_label, low_bucket_label
        """
        # Check if CSV file exists
        if not os.path.exists(self.csv_file_path):
            raise FileNotFoundError(
                f"Required training data file '{self.csv_file_path}' not found. "
                f"Please ensure the CSV file exists."
            )
        
        # Load the CSV file
        df = pd.read_csv(self.csv_file_path)
        print(f"Loaded {len(df)} rows from {self.csv_file_path}")
        
        # Check required columns exist
        required_columns = ['situation_id', 'prompt_text', 'CARA_correct_labels', 'low_bucket_label']
        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            raise ValueError(
                f"Missing required columns in training CSV: {missing_columns}. "
                f"Available columns: {list(df.columns)}"
            )
        
        # Group by situation_id, take first row of each group (all rows have same labels)
        situations = df.groupby('situation_id').first().reset_index()
        print(f"Found {len(situations)} unique situations")
        
        processed = []
        skipped = 0
        
        for _, row in situations.iterrows():
            try:
                prompt_text = row['prompt_text']
                
                # Parse JSON array for correct labels
                correct_labels = json.loads(row['CARA_correct_labels'])
                if not correct_labels:
                    skipped += 1
                    continue
                
                # Get low_bucket_label and determine incorrect labels
                low_bucket = row['low_bucket_label'].strip('"')  # Remove surrounding quotes
                
                if low_bucket == '010_only':
                    # Use CARA_alpha_0_10_best_labels as incorrect
                    incorrect_labels = json.loads(row['CARA_alpha_0_10_best_labels'])
                elif low_bucket == 'lin_only':
                    # Use linear_best_labels as incorrect
                    incorrect_labels = json.loads(row['linear_best_labels'])
                elif low_bucket == 'both':
                    # Randomly choose between linear and alpha_0_10 (re-randomizes each epoch)
                    if self.rng.random() < 0.5:
                        incorrect_labels = json.loads(row['linear_best_labels'])
                    else:
                        incorrect_labels = json.loads(row['CARA_alpha_0_10_best_labels'])
                else:
                    # Fallback: use CARA_incorrect_labels if available
                    incorrect_labels = json.loads(row.get('CARA_incorrect_labels', '[]'))
                
                if not incorrect_labels:
                    skipped += 1
                    continue
                
                # Randomly select one label from each array
                correct_label = str(self.rng.choice(correct_labels))
                incorrect_label = str(self.rng.choice(incorrect_labels))
                
                processed.append({
                    'situation_id': row['situation_id'],
                    'prompt_text': prompt_text,
                    'correct_label': correct_label,
                    'incorrect_label': incorrect_label,
                    'low_bucket_label': low_bucket,
                })
                
            except (json.JSONDecodeError, KeyError) as e:
                print(f"Warning: Error processing situation {row['situation_id']}: {e}")
                skipped += 1
                continue
        
        if skipped > 0:
            print(f"Warning: Skipped {skipped} situations due to missing/empty labels")
        
        result_df = pd.DataFrame(processed)
        print(f"Processed into {len(result_df)} training examples (epoch {self.epoch})")
        
        # Display low_bucket_label distribution
        if 'low_bucket_label' in result_df.columns and len(result_df) > 0:
            print(f"\nlow_bucket_label distribution:")
            for label, count in result_df['low_bucket_label'].value_counts().items():
                print(f"  {label}: {count} ({100*count/len(result_df):.1f}%)")
        
        return result_df


class ValidationDataLoader:
    """Load and process validation data with cooperate-based labels.
    
    For validation, we use cooperate_correct_labels and cooperate_incorrect_labels
    because cooperation is the goal we're validating for.
    """
    
    def __init__(self, csv_file_path: str, random_seed: int = 42):
        """
        Args:
            csv_file_path: Path to validation CSV file
            random_seed: Random seed for reproducibility
        """
        self.csv_file_path = csv_file_path
        self.rng = np.random.default_rng(random_seed)
    
    def load_and_process_data(self) -> pd.DataFrame:
        """Load CSV data and process it for validation.
        
        Returns:
            DataFrame with columns: situation_id, prompt_text, correct_label, incorrect_label
        """
        # Check if CSV file exists
        if not os.path.exists(self.csv_file_path):
            raise FileNotFoundError(
                f"Required validation data file '{self.csv_file_path}' not found. "
                f"Please ensure the CSV file exists."
            )
        
        # Load the CSV file
        df = pd.read_csv(self.csv_file_path)
        print(f"Loaded {len(df)} rows from {self.csv_file_path}")
        
        # Check required columns exist
        required_columns = ['situation_id', 'prompt_text', 'cooperate_correct_labels', 'cooperate_incorrect_labels']
        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            raise ValueError(
                f"Missing required columns in validation CSV: {missing_columns}. "
                f"Available columns: {list(df.columns)}"
            )
        
        # Group by situation_id, take first row of each group
        situations = df.groupby('situation_id').first().reset_index()
        print(f"Found {len(situations)} unique situations")
        
        processed = []
        skipped = 0
        
        for _, row in situations.iterrows():
            try:
                # Parse JSON arrays
                correct_labels = json.loads(row['cooperate_correct_labels'])
                incorrect_labels = json.loads(row['cooperate_incorrect_labels'])
                
                # Skip situations with empty labels
                if not correct_labels or not incorrect_labels:
                    skipped += 1
                    continue
                
                # Randomly select one label from each array
                correct_label = str(self.rng.choice(correct_labels))
                incorrect_label = str(self.rng.choice(incorrect_labels))
                
                processed.append({
                    'situation_id': row['situation_id'],
                    'prompt_text': row['prompt_text'],
                    'correct_label': correct_label,
                    'incorrect_label': incorrect_label,
                })
                
            except (json.JSONDecodeError, KeyError) as e:
                print(f"Warning: Error processing situation {row['situation_id']}: {e}")
                skipped += 1
                continue
        
        if skipped > 0:
            print(f"Warning: Skipped {skipped} situations due to missing/empty labels")
        
        result_df = pd.DataFrame(processed)
        print(f"Processed into {len(result_df)} validation examples")
        
        return result_df


print("TrainingDataLoader and ValidationDataLoader defined")

## 5. Load and Validate Data

Load the separate training and validation data files.

In [None]:
# Load training data (epoch 0 for initial load)
train_loader = TrainingDataLoader(TRAIN_DATA_FILE, epoch=0, random_seed=RANDOM_SEED)
train_df = train_loader.load_and_process_data()

# Load validation data (fixed, no per-epoch changes)
val_loader = ValidationDataLoader(VAL_DATA_FILE, random_seed=RANDOM_SEED)
val_df = val_loader.load_and_process_data()

print(f"\n{'='*60}")
print(f"Dataset Summary:")
print(f"  Training file: {TRAIN_DATA_FILE}")
print(f"  Training situations: {len(train_df)}")
print(f"  Validation file: {VAL_DATA_FILE}")
print(f"  Validation situations: {len(val_df)}")
print(f"{'='*60}")

# Validate data format
assert 'prompt_text' in train_df.columns, "Missing prompt_text column in training data"
assert 'correct_label' in train_df.columns, "Missing correct_label column in training data"
assert 'incorrect_label' in train_df.columns, "Missing incorrect_label column in training data"

assert 'prompt_text' in val_df.columns, "Missing prompt_text column in validation data"
assert 'correct_label' in val_df.columns, "Missing correct_label column in validation data"
assert 'incorrect_label' in val_df.columns, "Missing incorrect_label column in validation data"

print("\nData validation passed!")
print("Data loading complete!")

## 6. Pairwise Dataset for Reward Modeling

Dataset class that provides pairs of (preferred, rejected) options for Bradley-Terry loss.

In [None]:
class PairwiseRewardDataset(Dataset):
    """Dataset for pairwise reward model training with Bradley-Terry loss"""
    
    def __init__(self, dataframe: pd.DataFrame, tokenizer, max_length: int = 256):
        """
        Args:
            dataframe: DataFrame with columns: prompt_text, correct_label, incorrect_label
            tokenizer: Tokenizer for encoding text
            max_length: Maximum sequence length
        """
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Format: prompt + chosen option
        preferred_text = f"{row['prompt_text']}\n\nChosen option: {row['correct_label']}"
        rejected_text = f"{row['prompt_text']}\n\nChosen option: {row['incorrect_label']}"
        
        # Tokenize both options
        preferred_encoding = self.tokenizer(
            preferred_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        rejected_encoding = self.tokenizer(
            rejected_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'preferred_input_ids': preferred_encoding['input_ids'].squeeze(0),
            'preferred_attention_mask': preferred_encoding['attention_mask'].squeeze(0),
            'rejected_input_ids': rejected_encoding['input_ids'].squeeze(0),
            'rejected_attention_mask': rejected_encoding['attention_mask'].squeeze(0),
        }

print("✓ PairwiseRewardDataset defined")

## 7. Reward Model Architecture

Qwen3-8B base model with LoRA adapters (q_proj, v_proj) + trainable scalar reward head.

In [None]:
class RewardModel(nn.Module):
    """Reward model with LoRA-adapted backbone and trainable scalar reward head"""
    
    def __init__(
        self, 
        model_name: str = "Qwen/Qwen3-8B",
        lora_r: int = 8,
        lora_alpha: int = 16,
        lora_dropout: float = 0.05,
        lora_target_modules: list = None,
    ):
        super().__init__()
        
        # Load base model in fp16
        print(f"Loading base model: {model_name}")
        self.backbone = AutoModel.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto",
        )
        
        hidden_size = self.backbone.config.hidden_size
        
        # Add reward head BEFORE applying LoRA
        self.reward_head = nn.Linear(hidden_size, 1, bias=True)
        
        # Initialize reward head with small weights
        nn.init.normal_(self.reward_head.weight, mean=0.0, std=0.01)
        nn.init.zeros_(self.reward_head.bias)
        
        # Configure and apply LoRA
        if lora_target_modules is None:
            lora_target_modules = ["q_proj", "v_proj"]
        
        lora_config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            target_modules=lora_target_modules,
            bias="none",
            task_type=TaskType.FEATURE_EXTRACTION,
        )
        
        self.backbone = get_peft_model(self.backbone, lora_config)
        
        # Print trainable parameter info
        print(f"\nLoRA Configuration:")
        print(f"  Rank: {lora_r}")
        print(f"  Alpha: {lora_alpha}")
        print(f"  Dropout: {lora_dropout}")
        print(f"  Target modules: {lora_target_modules}")
        
        self.backbone.print_trainable_parameters()
        
        print(f"\nReward head: Linear({hidden_size} -> 1) with bias")
        
        # Count total trainable parameters
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in self.parameters())
        print(f"Total trainable: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.4f}%)")
        
    def forward(self, input_ids, attention_mask):
        """
        Forward pass to compute scalar reward from final hidden state
        
        Args:
            input_ids: Input token IDs [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            
        Returns:
            rewards: Scalar reward scores [batch_size]
        """
        # Get hidden states - gradients flow through LoRA layers
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        
        # Extract final hidden state h_T (last non-padding token per sequence)
        sequence_lengths = attention_mask.sum(dim=1) - 1  # 0-indexed position
        batch_size = input_ids.shape[0]
        
        # Gather the hidden state at the last token position for each sequence
        hidden_states = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        last_hidden_states = hidden_states[
            torch.arange(batch_size, device=hidden_states.device),
            sequence_lengths
        ]
        
        # Convert to fp32 for numerical stability in reward head
        last_hidden_states = last_hidden_states.float()
        
        # Compute scalar reward: r = W^T * h_T + b
        rewards = self.reward_head(last_hidden_states).squeeze(-1)
        
        return rewards

print("RewardModel class defined (with LoRA support)")

## 8. Training Setup

Initialize model, tokenizer, datasets, loss function, and optimizer.

In [None]:
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Create datasets
print("Creating datasets...")
train_dataset = PairwiseRewardDataset(train_df, tokenizer, max_length=MAX_LENGTH)
val_dataset = PairwiseRewardDataset(val_df, tokenizer, max_length=MAX_LENGTH)

print(f"  Training examples: {len(train_dataset)}")
print(f"  Validation examples: {len(val_dataset)}")


# Helper function for per-epoch re-randomization of training data
def recreate_training_dataset(epoch: int):
    """Recreate training dataset with new randomization for 'both' bucket cases.
    
    This function creates a fresh training dataset where situations with
    low_bucket_label='both' get a new random choice between linear_best_labels
    and CARA_alpha_0_10_best_labels.
    
    Args:
        epoch: Current epoch number (used for reproducible randomization)
        
    Returns:
        PairwiseRewardDataset with re-randomized label selections
    """
    loader = TrainingDataLoader(TRAIN_DATA_FILE, epoch=epoch, random_seed=RANDOM_SEED)
    new_train_df = loader.load_and_process_data()
    return PairwiseRewardDataset(new_train_df, tokenizer, max_length=MAX_LENGTH)


# Initialize model with LoRA
print("\nInitializing reward model with LoRA...")
model = RewardModel(
    model_name=MODEL_NAME,
    lora_r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    lora_target_modules=LORA_TARGET_MODULES,
)

# Get device from backbone
device = next(model.backbone.parameters()).device

# Move reward head to same device and ensure fp32
model.reward_head = model.reward_head.to(device).float()

print(f"\n  Device: {device}")
print(f"  Backbone dtype: {next(model.backbone.parameters()).dtype}")
print(f"  Reward head dtype: {next(model.reward_head.parameters()).dtype}")

# Bradley-Terry loss function
def bradley_terry_loss(preferred_rewards, rejected_rewards):
    """
    Bradley-Terry pairwise ranking loss
    Loss = -log(sigmoid(r_preferred - r_rejected))

    Encourages: r_preferred > r_rejected
    """
    return -torch.log(torch.sigmoid(preferred_rewards - rejected_rewards)).mean()

# Optimizer - train LoRA parameters AND reward head with different learning rates
optimizer = torch.optim.AdamW([
    {'params': model.backbone.parameters(), 'lr': LEARNING_RATE},
    {'params': model.reward_head.parameters(), 'lr': LEARNING_RATE * 2.5},  # Higher LR for reward head
], weight_decay=WEIGHT_DECAY)

print(f"\n{'='*60}")
print("Training Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  LoRA parameters LR: {LEARNING_RATE}")
print(f"  Reward head LR: {LEARNING_RATE * 2.5}")
print(f"  Weight decay (L2): {WEIGHT_DECAY}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Max sequence length: {MAX_LENGTH}")
print(f"  Per-epoch re-randomization: Enabled for 'both' bucket cases")
print(f"{'='*60}")
print("\nTraining setup complete!")

## 9. Evaluation Function

Compute pairwise accuracy: percentage of pairs where preferred option scores higher.

In [None]:
def evaluate_model(model, dataset, batch_size=None):
    """
    Evaluate model on pairwise accuracy
    
    Args:
        model: RewardModel to evaluate
        dataset: PairwiseRewardDataset
        batch_size: Batch size for evaluation (defaults to BATCH_SIZE * 2)
        
    Returns:
        accuracy: Float, percentage of pairs where preferred scores higher
        avg_loss: Float, average Bradley-Terry loss
        preferred_scores: List of reward scores for preferred options
        rejected_scores: List of reward scores for rejected options
    """
    if batch_size is None:
        batch_size = BATCH_SIZE * 2  # Can use larger batch for eval (no gradients)
    
    model.eval()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    correct = 0
    total = 0
    total_loss = 0.0
    preferred_scores_list = []
    rejected_scores_list = []
    
    with torch.no_grad():
        for batch in dataloader:
            # Get rewards for preferred options
            preferred_rewards = model(
                input_ids=batch['preferred_input_ids'].to(device),
                attention_mask=batch['preferred_attention_mask'].to(device)
            )
            
            # Get rewards for rejected options
            rejected_rewards = model(
                input_ids=batch['rejected_input_ids'].to(device),
                attention_mask=batch['rejected_attention_mask'].to(device)
            )
            
            # Compute loss
            loss = bradley_terry_loss(preferred_rewards, rejected_rewards)
            total_loss += loss.item() * len(preferred_rewards)
            
            # Compute accuracy: count pairs where preferred > rejected
            correct += (preferred_rewards > rejected_rewards).sum().item()
            total += len(preferred_rewards)
            
            # Store scores for analysis
            preferred_scores_list.extend(preferred_rewards.cpu().float().numpy())
            rejected_scores_list.extend(rejected_rewards.cpu().float().numpy())
    
    accuracy = correct / total if total > 0 else 0.0
    avg_loss = total_loss / total if total > 0 else 0.0
    
    return accuracy, avg_loss, preferred_scores_list, rejected_scores_list

print("Evaluation function defined")

## 10. Baseline Evaluation (Before Training)

Evaluate the randomly initialized reward model to establish a baseline for comparison.

In [None]:
# =============================================================================
# BASELINE EVALUATION - Before any training
# =============================================================================
# This establishes how well the randomly initialized reward head performs
# Expected: ~50% accuracy (random guessing)

print("="*60)
print("BASELINE EVALUATION (Before Training)")
print("="*60)
print("\nEvaluating randomly initialized model on validation set...")

baseline_accuracy, baseline_loss, baseline_pref_scores, baseline_rej_scores = evaluate_model(
    model, val_dataset
)

# Calculate baseline statistics
baseline_margins = np.array(baseline_pref_scores) - np.array(baseline_rej_scores)
baseline_correct = np.sum(baseline_margins > 0)
baseline_incorrect = np.sum(baseline_margins < 0)

print(f"\nBaseline Results (Untrained Model):")
print(f"  Accuracy: {baseline_accuracy:.4f} ({baseline_accuracy*100:.2f}%)")
print(f"  Loss: {baseline_loss:.4f}")
print(f"  Mean preferred score: {np.mean(baseline_pref_scores):.4f}")
print(f"  Mean rejected score: {np.mean(baseline_rej_scores):.4f}")
print(f"  Mean margin: {np.mean(baseline_margins):.4f}")
print(f"  Correct rankings: {baseline_correct} ({100*baseline_correct/len(baseline_margins):.1f}%)")
print(f"  Incorrect rankings: {baseline_incorrect} ({100*baseline_incorrect/len(baseline_margins):.1f}%)")

# Store baseline for later comparison
baseline_results = {
    'accuracy': baseline_accuracy,
    'loss': baseline_loss,
    'preferred_scores': baseline_pref_scores,
    'rejected_scores': baseline_rej_scores,
    'margins': baseline_margins.tolist(),
    'mean_margin': float(np.mean(baseline_margins)),
    'std_margin': float(np.std(baseline_margins)),
}

print(f"\nExpected baseline: ~50% (random initialization)")
print(f"Actual baseline:   {baseline_accuracy*100:.1f}%")
if abs(baseline_accuracy - 0.5) < 0.1:
    print("Baseline is near random as expected - model has not learned any preference yet.")
else:
    print("NOTE: Baseline deviates from 50% - this may indicate bias in initialization or data.")

print("="*60)

## 11. Training Loop

Train the model with logging, validation, and checkpointing.

In [None]:
# Create output directory for checkpoints
os.makedirs("outputs", exist_ok=True)

# Training history
history = {
    'train_loss': [],
    'train_steps': [],
    'val_accuracy': [],
    'val_loss': [],
    'epochs': [],
    'reward_margins': [],
    'preferred_rewards': [],
    'rejected_rewards': [],
}

print("Starting training...")
print(f"Training data will be re-randomized each epoch for 'both' bucket cases\n")

best_val_accuracy = 0.0
global_step = 0

# Store initial weights for comparison
initial_weight = model.reward_head.weight.clone().detach()
initial_bias = model.reward_head.bias.clone().detach()

for epoch in range(NUM_EPOCHS):
    print(f"{'='*60}")
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
    print(f"{'='*60}")
    
    # Recreate training dataset with epoch-specific randomization
    # This ensures 'both' bucket cases get new random label choices each epoch
    if epoch > 0:
        print(f"  Re-randomizing training data for epoch {epoch + 1}...")
        train_dataset = recreate_training_dataset(epoch)
    
    # Create dataloader for this epoch
    train_dataloader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True,
        pin_memory=True if torch.cuda.is_available() else False,
    )
    
    print(f"  Batches this epoch: {len(train_dataloader)}")
    
    model.train()
    epoch_loss = 0.0
    epoch_preferred_rewards = []
    epoch_rejected_rewards = []
    epoch_margins = []
    
    for step, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        
        # Forward pass for preferred options
        preferred_rewards = model(
            input_ids=batch['preferred_input_ids'].to(device),
            attention_mask=batch['preferred_attention_mask'].to(device)
        )
        
        # Forward pass for rejected options
        rejected_rewards = model(
            input_ids=batch['rejected_input_ids'].to(device),
            attention_mask=batch['rejected_attention_mask'].to(device)
        )
        
        # Track reward statistics
        epoch_preferred_rewards.extend(preferred_rewards.detach().cpu().tolist())
        epoch_rejected_rewards.extend(rejected_rewards.detach().cpu().tolist())
        reward_margin = (preferred_rewards - rejected_rewards).detach()
        epoch_margins.extend(reward_margin.cpu().tolist())
        
        # Compute Bradley-Terry loss
        loss = bradley_terry_loss(preferred_rewards, rejected_rewards)
        
        # Check for NaN loss
        if torch.isnan(loss):
            print(f"  WARNING: NaN loss detected at step {step + 1}")
            print(f"    Preferred rewards: {preferred_rewards}")
            print(f"    Rejected rewards: {rejected_rewards}")
            continue
        
        # Backward pass
        loss.backward()
        epoch_loss += loss.item()
        
        # Detailed diagnostics on first few steps
        if global_step < 3:
            print(f"\n  === Diagnostics for Step {global_step + 1} ===")
            
            # Reward statistics
            print(f"  Rewards:")
            print(f"    Preferred: mean={preferred_rewards.mean().item():.4f}, std={preferred_rewards.std().item():.4f}")
            print(f"    Rejected:  mean={rejected_rewards.mean().item():.4f}, std={rejected_rewards.std().item():.4f}")
            print(f"    Margin:    mean={reward_margin.mean().item():.4f}, std={reward_margin.std().item():.4f}")
            print(f"    Loss:      {loss.item():.4f}")
            
            # Gradient statistics - reward head
            print(f"  Gradients (Reward Head):")
            for name, param in model.reward_head.named_parameters():
                if param.grad is not None:
                    grad_norm = param.grad.norm().item()
                    grad_mean = param.grad.mean().item()
                    grad_max = param.grad.abs().max().item()
                    print(f"    {name}: norm={grad_norm:.6f}, mean={grad_mean:.6f}, max={grad_max:.6f}")
                else:
                    print(f"    {name}: NO GRADIENT!")
            
            # Gradient statistics - LoRA layers
            lora_grad_norms = []
            for name, param in model.backbone.named_parameters():
                if param.requires_grad and param.grad is not None:
                    lora_grad_norms.append(param.grad.norm().item())
            if lora_grad_norms:
                print(f"  Gradients (LoRA layers):")
                print(f"    mean_norm={np.mean(lora_grad_norms):.6f}, max_norm={max(lora_grad_norms):.6f}")
            
            # Parameter statistics
            print(f"  Parameters:")
            weight_change = (model.reward_head.weight - initial_weight).abs().max().item()
            bias_change = (model.reward_head.bias - initial_bias).abs().max().item()
            print(f"    Reward head weight max change: {weight_change:.6f}")
            print(f"    Reward head bias max change: {bias_change:.6f}")
            print()
        
        # Optimizer step
        optimizer.step()
        global_step += 1
        
        # Log every 50 steps (adjusted for smaller dataset)
        if (step + 1) % 50 == 0:
            avg_loss = epoch_loss / (step + 1)
            weight_change = (model.reward_head.weight - initial_weight).abs().max().item()
            recent_margin = np.mean(epoch_margins[-50:]) if len(epoch_margins) >= 50 else np.mean(epoch_margins)
            
            print(f"  Step {step + 1}/{len(train_dataloader)} | "
                  f"Loss: {avg_loss:.4f} | "
                  f"Margin: {recent_margin:+.4f} | "
                  f"Weight: {weight_change:.6f}")
            
            history['train_loss'].append(avg_loss)
            history['train_steps'].append(global_step)
    
    # End of epoch statistics
    avg_train_loss = epoch_loss / len(train_dataloader)
    
    # Weight changes
    total_weight_change = (model.reward_head.weight - initial_weight).abs().mean().item()
    total_bias_change = (model.reward_head.bias - initial_bias).abs().mean().item()
    
    # Reward statistics
    mean_preferred = np.mean(epoch_preferred_rewards)
    mean_rejected = np.mean(epoch_rejected_rewards)
    mean_margin = np.mean(epoch_margins)
    std_margin = np.std(epoch_margins)
    percent_correct = np.mean([m > 0 for m in epoch_margins]) * 100
    
    print(f"\n  Epoch {epoch + 1} Summary:")
    print(f"    Loss: {avg_train_loss:.4f}")
    print(f"    Rewards: preferred={mean_preferred:+.4f}, rejected={mean_rejected:+.4f}")
    print(f"    Margin: mean={mean_margin:+.4f}, std={std_margin:.4f}")
    print(f"    Correct ranking: {percent_correct:.1f}% (preferred > rejected)")
    print(f"    Weight changes: weight={total_weight_change:.6f}, bias={total_bias_change:.6f}")
    
    # Store for history
    history['reward_margins'].append(mean_margin)
    history['preferred_rewards'].append(mean_preferred)
    history['rejected_rewards'].append(mean_rejected)
    
    # Validation
    print(f"\n  Running validation...")
    val_accuracy, val_loss, val_pref_scores, val_rej_scores = evaluate_model(model, val_dataset)
    
    val_margin = np.mean(val_pref_scores) - np.mean(val_rej_scores)
    
    print(f"    Validation accuracy: {val_accuracy:.4f} ({val_accuracy*100:.2f}%)")
    print(f"    Validation loss: {val_loss:.4f}")
    print(f"    Validation margin: {val_margin:+.4f}")
    
    # Save to history
    history['val_accuracy'].append(val_accuracy)
    history['val_loss'].append(val_loss)
    history['epochs'].append(epoch + 1)
    
    # Save checkpoint if best model
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        checkpoint_dir = f"outputs/best_model_epoch{epoch+1}"
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        # Save LoRA adapter weights
        model.backbone.save_pretrained(checkpoint_dir)
        
        # Save reward head separately
        torch.save({
            'reward_head_state_dict': model.reward_head.state_dict(),
            'epoch': epoch + 1,
            'val_accuracy': val_accuracy,
            'val_loss': val_loss,
            'optimizer_state_dict': optimizer.state_dict(),
            'lora_config': {
                'r': LORA_R,
                'alpha': LORA_ALPHA,
                'dropout': LORA_DROPOUT,
                'target_modules': LORA_TARGET_MODULES,
            }
        }, os.path.join(checkpoint_dir, "reward_head.pt"))
        
        print(f"    New best! Saved checkpoint: {checkpoint_dir}")
    
    # Clear GPU cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print()

print(f"\n{'='*60}")
print("Training Complete!")
print(f"{'='*60}")
print(f"Best validation accuracy: {best_val_accuracy:.4f} ({best_val_accuracy*100:.2f}%)")
print(f"Total steps: {global_step}")
print(f"Final model saved to: outputs/best_model_epoch*/")

## 12. Visualization and Results

Plot training curves, compare against baseline, and analyze model performance.

In [None]:
# Evaluate final model on validation set
print("Evaluating final model on validation set...")
final_accuracy, final_loss, preferred_scores, rejected_scores = evaluate_model(
    model, val_dataset
)

print(f"\nFinal Validation Results:")
print(f"  Accuracy: {final_accuracy:.4f} ({final_accuracy*100:.2f}%)")
print(f"  Loss: {final_loss:.4f}")
print(f"  Mean preferred score: {np.mean(preferred_scores):.4f}")
print(f"  Mean rejected score: {np.mean(rejected_scores):.4f}")
print(f"  Score difference: {np.mean(preferred_scores) - np.mean(rejected_scores):.4f}")

# Calculate reward margins
reward_margins = np.array(preferred_scores) - np.array(rejected_scores)

# Compare with baseline
print(f"\n{'='*60}")
print("BASELINE vs TRAINED MODEL COMPARISON")
print(f"{'='*60}")
print(f"                    Baseline    Trained     Improvement")
print(f"  Accuracy:         {baseline_results['accuracy']*100:6.2f}%     {final_accuracy*100:6.2f}%     {(final_accuracy - baseline_results['accuracy'])*100:+6.2f}%")
print(f"  Loss:             {baseline_results['loss']:6.4f}      {final_loss:6.4f}      {final_loss - baseline_results['loss']:+6.4f}")
print(f"  Mean Margin:      {baseline_results['mean_margin']:+6.4f}     {np.mean(reward_margins):+6.4f}     {np.mean(reward_margins) - baseline_results['mean_margin']:+6.4f}")
print(f"{'='*60}")

# Create comprehensive visualizations (4x3 grid to include baseline comparisons)
fig = plt.figure(figsize=(20, 16))
gs = fig.add_gridspec(4, 3, hspace=0.35, wspace=0.3)

fig.suptitle('Reward Model Training Results\nTrained on CARA (Risk-Aversion) | Validated on Cooperation', 
             fontsize=16, fontweight='bold', y=0.995)

# Row 1: Training Progress
# Plot 1: Training Loss Over Time
ax1 = fig.add_subplot(gs[0, 0])
if len(history['train_steps']) > 0:
    ax1.plot(history['train_steps'], history['train_loss'], 'b-', linewidth=2, label='Training Loss')
    ax1.axhline(y=baseline_results['loss'], color='gray', linestyle='--', alpha=0.7, label=f'Baseline Loss ({baseline_results["loss"]:.3f})')
    ax1.set_xlabel('Training Steps')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss Over Time')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

# Plot 2: Validation Accuracy Over Epochs (with baseline)
ax2 = fig.add_subplot(gs[0, 1])
if len(history['epochs']) > 0:
    # Add epoch 0 as baseline
    epochs_with_baseline = [0] + history['epochs']
    accuracy_with_baseline = [baseline_results['accuracy']] + history['val_accuracy']
    ax2.plot(epochs_with_baseline, accuracy_with_baseline, 'g-o', linewidth=2, markersize=8, label='Validation Accuracy')
    ax2.axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='Random (50%)')
    ax2.axhline(y=baseline_results['accuracy'], color='gray', linestyle=':', alpha=0.7, label=f'Baseline ({baseline_results["accuracy"]*100:.1f}%)')
    ax2.scatter([0], [baseline_results['accuracy']], color='orange', s=100, zorder=5, marker='s', label='Before Training')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Validation Accuracy (Cooperation)')
    ax2.set_ylim([0, 1])
    ax2.legend(fontsize=8)
    ax2.grid(True, alpha=0.3)

# Plot 3: Reward Margin Progression (with baseline)
ax3 = fig.add_subplot(gs[0, 2])
if len(history['epochs']) > 0 and len(history['reward_margins']) > 0:
    epochs_with_baseline = [0] + history['epochs']
    margins_with_baseline = [baseline_results['mean_margin']] + history['reward_margins']
    ax3.plot(epochs_with_baseline, margins_with_baseline, 'purple', linewidth=2, marker='s', markersize=8)
    ax3.axhline(y=0, color='r', linestyle='--', alpha=0.5, label='No Preference')
    ax3.axhline(y=baseline_results['mean_margin'], color='gray', linestyle=':', alpha=0.7, label=f'Baseline ({baseline_results["mean_margin"]:.3f})')
    ax3.scatter([0], [baseline_results['mean_margin']], color='orange', s=100, zorder=5, marker='s', label='Before Training')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Mean Reward Margin')
    ax3.set_title('Preference Strength (Training)\n(Preferred - Rejected)')
    ax3.legend(fontsize=8)
    ax3.grid(True, alpha=0.3)

# Row 2: Score Distributions (Trained Model)
# Plot 4: Score Distribution Comparison (Histogram) - Trained
ax4 = fig.add_subplot(gs[1, 0])
bins = np.linspace(min(min(preferred_scores), min(rejected_scores)),
                  max(max(preferred_scores), max(rejected_scores)), 30)
ax4.hist(preferred_scores, bins=bins, alpha=0.6, label='Preferred (Cooperate)', 
         color='green', density=True, edgecolor='black', linewidth=0.5)
ax4.hist(rejected_scores, bins=bins, alpha=0.6, label='Rejected (Non-Cooperate)', 
         color='red', density=True, edgecolor='black', linewidth=0.5)
ax4.set_xlabel('Reward Score')
ax4.set_ylabel('Density')
ax4.set_title('Trained Model: Score Distribution (Val)')
ax4.legend()
ax4.grid(True, alpha=0.3, axis='y')

# Plot 5: Scatter Plot - Preferred vs Rejected Scores (Trained)
ax5 = fig.add_subplot(gs[1, 1])
ax5.scatter(rejected_scores, preferred_scores, alpha=0.5, s=30, c=reward_margins, 
            cmap='RdYlGn', edgecolors='black', linewidth=0.5)
min_val = min(min(rejected_scores), min(preferred_scores))
max_val = max(max(rejected_scores), max(preferred_scores))
ax5.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.5, linewidth=2, label='Equal Scores')
ax5.fill_between([min_val, max_val], [min_val, max_val], [max_val, max_val],
                alpha=0.15, color='green', label='Cooperate Preferred')
ax5.fill_between([min_val, max_val], [min_val, min_val], [min_val, max_val],
                alpha=0.15, color='red', label='Non-Cooperate Preferred')
ax5.set_xlabel('Non-Cooperate Score')
ax5.set_ylabel('Cooperate Score')
ax5.set_title('Trained Model: Pairwise Comparison (Val)')
ax5.legend(fontsize=8)
ax5.grid(True, alpha=0.3)
ax5.axis('equal')

# Plot 6: Reward Margin Distribution (Trained)
ax6 = fig.add_subplot(gs[1, 2])
ax6.hist(reward_margins, bins=40, alpha=0.7, color='purple', edgecolor='black', linewidth=0.5)
ax6.axvline(x=0, color='r', linestyle='--', linewidth=2, label='No Preference')
ax6.axvline(x=np.mean(reward_margins), color='g', linestyle='-', linewidth=2, 
            label=f'Trained Mean: {np.mean(reward_margins):.3f}')
ax6.axvline(x=baseline_results['mean_margin'], color='gray', linestyle=':', linewidth=2, 
            label=f'Baseline Mean: {baseline_results["mean_margin"]:.3f}')
ax6.set_xlabel('Reward Margin (Preferred - Rejected)')
ax6.set_ylabel('Count')
ax6.set_title('Trained Model: Margin Distribution (Val)')
ax6.legend(fontsize=8)
ax6.grid(True, alpha=0.3, axis='y')

# Row 3: Baseline Comparison
# Plot 7: Baseline Score Distribution (for comparison)
ax7 = fig.add_subplot(gs[2, 0])
baseline_pref = np.array(baseline_results['preferred_scores'])
baseline_rej = np.array(baseline_results['rejected_scores'])
bins_baseline = np.linspace(min(min(baseline_pref), min(baseline_rej)),
                           max(max(baseline_pref), max(baseline_rej)), 30)
ax7.hist(baseline_pref, bins=bins_baseline, alpha=0.6, label='Preferred (Cooperate)', 
         color='green', density=True, edgecolor='black', linewidth=0.5)
ax7.hist(baseline_rej, bins=bins_baseline, alpha=0.6, label='Rejected (Non-Cooperate)', 
         color='red', density=True, edgecolor='black', linewidth=0.5)
ax7.set_xlabel('Reward Score')
ax7.set_ylabel('Density')
ax7.set_title('Baseline (Untrained): Score Distribution (Val)')
ax7.legend()
ax7.grid(True, alpha=0.3, axis='y')

# Plot 8: Baseline Scatter Plot
ax8 = fig.add_subplot(gs[2, 1])
baseline_margins_arr = np.array(baseline_results['margins'])
ax8.scatter(baseline_rej, baseline_pref, alpha=0.5, s=30, c=baseline_margins_arr, 
            cmap='RdYlGn', edgecolors='black', linewidth=0.5)
min_val_b = min(min(baseline_rej), min(baseline_pref))
max_val_b = max(max(baseline_rej), max(baseline_pref))
ax8.plot([min_val_b, max_val_b], [min_val_b, max_val_b], 'k--', alpha=0.5, linewidth=2, label='Equal Scores')
ax8.fill_between([min_val_b, max_val_b], [min_val_b, max_val_b], [max_val_b, max_val_b],
                alpha=0.15, color='green', label='Cooperate Preferred')
ax8.fill_between([min_val_b, max_val_b], [min_val_b, min_val_b], [min_val_b, max_val_b],
                alpha=0.15, color='red', label='Non-Cooperate Preferred')
ax8.set_xlabel('Non-Cooperate Score')
ax8.set_ylabel('Cooperate Score')
ax8.set_title('Baseline (Untrained): Pairwise Comparison (Val)')
ax8.legend(fontsize=8)
ax8.grid(True, alpha=0.3)
ax8.axis('equal')

# Plot 9: Baseline vs Trained Accuracy Bar Chart
ax9 = fig.add_subplot(gs[2, 2])
categories = ['Baseline\n(Untrained)', 'Trained\n(LoRA)']
accuracies = [baseline_results['accuracy'] * 100, final_accuracy * 100]
colors = ['gray', 'green']
bars = ax9.bar(categories, accuracies, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
ax9.axhline(y=50, color='r', linestyle='--', alpha=0.5, label='Random (50%)')
for bar, acc in zip(bars, accuracies):
    height = bar.get_height()
    ax9.text(bar.get_x() + bar.get_width()/2., height + 1,
             f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold', fontsize=12)
ax9.set_ylabel('Accuracy (%)')
ax9.set_title('Cooperation Accuracy: Baseline vs Trained')
ax9.set_ylim([0, 100])
ax9.legend()
ax9.grid(True, alpha=0.3, axis='y')

# Row 4: Performance Analysis
# Plot 10: Ranking Performance Breakdown
ax10 = fig.add_subplot(gs[3, 0])
correct_rankings = np.sum(reward_margins > 0)
incorrect_rankings = np.sum(reward_margins < 0)
tied_rankings = np.sum(reward_margins == 0)
categories = ['Correct\n(Pref > Rej)', 'Incorrect\n(Pref < Rej)', 'Tied\n(Pref = Rej)']
counts = [correct_rankings, incorrect_rankings, tied_rankings]
colors = ['green', 'red', 'gray']
bars = ax10.bar(categories, counts, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
for bar, count in zip(bars, counts):
    height = bar.get_height()
    pct = 100 * count / len(reward_margins)
    ax10.text(bar.get_x() + bar.get_width()/2., height + 1,
             f'{count}\n({pct:.1f}%)', ha='center', va='bottom', fontweight='bold')
ax10.set_ylabel('Number of Pairs')
ax10.set_title('Trained Model: Ranking Performance (Val)')
ax10.grid(True, alpha=0.3, axis='y')

# Plot 11: Score Evolution Over Training (with baseline)
ax11 = fig.add_subplot(gs[3, 1])
if len(history['epochs']) > 0 and len(history['preferred_rewards']) > 0:
    epochs_with_baseline = [0] + history['epochs']
    pref_with_baseline = [np.mean(baseline_results['preferred_scores'])] + history['preferred_rewards']
    rej_with_baseline = [np.mean(baseline_results['rejected_scores'])] + history['rejected_rewards']
    ax11.plot(epochs_with_baseline, pref_with_baseline, 'g-o', 
             linewidth=2, markersize=8, label='Preferred (CARA)')
    ax11.plot(epochs_with_baseline, rej_with_baseline, 'r-s', 
             linewidth=2, markersize=8, label='Rejected')
    ax11.scatter([0, 0], [pref_with_baseline[0], rej_with_baseline[0]], 
                color='orange', s=100, zorder=5, marker='D', label='Baseline')
    ax11.set_xlabel('Epoch')
    ax11.set_ylabel('Mean Reward Score')
    ax11.set_title('Training Score Evolution')
    ax11.legend()
    ax11.grid(True, alpha=0.3)

# Plot 12: Cumulative Distribution Functions
ax12 = fig.add_subplot(gs[3, 2])
sorted_pref = np.sort(preferred_scores)
sorted_rej = np.sort(rejected_scores)
cdf_pref = np.arange(1, len(sorted_pref) + 1) / len(sorted_pref)
cdf_rej = np.arange(1, len(sorted_rej) + 1) / len(sorted_rej)
ax12.plot(sorted_pref, cdf_pref, 'g-', linewidth=2, label='Trained: Cooperate')
ax12.plot(sorted_rej, cdf_rej, 'r-', linewidth=2, label='Trained: Non-Cooperate')
# Add baseline CDFs
sorted_base_pref = np.sort(baseline_results['preferred_scores'])
sorted_base_rej = np.sort(baseline_results['rejected_scores'])
cdf_base_pref = np.arange(1, len(sorted_base_pref) + 1) / len(sorted_base_pref)
cdf_base_rej = np.arange(1, len(sorted_base_rej) + 1) / len(sorted_base_rej)
ax12.plot(sorted_base_pref, cdf_base_pref, 'g--', linewidth=1.5, alpha=0.5, label='Baseline: Cooperate')
ax12.plot(sorted_base_rej, cdf_base_rej, 'r--', linewidth=1.5, alpha=0.5, label='Baseline: Non-Cooperate')
ax12.set_xlabel('Reward Score')
ax12.set_ylabel('Cumulative Probability')
ax12.set_title('CDF Comparison: Baseline vs Trained (Val)')
ax12.legend(fontsize=8)
ax12.grid(True, alpha=0.3)

plt.tight_layout()

# Save plot
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
plot_path = f"outputs/training_results_{timestamp}.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f"\nComprehensive plots saved to: {plot_path}")

plt.show()

# Calculate trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())

# Save results to JSON (including baseline and LoRA config)
results = {
    'model_name': MODEL_NAME,
    'architecture': {
        'type': 'LoRA + Reward Head',
        'lora_config': {
            'r': LORA_R,
            'alpha': LORA_ALPHA,
            'dropout': LORA_DROPOUT,
            'target_modules': LORA_TARGET_MODULES,
        },
        'trainable_params': trainable_params,
        'total_params': total_params,
        'trainable_percent': 100 * trainable_params / total_params,
    },
    'data': {
        'training_file': TRAIN_DATA_FILE,
        'validation_file': VAL_DATA_FILE,
        'training_label_type': 'CARA (risk-aversion)',
        'validation_label_type': 'cooperate',
    },
    'baseline': {
        'accuracy': float(baseline_results['accuracy']),
        'loss': float(baseline_results['loss']),
        'mean_margin': float(baseline_results['mean_margin']),
        'std_margin': float(baseline_results['std_margin']),
    },
    'trained': {
        'final_validation_accuracy': float(final_accuracy),
        'final_validation_loss': float(final_loss),
        'best_validation_accuracy': float(best_val_accuracy),
        'mean_preferred_score': float(np.mean(preferred_scores)),
        'mean_rejected_score': float(np.mean(rejected_scores)),
        'score_difference': float(np.mean(preferred_scores) - np.mean(rejected_scores)),
        'margin_mean': float(np.mean(reward_margins)),
        'margin_std': float(np.std(reward_margins)),
        'correct_rankings': int(np.sum(reward_margins > 0)),
        'incorrect_rankings': int(np.sum(reward_margins < 0)),
        'tied_rankings': int(np.sum(reward_margins == 0)),
    },
    'improvement': {
        'accuracy_gain': float(final_accuracy - baseline_results['accuracy']),
        'accuracy_gain_percent': float((final_accuracy - baseline_results['accuracy']) * 100),
        'loss_reduction': float(baseline_results['loss'] - final_loss),
        'margin_improvement': float(np.mean(reward_margins) - baseline_results['mean_margin']),
    },
    'config': {
        'num_epochs': NUM_EPOCHS,
        'epochs_trained': len(history['epochs']),
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'weight_decay': WEIGHT_DECAY,
        'training_samples': len(train_dataset),
        'validation_samples': len(val_dataset),
    },
    'timestamp': timestamp
}

results_path = f"outputs/results_{timestamp}.json"
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"Results saved to: {results_path}")

# Save training history with timestamp
history_path = f"outputs/training_history_{timestamp}.json"
with open(history_path, 'w') as f:
    json.dump(history, f, indent=2)
print(f"Training history saved to: {history_path}")

print("\n" + "="*60)
print("Experiment Complete!")
print("="*60)
print(f"\nSummary:")
print(f"  Training data: CARA labels (risk-aversion)")
print(f"  Validation data: Cooperate labels")
print(f"  Baseline accuracy:  {baseline_results['accuracy']*100:.1f}%")
print(f"  Final accuracy:     {final_accuracy*100:.1f}%")
print(f"  Improvement:        {(final_accuracy - baseline_results['accuracy'])*100:+.1f}%")
print(f"  Trainable params:   {trainable_params:,} ({100*trainable_params/total_params:.4f}%)")

# Find checkpoints before try block to avoid undefined variable
import glob
import shutil
checkpoint_dirs = glob.glob("outputs/best_model_epoch*")

# Automatic downloads (for Google Colab)
print("\n" + "="*60)
print("Downloading Results...")
print("="*60)

try:
    from google.colab import files
    
    # Download plots
    print(f"Downloading: {plot_path}")
    files.download(plot_path)
    
    # Download results JSON
    print(f"Downloading: {results_path}")
    files.download(results_path)
    
    # Download training history
    print(f"Downloading: {history_path}")
    files.download(history_path)
    
    # Download best model checkpoint (LoRA checkpoints are directories - zip them)
    if checkpoint_dirs:
        latest_checkpoint = max(checkpoint_dirs, key=os.path.getctime)
        if os.path.isdir(latest_checkpoint):
            zip_path = f"{latest_checkpoint}.zip"
            shutil.make_archive(latest_checkpoint, 'zip', latest_checkpoint)
            print(f"Downloading: {zip_path}")
            files.download(zip_path)
    
    print("\nAll files downloaded successfully!")
    
except ImportError:
    print("\nNOTE: Not running in Google Colab - files saved to outputs/ directory")
    print("Files available:")
    print(f"  - {plot_path}")
    print(f"  - {results_path}")
    print(f"  - {history_path}")
    if checkpoint_dirs:
        latest_checkpoint = max(checkpoint_dirs, key=os.path.getctime)
        print(f"  - {latest_checkpoint}/ (LoRA checkpoint directory)")