# Risk-Averse Reward Model Training

This notebook implements the complete risk aversion experiment for training **Qwen3-8B** to prefer risk-averse choices over risk-neutral ones.

## Features
- **Separate Train/Val Data**: Uses dedicated training and validation CSV files with different label types
- **CARA-based Training**: Trains on CARA (risk-averse) labels with smart incorrect label selection
- **Cooperation-based Validation**: Validates on cooperate labels to test generalization
- **Per-Epoch Re-randomization**: Training data with "both" bucket_label gets re-randomized each epoch

## Data Files
- **Training**: `data/2025_12_5_training_set_low_stakes_balanced.csv` (CARA labels)
- **Validation**: `data/2025_12_5_val_set_medium_stakes_balanced.csv` (cooperate labels)

**Memory Optimized for 8B Model:** Uses gradient checkpointing, smaller batch size, and fp16 for memory efficiency

## 1. Install Dependencies

In [None]:
# Install required packages
!pip install -q transformers datasets accelerate torch pandas numpy scikit-learn matplotlib seaborn hf_transfer peft

print("âœ“ Dependencies installed successfully!")

## 2. Import Libraries

In [None]:
import os
import sys
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from peft import LoraConfig, get_peft_model, TaskType
from sklearn.model_selection import train_test_split
import json
from typing import List, Dict, Tuple
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")

# Reproducibility function (will be called after config is loaded)
def set_seed(seed):
    """Set all seeds for reproducibility"""
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 3. Configuration

All configurable parameters for the experiment.

In [None]:
# =============================================================================
# CONFIGURATION - All configurable parameters for the experiment
# =============================================================================

# Model settings
MODEL_NAME = "Qwen/Qwen3-8B"          # Base model to use
MAX_LENGTH = 256                       # Maximum sequence length for tokenization

# Training hyperparameters
BATCH_SIZE = 2                         # Batch size per forward pass
LEARNING_RATE = 2e-4                   # Learning rate for LoRA layers
WEIGHT_DECAY = 0.01                    # L2 regularization
NUM_EPOCHS = 10                        # Number of training epochs

# LoRA configuration
LORA_R = 8                             # LoRA rank (low for small dataset)
LORA_ALPHA = 16                        # LoRA alpha (scaling = alpha/r = 2.0)
LORA_DROPOUT = 0.05                    # LoRA dropout for regularization
LORA_TARGET_MODULES = ["q_proj", "v_proj"]  # Query and Value attention projections

# Data settings - separate training and validation files
TRAIN_DATA_FILE = "data/2025_12_5_training_set_low_stakes_balanced.csv"
VAL_DATA_FILE = "data/2025_12_5_val_set_medium_stakes_balanced.csv"
RANDOM_SEED = 42                       # Random seed for reproducibility

# In-distribution validation split
IN_DIST_VAL_SPLIT = 0.10               # Hold out 10% of training data for in-distribution validation

# =============================================================================
# Set seed for reproducibility
set_seed(RANDOM_SEED)

print("Configuration loaded:")
print(f"  Model: {MODEL_NAME}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Random seed: {RANDOM_SEED}")
print(f"  In-dist val split: {IN_DIST_VAL_SPLIT*100:.0f}%")
print(f"\nLoRA Configuration:")
print(f"  Rank (r): {LORA_R}")
print(f"  Alpha: {LORA_ALPHA}")
print(f"  Dropout: {LORA_DROPOUT}")
print(f"  Target modules: {LORA_TARGET_MODULES}")
print(f"\nData Files:")
print(f"  Training: {TRAIN_DATA_FILE}")
print(f"  Validation (out-of-dist): {VAL_DATA_FILE}")

## 4. Data Loading Classes

Two specialized loaders for training and validation data:
- **TrainingDataLoader**: Loads CARA-based labels with `low_bucket_label` logic for incorrect label selection
- **ValidationDataLoader**: Loads cooperate-based labels for generalization testing

In [None]:
def get_train_val_situation_split(csv_file_path: str, val_fraction: float = 0.10, random_seed: int = 42):
    """Split situation IDs into train and validation sets, stratified by low_bucket_label.
    
    Args:
        csv_file_path: Path to training CSV file
        val_fraction: Fraction of situations to hold out for validation (default 0.10)
        random_seed: Random seed for reproducibility
        
    Returns:
        Tuple of (train_situation_ids, val_situation_ids) as lists
    """
    df = pd.read_csv(csv_file_path)
    
    # Get unique situations with their low_bucket_label for stratification
    situations = df.groupby('situation_id').first().reset_index()
    situation_ids = situations['situation_id'].tolist()
    
    # Get stratification labels
    if 'low_bucket_label' in situations.columns:
        strat_labels = situations['low_bucket_label'].apply(lambda x: x.strip('"') if isinstance(x, str) else x).tolist()
        try:
            train_ids, val_ids = train_test_split(
                situation_ids,
                test_size=val_fraction,
                random_state=random_seed,
                stratify=strat_labels
            )
            print(f"Stratified split by low_bucket_label: {len(train_ids)} train, {len(val_ids)} val")
        except ValueError:
            # Fall back to non-stratified if stratification fails
            train_ids, val_ids = train_test_split(
                situation_ids,
                test_size=val_fraction,
                random_state=random_seed
            )
            print(f"Non-stratified split (stratification failed): {len(train_ids)} train, {len(val_ids)} val")
    else:
        train_ids, val_ids = train_test_split(
            situation_ids,
            test_size=val_fraction,
            random_state=random_seed
        )
        print(f"Non-stratified split: {len(train_ids)} train, {len(val_ids)} val")
    
    return train_ids, val_ids


class TrainingDataLoader:
    """Load and process training data with CARA-based labels and low_bucket_label logic.
    
    This loader handles the special logic for selecting incorrect labels based on
    the low_bucket_label field:
    - "010_only": use CARA_alpha_0_10_best_labels as incorrect (avoid over-risk-aversion)
    - "lin_only": use linear_best_labels as incorrect (avoid being linear/risk-neutral)
    - "both": randomly choose between the two (re-randomizes each epoch)
    """
    
    def __init__(self, csv_file_path: str, epoch: int = 0, random_seed: int = 42, 
                 situation_ids: List = None):
        """
        Args:
            csv_file_path: Path to training CSV file
            epoch: Current epoch number (used for reproducible per-epoch randomization)
            random_seed: Base random seed for reproducibility
            situation_ids: Optional list of situation IDs to include (for train/val split)
        """
        self.csv_file_path = csv_file_path
        self.epoch = epoch
        self.rng = np.random.default_rng(random_seed + epoch)  # Per-epoch randomization
        self.situation_ids = situation_ids
        
    def load_and_process_data(self) -> pd.DataFrame:
        """Load CSV data and process it for training.
        
        Returns:
            DataFrame with columns: situation_id, prompt_text, correct_label, incorrect_label, low_bucket_label
        """
        # Check if CSV file exists
        if not os.path.exists(self.csv_file_path):
            raise FileNotFoundError(
                f"Required training data file '{self.csv_file_path}' not found. "
                f"Please ensure the CSV file exists."
            )
        
        # Load the CSV file
        df = pd.read_csv(self.csv_file_path)
        print(f"Loaded {len(df)} rows from {self.csv_file_path}")
        
        # Check required columns exist
        required_columns = ['situation_id', 'prompt_text', 'CARA_correct_labels', 'low_bucket_label']
        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            raise ValueError(
                f"Missing required columns in training CSV: {missing_columns}. "
                f"Available columns: {list(df.columns)}"
            )
        
        # Group by situation_id, take first row of each group (all rows have same labels)
        situations = df.groupby('situation_id').first().reset_index()
        
        # Filter to specified situation IDs if provided
        if self.situation_ids is not None:
            situations = situations[situations['situation_id'].isin(self.situation_ids)]
            print(f"Filtered to {len(situations)} situations (from {len(self.situation_ids)} specified IDs)")
        else:
            print(f"Found {len(situations)} unique situations")
        
        processed = []
        skipped = 0
        
        for _, row in situations.iterrows():
            try:
                prompt_text = row['prompt_text']
                
                # Parse JSON array for correct labels
                correct_labels = json.loads(row['CARA_correct_labels'])
                if not correct_labels:
                    skipped += 1
                    continue
                
                # Get low_bucket_label and determine incorrect labels
                low_bucket = row['low_bucket_label'].strip('"')  # Remove surrounding quotes
                
                if low_bucket == '010_only':
                    # Use CARA_alpha_0_10_best_labels as incorrect
                    incorrect_labels = json.loads(row['CARA_alpha_0_10_best_labels'])
                elif low_bucket == 'lin_only':
                    # Use linear_best_labels as incorrect
                    incorrect_labels = json.loads(row['linear_best_labels'])
                elif low_bucket == 'both':
                    # Randomly choose between linear and alpha_0_10 (re-randomizes each epoch)
                    if self.rng.random() < 0.5:
                        incorrect_labels = json.loads(row['linear_best_labels'])
                    else:
                        incorrect_labels = json.loads(row['CARA_alpha_0_10_best_labels'])
                else:
                    # Fallback: use CARA_incorrect_labels if available
                    incorrect_labels = json.loads(row.get('CARA_incorrect_labels', '[]'))
                
                if not incorrect_labels:
                    skipped += 1
                    continue
                
                # Randomly select one label from each array
                correct_label = str(self.rng.choice(correct_labels))
                incorrect_label = str(self.rng.choice(incorrect_labels))
                
                processed.append({
                    'situation_id': row['situation_id'],
                    'prompt_text': prompt_text,
                    'correct_label': correct_label,
                    'incorrect_label': incorrect_label,
                    'low_bucket_label': low_bucket,
                })
                
            except (json.JSONDecodeError, KeyError) as e:
                print(f"Warning: Error processing situation {row['situation_id']}: {e}")
                skipped += 1
                continue
        
        if skipped > 0:
            print(f"Warning: Skipped {skipped} situations due to missing/empty labels")
        
        result_df = pd.DataFrame(processed)
        print(f"Processed into {len(result_df)} training examples (epoch {self.epoch})")
        
        # Display low_bucket_label distribution
        if 'low_bucket_label' in result_df.columns and len(result_df) > 0:
            print(f"\nlow_bucket_label distribution:")
            for label, count in result_df['low_bucket_label'].value_counts().items():
                print(f"  {label}: {count} ({100*count/len(result_df):.1f}%)")
        
        return result_df


def _is_true(value) -> bool:
    """Check if a value is TRUE (handles NaN as FALSE).
    
    Boolean columns in the CSV only contain TRUE or NaN, where NaN means FALSE.
    """
    if pd.isna(value):
        return False
    return str(value).upper() == 'TRUE'


class InDistributionValidationDataLoader:
    """Load and process in-distribution validation data with CARA-based labels.
    
    Uses the same label logic as training but with fixed randomization (no per-epoch changes).
    Includes error_type categorization for breakdown analysis.
    """
    
    def __init__(self, csv_file_path: str, random_seed: int = 42, situation_ids: List = None):
        """
        Args:
            csv_file_path: Path to training CSV file
            random_seed: Random seed for reproducibility (fixed, no per-epoch changes)
            situation_ids: List of situation IDs to include (typically the held-out 10%)
        """
        self.csv_file_path = csv_file_path
        self.rng = np.random.default_rng(random_seed)
        self.situation_ids = situation_ids
    
    def load_and_process_data(self) -> pd.DataFrame:
        """Load CSV data and process it for in-distribution validation.
        
        Returns:
            DataFrame with columns: situation_id, prompt_text, correct_label, incorrect_label, 
                                   error_type, low_bucket_label
        """
        if not os.path.exists(self.csv_file_path):
            raise FileNotFoundError(
                f"Required data file '{self.csv_file_path}' not found."
            )
        
        df = pd.read_csv(self.csv_file_path)
        print(f"Loaded {len(df)} rows from {self.csv_file_path}")
        
        # Group by situation_id, take first row of each group
        situations = df.groupby('situation_id').first().reset_index()
        
        # Filter to specified situation IDs
        if self.situation_ids is not None:
            situations = situations[situations['situation_id'].isin(self.situation_ids)]
            print(f"Filtered to {len(situations)} in-distribution validation situations")
        else:
            print(f"Found {len(situations)} unique situations")
        
        processed = []
        skipped = 0
        error_type_counts = {'too_risky': 0, 'too_risk_averse': 0, 'other': 0}
        
        for _, row in situations.iterrows():
            try:
                prompt_text = row['prompt_text']
                
                # Parse JSON array for correct labels (CARA labels)
                correct_labels = json.loads(row['CARA_correct_labels'])
                if not correct_labels:
                    skipped += 1
                    continue
                
                # Get low_bucket_label and determine incorrect labels + error type
                low_bucket = row['low_bucket_label'].strip('"')
                
                # For validation: use fixed selection based on low_bucket_label
                # This determines both the incorrect label AND the error type
                if low_bucket == '010_only':
                    incorrect_labels = json.loads(row['CARA_alpha_0_10_best_labels'])
                    error_type = 'too_risk_averse'
                elif low_bucket == 'lin_only':
                    incorrect_labels = json.loads(row['linear_best_labels'])
                    error_type = 'too_risky'
                elif low_bucket == 'both':
                    # For 'both', randomly choose one but stay consistent (fixed seed)
                    if self.rng.random() < 0.5:
                        incorrect_labels = json.loads(row['linear_best_labels'])
                        error_type = 'too_risky'
                    else:
                        incorrect_labels = json.loads(row['CARA_alpha_0_10_best_labels'])
                        error_type = 'too_risk_averse'
                else:
                    incorrect_labels = json.loads(row.get('CARA_incorrect_labels', '[]'))
                    error_type = 'other'
                
                if not incorrect_labels:
                    skipped += 1
                    continue
                
                # Randomly select one label from each array
                correct_label = str(self.rng.choice(correct_labels))
                incorrect_label = str(self.rng.choice(incorrect_labels))
                
                error_type_counts[error_type] += 1
                
                processed.append({
                    'situation_id': row['situation_id'],
                    'prompt_text': prompt_text,
                    'correct_label': correct_label,
                    'incorrect_label': incorrect_label,
                    'error_type': error_type,
                    'low_bucket_label': low_bucket,
                })
                
            except (json.JSONDecodeError, KeyError) as e:
                print(f"Warning: Error processing situation {row['situation_id']}: {e}")
                skipped += 1
                continue
        
        if skipped > 0:
            print(f"Warning: Skipped {skipped} situations due to missing/empty labels")
        
        result_df = pd.DataFrame(processed)
        print(f"Processed into {len(result_df)} in-distribution validation examples")
        
        # Display error type distribution
        print(f"\nError type distribution (if model incorrectly prefers non-CARA option):")
        for error_type, count in error_type_counts.items():
            pct = 100 * count / len(result_df) if len(result_df) > 0 else 0
            print(f"  {error_type}: {count} ({pct:.1f}%)")
        
        return result_df


class ValidationDataLoader:
    """Load and process validation data with cooperate-based labels and error categorization.
    
    For validation, we use cooperate_correct_labels and cooperate_incorrect_labels.
    Each pair is categorized by error type:
    - "too_risky": incorrect option is linear_best (risk-seeking error)
    - "too_risk_averse": incorrect option is in CARA_alpha_0_10_best_labels (overly cautious)
    - "other": neither of the above
    """
    
    def __init__(self, csv_file_path: str, random_seed: int = 42):
        """
        Args:
            csv_file_path: Path to validation CSV file
            random_seed: Random seed for reproducibility
        """
        self.csv_file_path = csv_file_path
        self.rng = np.random.default_rng(random_seed)
    
    def load_and_process_data(self) -> pd.DataFrame:
        """Load CSV data and process it for validation.
        
        Returns:
            DataFrame with columns: situation_id, prompt_text, correct_label, incorrect_label, error_type
        """
        # Check if CSV file exists
        if not os.path.exists(self.csv_file_path):
            raise FileNotFoundError(
                f"Required validation data file '{self.csv_file_path}' not found. "
                f"Please ensure the CSV file exists."
            )
        
        # Load the CSV file
        df = pd.read_csv(self.csv_file_path)
        print(f"Loaded {len(df)} rows from {self.csv_file_path}")
        
        # Drop completely empty rows (CSV artifact - all columns are NaN)
        original_len = len(df)
        df = df.dropna(how='all')
        if len(df) < original_len:
            print(f"Dropped {original_len - len(df)} empty rows (CSV artifact)")
        
        # Check required columns exist
        required_columns = ['situation_id', 'prompt_text', 'cooperate_correct_labels', 'cooperate_incorrect_labels']
        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            raise ValueError(
                f"Missing required columns in validation CSV: {missing_columns}. "
                f"Available columns: {list(df.columns)}"
            )
        
        # Build a lookup for each option's properties (linear_best, option_type, etc.)
        # Key: (situation_id, label) -> properties dict
        option_properties = {}
        for _, row in df.iterrows():
            sit_id = row['situation_id']
            opt_idx = int(row['option_index'])
            
            # Determine the label letter (a, b, c, ... or 1, 2, 3, ... based on data)
            # First check if cooperate_correct_labels uses letters or numbers
            correct_labels_str = row['cooperate_correct_labels']
            if correct_labels_str and pd.notna(correct_labels_str):
                try:
                    sample_labels = json.loads(correct_labels_str)
                    if sample_labels and str(sample_labels[0]).isdigit():
                        # Uses numeric labels (1, 2, 3, ...)
                        label = str(opt_idx + 1)
                    else:
                        # Uses letter labels (a, b, c, ...)
                        label = chr(ord('a') + opt_idx)
                except json.JSONDecodeError:
                    label = chr(ord('a') + opt_idx)
            else:
                label = chr(ord('a') + opt_idx)
            
            # Boolean columns: TRUE or NaN (NaN means FALSE)
            is_linear_best = _is_true(row.get('is_best_linear_display'))
            is_cara_best = _is_true(row.get('is_best_cara_display'))
            is_rebel_fosd_all_coops = _is_true(row.get('option_is_rebel_fosd_all_coops'))
            is_coop_fosd_all_rebels = _is_true(row.get('option_is_coop_fosd_all_rebels'))
            is_rebel_best_cara = _is_true(row.get('option_is_rebel_best_cara'))
            is_coop_best_linear = _is_true(row.get('option_is_coop_best_linear'))
            
            option_type = row.get('option_type', '')
            
            # Get CARA_alpha_0_10_best_labels for this situation
            alpha_010_str = row.get('CARA_alpha_0_10_best_labels', '')
            alpha_010_labels = []
            if alpha_010_str and pd.notna(alpha_010_str) and str(alpha_010_str).strip():
                try:
                    alpha_010_labels = json.loads(alpha_010_str)
                except json.JSONDecodeError:
                    alpha_010_labels = []
            
            option_properties[(sit_id, label)] = {
                'is_linear_best': is_linear_best,
                'is_cara_best': is_cara_best,
                'is_rebel_fosd_all_coops': is_rebel_fosd_all_coops,
                'is_coop_fosd_all_rebels': is_coop_fosd_all_rebels,
                'is_rebel_best_cara': is_rebel_best_cara,
                'is_coop_best_linear': is_coop_best_linear,
                'option_type': option_type,
                'alpha_010_labels': alpha_010_labels,
            }
        
        # Group by situation_id, take first row of each group
        situations = df.groupby('situation_id').first().reset_index()
        print(f"Found {len(situations)} unique situations")
        
        processed = []
        skipped = 0
        error_type_counts = {'too_risky': 0, 'too_risk_averse': 0, 'other': 0}
        
        for _, row in situations.iterrows():
            try:
                sit_id = row['situation_id']
                
                # Skip if cooperate labels are missing
                if pd.isna(row['cooperate_correct_labels']) or pd.isna(row['cooperate_incorrect_labels']):
                    skipped += 1
                    continue
                
                # Parse JSON arrays
                correct_labels = json.loads(row['cooperate_correct_labels'])
                incorrect_labels = json.loads(row['cooperate_incorrect_labels'])
                
                # Skip situations with empty labels
                if not correct_labels or not incorrect_labels:
                    skipped += 1
                    continue
                
                # Randomly select one label from each array
                correct_label = str(self.rng.choice(correct_labels))
                incorrect_label = str(self.rng.choice(incorrect_labels))
                
                # Determine error type based on the selected incorrect label's properties
                props = option_properties.get((sit_id, incorrect_label), {})
                is_linear_best = props.get('is_linear_best', False)
                alpha_010_labels = props.get('alpha_010_labels', [])
                option_type = props.get('option_type', '')
                
                # Categorize the error type
                if is_linear_best:
                    error_type = 'too_risky'
                elif incorrect_label in alpha_010_labels:
                    error_type = 'too_risk_averse'
                else:
                    error_type = 'other'
                
                error_type_counts[error_type] += 1
                
                processed.append({
                    'situation_id': sit_id,
                    'prompt_text': row['prompt_text'],
                    'correct_label': correct_label,
                    'incorrect_label': incorrect_label,
                    'error_type': error_type,
                    'incorrect_option_type': option_type,
                })
                
            except (json.JSONDecodeError, KeyError) as e:
                print(f"Warning: Error processing situation {row['situation_id']}: {e}")
                skipped += 1
                continue
        
        if skipped > 0:
            print(f"Warning: Skipped {skipped} situations due to missing/empty labels")
        
        result_df = pd.DataFrame(processed)
        print(f"Processed into {len(result_df)} validation examples")
        
        # Display error type distribution
        print(f"\nError type distribution (if model incorrectly prefers non-cooperate):")
        for error_type, count in error_type_counts.items():
            pct = 100 * count / len(result_df) if len(result_df) > 0 else 0
            print(f"  {error_type}: {count} ({pct:.1f}%)")
        
        # Display option type distribution for incorrect options
        if 'incorrect_option_type' in result_df.columns and len(result_df) > 0:
            print(f"\nIncorrect option types:")
            for opt_type, count in result_df['incorrect_option_type'].value_counts().items():
                print(f"  {opt_type}: {count} ({100*count/len(result_df):.1f}%)")
        
        return result_df


print("Data loaders defined: TrainingDataLoader, InDistributionValidationDataLoader, ValidationDataLoader")

## 5. Load and Validate Data

Load the separate training and validation data files.

In [None]:
# =============================================================================
# SPLIT TRAINING DATA INTO TRAIN AND IN-DISTRIBUTION VALIDATION
# =============================================================================
print("Splitting training data into train and in-distribution validation sets...")
train_situation_ids, in_dist_val_situation_ids = get_train_val_situation_split(
    TRAIN_DATA_FILE,
    val_fraction=IN_DIST_VAL_SPLIT,
    random_seed=RANDOM_SEED
)

# =============================================================================
# LOAD TRAINING DATA (90% of low stakes)
# =============================================================================
print(f"\n{'='*60}")
print("Loading Training Data (90% of low stakes)")
print(f"{'='*60}")
train_loader = TrainingDataLoader(
    TRAIN_DATA_FILE, 
    epoch=0, 
    random_seed=RANDOM_SEED,
    situation_ids=train_situation_ids
)
train_df = train_loader.load_and_process_data()

# =============================================================================
# LOAD IN-DISTRIBUTION VALIDATION DATA (10% of low stakes)
# =============================================================================
print(f"\n{'='*60}")
print("Loading In-Distribution Validation Data (10% of low stakes)")
print(f"{'='*60}")
in_dist_val_loader = InDistributionValidationDataLoader(
    TRAIN_DATA_FILE,
    random_seed=RANDOM_SEED,
    situation_ids=in_dist_val_situation_ids
)
in_dist_val_df = in_dist_val_loader.load_and_process_data()

# =============================================================================
# LOAD OUT-OF-DISTRIBUTION VALIDATION DATA (medium stakes, cooperate labels)
# =============================================================================
print(f"\n{'='*60}")
print("Loading Out-of-Distribution Validation Data (medium stakes)")
print(f"{'='*60}")
out_dist_val_loader = ValidationDataLoader(VAL_DATA_FILE, random_seed=RANDOM_SEED)
out_dist_val_df = out_dist_val_loader.load_and_process_data()

# =============================================================================
# DATASET SUMMARY
# =============================================================================
print(f"\n{'='*60}")
print("Dataset Summary")
print(f"{'='*60}")
print(f"  Training file: {TRAIN_DATA_FILE}")
print(f"    Training situations: {len(train_df)} (90%)")
print(f"    In-dist validation situations: {len(in_dist_val_df)} (10%)")
print(f"  Out-of-dist validation file: {VAL_DATA_FILE}")
print(f"    Out-of-dist validation situations: {len(out_dist_val_df)}")
print(f"\n  Label types:")
print(f"    Training: CARA (risk-aversion)")
print(f"    In-dist validation: CARA (risk-aversion)")
print(f"    Out-of-dist validation: Cooperate (generalization test)")
print(f"{'='*60}")

# Validate data format
for df_name, df in [('train_df', train_df), ('in_dist_val_df', in_dist_val_df), ('out_dist_val_df', out_dist_val_df)]:
    assert 'prompt_text' in df.columns, f"Missing prompt_text column in {df_name}"
    assert 'correct_label' in df.columns, f"Missing correct_label column in {df_name}"
    assert 'incorrect_label' in df.columns, f"Missing incorrect_label column in {df_name}"

print("\nData validation passed!")
print("Data loading complete!")

## 6. Pairwise Dataset for Reward Modeling

Dataset class that provides pairs of (preferred, rejected) options for Bradley-Terry loss.

In [None]:
class PairwiseRewardDataset(Dataset):
    """Dataset for pairwise reward model training with Bradley-Terry loss.
    
    Supports optional error_type metadata for validation analysis.
    """
    
    def __init__(self, dataframe: pd.DataFrame, tokenizer, max_length: int = 256):
        """
        Args:
            dataframe: DataFrame with columns: prompt_text, correct_label, incorrect_label
                       Optional: error_type, incorrect_option_type (for validation breakdown)
            tokenizer: Tokenizer for encoding text
            max_length: Maximum sequence length
        """
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Check if error_type column exists (for validation data)
        self.has_error_types = 'error_type' in self.data.columns
        if self.has_error_types:
            self.error_types = self.data['error_type'].tolist()
        else:
            self.error_types = None
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Format: prompt + chosen option
        preferred_text = f"{row['prompt_text']}\n\nChosen option: {row['correct_label']}"
        rejected_text = f"{row['prompt_text']}\n\nChosen option: {row['incorrect_label']}"
        
        # Tokenize both options
        preferred_encoding = self.tokenizer(
            preferred_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        rejected_encoding = self.tokenizer(
            rejected_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        result = {
            'preferred_input_ids': preferred_encoding['input_ids'].squeeze(0),
            'preferred_attention_mask': preferred_encoding['attention_mask'].squeeze(0),
            'rejected_input_ids': rejected_encoding['input_ids'].squeeze(0),
            'rejected_attention_mask': rejected_encoding['attention_mask'].squeeze(0),
        }
        
        # Include error_type if available (for validation analysis)
        if self.has_error_types:
            result['error_type'] = self.error_types[idx]
        
        return result
    
    def get_error_type_indices(self) -> Dict[str, List[int]]:
        """Get indices grouped by error type for validation analysis.
        
        Returns:
            Dictionary mapping error_type to list of example indices
        """
        if not self.has_error_types:
            return {}
        
        indices = {'too_risky': [], 'too_risk_averse': [], 'other': []}
        for idx, error_type in enumerate(self.error_types):
            if error_type in indices:
                indices[error_type].append(idx)
        return indices

print("PairwiseRewardDataset defined (with error_type support)")

## 7. Reward Model Architecture

Qwen3-8B base model with LoRA adapters (q_proj, v_proj) + trainable scalar reward head.

In [None]:
class RewardModel(nn.Module):
    """Reward model with LoRA-adapted backbone and trainable scalar reward head"""
    
    def __init__(
        self, 
        model_name: str = "Qwen/Qwen3-8B",
        lora_r: int = 8,
        lora_alpha: int = 16,
        lora_dropout: float = 0.05,
        lora_target_modules: list = None,
    ):
        super().__init__()
        
        # Load base model in fp16
        print(f"Loading base model: {model_name}")
        self.backbone = AutoModel.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto",
        )
        
        hidden_size = self.backbone.config.hidden_size
        
        # Add reward head BEFORE applying LoRA
        self.reward_head = nn.Linear(hidden_size, 1, bias=True)
        
        # Initialize reward head with small weights
        nn.init.normal_(self.reward_head.weight, mean=0.0, std=0.01)
        nn.init.zeros_(self.reward_head.bias)
        
        # Configure and apply LoRA
        if lora_target_modules is None:
            lora_target_modules = ["q_proj", "v_proj"]
        
        lora_config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            target_modules=lora_target_modules,
            bias="none",
            task_type=TaskType.FEATURE_EXTRACTION,
        )
        
        self.backbone = get_peft_model(self.backbone, lora_config)
        
        # Print trainable parameter info
        print(f"\nLoRA Configuration:")
        print(f"  Rank: {lora_r}")
        print(f"  Alpha: {lora_alpha}")
        print(f"  Dropout: {lora_dropout}")
        print(f"  Target modules: {lora_target_modules}")
        
        self.backbone.print_trainable_parameters()
        
        print(f"\nReward head: Linear({hidden_size} -> 1) with bias")
        
        # Count total trainable parameters
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in self.parameters())
        print(f"Total trainable: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.4f}%)")
        
    def forward(self, input_ids, attention_mask):
        """
        Forward pass to compute scalar reward from final hidden state
        
        Args:
            input_ids: Input token IDs [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            
        Returns:
            rewards: Scalar reward scores [batch_size]
        """
        # Get hidden states - gradients flow through LoRA layers
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        
        # Extract final hidden state h_T (last non-padding token per sequence)
        sequence_lengths = attention_mask.sum(dim=1) - 1  # 0-indexed position
        batch_size = input_ids.shape[0]
        
        # Gather the hidden state at the last token position for each sequence
        hidden_states = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        last_hidden_states = hidden_states[
            torch.arange(batch_size, device=hidden_states.device),
            sequence_lengths
        ]
        
        # Convert to fp32 for numerical stability in reward head
        last_hidden_states = last_hidden_states.float()
        
        # Compute scalar reward: r = W^T * h_T + b
        rewards = self.reward_head(last_hidden_states).squeeze(-1)
        
        return rewards

print("RewardModel class defined (with LoRA support)")

## 8. Training Setup

Initialize model, tokenizer, datasets, loss function, and optimizer.

In [None]:
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Create datasets
print("Creating datasets...")
train_dataset = PairwiseRewardDataset(train_df, tokenizer, max_length=MAX_LENGTH)
in_dist_val_dataset = PairwiseRewardDataset(in_dist_val_df, tokenizer, max_length=MAX_LENGTH)
out_dist_val_dataset = PairwiseRewardDataset(out_dist_val_df, tokenizer, max_length=MAX_LENGTH)

print(f"  Training examples: {len(train_dataset)}")
print(f"  In-distribution validation examples: {len(in_dist_val_dataset)}")
print(f"  Out-of-distribution validation examples: {len(out_dist_val_dataset)}")


# Helper function for per-epoch re-randomization of training data
def recreate_training_dataset(epoch: int):
    """Recreate training dataset with new randomization for 'both' bucket cases.
    
    This function creates a fresh training dataset where situations with
    low_bucket_label='both' get a new random choice between linear_best_labels
    and CARA_alpha_0_10_best_labels.
    
    Args:
        epoch: Current epoch number (used for reproducible randomization)
        
    Returns:
        PairwiseRewardDataset with re-randomized label selections
    """
    loader = TrainingDataLoader(
        TRAIN_DATA_FILE, 
        epoch=epoch, 
        random_seed=RANDOM_SEED,
        situation_ids=train_situation_ids  # Only use training situations
    )
    new_train_df = loader.load_and_process_data()
    return PairwiseRewardDataset(new_train_df, tokenizer, max_length=MAX_LENGTH)


# Initialize model with LoRA
print("\nInitializing reward model with LoRA...")
model = RewardModel(
    model_name=MODEL_NAME,
    lora_r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    lora_target_modules=LORA_TARGET_MODULES,
)

# Get device from backbone
device = next(model.backbone.parameters()).device

# Move reward head to same device and ensure fp32
model.reward_head = model.reward_head.to(device).float()

print(f"\n  Device: {device}")
print(f"  Backbone dtype: {next(model.backbone.parameters()).dtype}")
print(f"  Reward head dtype: {next(model.reward_head.parameters()).dtype}")

# Bradley-Terry loss function
def bradley_terry_loss(preferred_rewards, rejected_rewards):
    """
    Bradley-Terry pairwise ranking loss
    Loss = -log(sigmoid(r_preferred - r_rejected))

    Encourages: r_preferred > r_rejected
    """
    return -torch.log(torch.sigmoid(preferred_rewards - rejected_rewards)).mean()

# Optimizer - train LoRA parameters AND reward head with different learning rates
optimizer = torch.optim.AdamW([
    {'params': model.backbone.parameters(), 'lr': LEARNING_RATE},
    {'params': model.reward_head.parameters(), 'lr': LEARNING_RATE * 2.5},  # Higher LR for reward head
], weight_decay=WEIGHT_DECAY)

print(f"\n{'='*60}")
print("Training Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  LoRA parameters LR: {LEARNING_RATE}")
print(f"  Reward head LR: {LEARNING_RATE * 2.5}")
print(f"  Weight decay (L2): {WEIGHT_DECAY}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Max sequence length: {MAX_LENGTH}")
print(f"  Per-epoch re-randomization: Enabled for 'both' bucket cases")
print(f"\nDataset sizes:")
print(f"  Training: {len(train_dataset)}")
print(f"  In-dist validation: {len(in_dist_val_dataset)}")
print(f"  Out-of-dist validation: {len(out_dist_val_dataset)}")
print(f"{'='*60}")
print("\nTraining setup complete!")

## 9. Evaluation Function

Compute pairwise accuracy: percentage of pairs where preferred option scores higher.

In [None]:
def evaluate_model(model, dataset, batch_size=None):
    """
    Evaluate model on pairwise accuracy with optional error type breakdown.
    
    Args:
        model: RewardModel to evaluate
        dataset: PairwiseRewardDataset (may include error_type metadata)
        batch_size: Batch size for evaluation (defaults to BATCH_SIZE * 2)
        
    Returns:
        dict with keys:
            - accuracy: Float, percentage of pairs where preferred scores higher
            - avg_loss: Float, average Bradley-Terry loss
            - preferred_scores: List of reward scores for preferred options
            - rejected_scores: List of reward scores for rejected options
            - error_type_breakdown: Dict with accuracy by error category (if available)
    """
    if batch_size is None:
        batch_size = BATCH_SIZE * 2  # Can use larger batch for eval (no gradients)
    
    model.eval()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    correct = 0
    total = 0
    total_loss = 0.0
    preferred_scores_list = []
    rejected_scores_list = []
    error_types_list = []
    correct_by_type = {'too_risky': 0, 'too_risk_averse': 0, 'other': 0}
    total_by_type = {'too_risky': 0, 'too_risk_averse': 0, 'other': 0}
    
    # Track per-example results for breakdown
    per_example_results = []
    
    with torch.no_grad():
        for batch in dataloader:
            # Get rewards for preferred options
            preferred_rewards = model(
                input_ids=batch['preferred_input_ids'].to(device),
                attention_mask=batch['preferred_attention_mask'].to(device)
            )
            
            # Get rewards for rejected options
            rejected_rewards = model(
                input_ids=batch['rejected_input_ids'].to(device),
                attention_mask=batch['rejected_attention_mask'].to(device)
            )
            
            # Compute loss
            loss = bradley_terry_loss(preferred_rewards, rejected_rewards)
            total_loss += loss.item() * len(preferred_rewards)
            
            # Compute accuracy: count pairs where preferred > rejected
            is_correct = (preferred_rewards > rejected_rewards)
            correct += is_correct.sum().item()
            total += len(preferred_rewards)
            
            # Store scores for analysis
            preferred_scores_list.extend(preferred_rewards.cpu().float().numpy())
            rejected_scores_list.extend(rejected_rewards.cpu().float().numpy())
            
            # Track by error type if available
            if 'error_type' in batch:
                batch_error_types = batch['error_type']
                error_types_list.extend(batch_error_types)
                
                for i, error_type in enumerate(batch_error_types):
                    if error_type in total_by_type:
                        total_by_type[error_type] += 1
                        if is_correct[i].item():
                            correct_by_type[error_type] += 1
                
                # Store per-example results
                for i in range(len(preferred_rewards)):
                    per_example_results.append({
                        'preferred_score': preferred_rewards[i].item(),
                        'rejected_score': rejected_rewards[i].item(),
                        'margin': (preferred_rewards[i] - rejected_rewards[i]).item(),
                        'is_correct': is_correct[i].item(),
                        'error_type': batch_error_types[i],
                    })
    
    accuracy = correct / total if total > 0 else 0.0
    avg_loss = total_loss / total if total > 0 else 0.0
    
    # Compute accuracy by error type
    error_type_breakdown = {}
    has_error_types = len(error_types_list) > 0
    
    if has_error_types:
        for error_type in ['too_risky', 'too_risk_averse', 'other']:
            if total_by_type[error_type] > 0:
                error_type_breakdown[error_type] = {
                    'accuracy': correct_by_type[error_type] / total_by_type[error_type],
                    'correct': correct_by_type[error_type],
                    'total': total_by_type[error_type],
                }
            else:
                error_type_breakdown[error_type] = {
                    'accuracy': 0.0,
                    'correct': 0,
                    'total': 0,
                }
    
    return {
        'accuracy': accuracy,
        'avg_loss': avg_loss,
        'preferred_scores': preferred_scores_list,
        'rejected_scores': rejected_scores_list,
        'error_type_breakdown': error_type_breakdown,
        'per_example_results': per_example_results if has_error_types else [],
        'error_types': error_types_list,
    }


def print_error_type_breakdown(breakdown: Dict, title: str = "Error Type Breakdown"):
    """Pretty print the error type breakdown."""
    if not breakdown:
        return
    
    print(f"\n  {title}:")
    print(f"  {'Category':<20} {'Accuracy':>10} {'Correct':>10} {'Total':>10}")
    print(f"  {'-'*52}")
    
    for error_type in ['too_risky', 'too_risk_averse', 'other']:
        if error_type in breakdown:
            stats = breakdown[error_type]
            acc_pct = stats['accuracy'] * 100
            print(f"  {error_type:<20} {acc_pct:>9.1f}% {stats['correct']:>10} {stats['total']:>10}")


print("Evaluation function defined (with error type breakdown)")

## 10. Baseline Evaluation (Before Training)

Evaluate the randomly initialized reward model to establish a baseline for comparison.

In [None]:
# =============================================================================
# BASELINE EVALUATION - Before any training
# =============================================================================
# This establishes how well the randomly initialized reward head performs
# Expected: ~50% accuracy (random guessing)

print("="*60)
print("BASELINE EVALUATION (Before Training)")
print("="*60)

# =============================================================================
# IN-DISTRIBUTION BASELINE (CARA labels, same as training)
# =============================================================================
print("\n" + "-"*60)
print("In-Distribution Validation (CARA labels)")
print("-"*60)

baseline_in_dist_eval = evaluate_model(model, in_dist_val_dataset)

baseline_in_dist_accuracy = baseline_in_dist_eval['accuracy']
baseline_in_dist_loss = baseline_in_dist_eval['avg_loss']
baseline_in_dist_pref_scores = baseline_in_dist_eval['preferred_scores']
baseline_in_dist_rej_scores = baseline_in_dist_eval['rejected_scores']
baseline_in_dist_error_breakdown = baseline_in_dist_eval['error_type_breakdown']

baseline_in_dist_margins = np.array(baseline_in_dist_pref_scores) - np.array(baseline_in_dist_rej_scores)

print(f"\nIn-Dist Baseline Results (Untrained Model):")
print(f"  Accuracy: {baseline_in_dist_accuracy:.4f} ({baseline_in_dist_accuracy*100:.2f}%)")
print(f"  Loss: {baseline_in_dist_loss:.4f}")
print(f"  Mean margin: {np.mean(baseline_in_dist_margins):.4f}")

if baseline_in_dist_error_breakdown:
    print_error_type_breakdown(baseline_in_dist_error_breakdown, "In-Dist Baseline Error Type Breakdown")

# Store in-dist baseline for later comparison
baseline_in_dist_results = {
    'accuracy': baseline_in_dist_accuracy,
    'loss': baseline_in_dist_loss,
    'preferred_scores': baseline_in_dist_pref_scores,
    'rejected_scores': baseline_in_dist_rej_scores,
    'margins': baseline_in_dist_margins.tolist(),
    'mean_margin': float(np.mean(baseline_in_dist_margins)),
    'std_margin': float(np.std(baseline_in_dist_margins)),
    'error_type_breakdown': baseline_in_dist_error_breakdown,
}

# =============================================================================
# OUT-OF-DISTRIBUTION BASELINE (Cooperate labels, generalization test)
# =============================================================================
print("\n" + "-"*60)
print("Out-of-Distribution Validation (Cooperate labels)")
print("-"*60)

baseline_out_dist_eval = evaluate_model(model, out_dist_val_dataset)

baseline_out_dist_accuracy = baseline_out_dist_eval['accuracy']
baseline_out_dist_loss = baseline_out_dist_eval['avg_loss']
baseline_out_dist_pref_scores = baseline_out_dist_eval['preferred_scores']
baseline_out_dist_rej_scores = baseline_out_dist_eval['rejected_scores']
baseline_out_dist_error_breakdown = baseline_out_dist_eval['error_type_breakdown']

baseline_out_dist_margins = np.array(baseline_out_dist_pref_scores) - np.array(baseline_out_dist_rej_scores)

print(f"\nOut-Dist Baseline Results (Untrained Model):")
print(f"  Accuracy: {baseline_out_dist_accuracy:.4f} ({baseline_out_dist_accuracy*100:.2f}%)")
print(f"  Loss: {baseline_out_dist_loss:.4f}")
print(f"  Mean margin: {np.mean(baseline_out_dist_margins):.4f}")

if baseline_out_dist_error_breakdown:
    print_error_type_breakdown(baseline_out_dist_error_breakdown, "Out-Dist Baseline Error Type Breakdown")

# Store out-dist baseline for later comparison
baseline_out_dist_results = {
    'accuracy': baseline_out_dist_accuracy,
    'loss': baseline_out_dist_loss,
    'preferred_scores': baseline_out_dist_pref_scores,
    'rejected_scores': baseline_out_dist_rej_scores,
    'margins': baseline_out_dist_margins.tolist(),
    'mean_margin': float(np.mean(baseline_out_dist_margins)),
    'std_margin': float(np.std(baseline_out_dist_margins)),
    'error_type_breakdown': baseline_out_dist_error_breakdown,
}

# Legacy alias for backward compatibility
baseline_results = baseline_out_dist_results

# =============================================================================
# BASELINE SUMMARY
# =============================================================================
print(f"\n{'='*60}")
print("BASELINE SUMMARY")
print(f"{'='*60}")
print(f"  In-Distribution (CARA):        {baseline_in_dist_accuracy*100:.1f}%")
print(f"  Out-of-Distribution (Cooperate): {baseline_out_dist_accuracy*100:.1f}%")
print(f"\nExpected baseline: ~50% (random initialization)")
if abs(baseline_in_dist_accuracy - 0.5) < 0.1 and abs(baseline_out_dist_accuracy - 0.5) < 0.1:
    print("Both baselines are near random as expected.")
else:
    print("NOTE: Baseline deviates from 50% - may indicate bias in initialization or data.")
print("="*60)

## 11. Training Loop

Train the model with logging, validation, and checkpointing.

In [None]:
# Create output directory for checkpoints
os.makedirs("outputs", exist_ok=True)

# Training history - now tracks both in-dist and out-of-dist validation
history = {
    'train_loss': [],
    'train_steps': [],
    'epochs': [],
    'reward_margins': [],
    'preferred_rewards': [],
    'rejected_rewards': [],
    # In-distribution validation (CARA labels)
    'in_dist_val_accuracy': [],
    'in_dist_val_loss': [],
    'in_dist_error_type_accuracy': {'too_risky': [], 'too_risk_averse': [], 'other': []},
    # Out-of-distribution validation (Cooperate labels)
    'out_dist_val_accuracy': [],
    'out_dist_val_loss': [],
    'out_dist_error_type_accuracy': {'too_risky': [], 'too_risk_averse': [], 'other': []},
}

# Legacy aliases for backward compatibility
history['val_accuracy'] = history['out_dist_val_accuracy']
history['val_loss'] = history['out_dist_val_loss']
history['error_type_accuracy'] = history['out_dist_error_type_accuracy']

print("Starting training...")
print(f"Training data will be re-randomized each epoch for 'both' bucket cases")
print(f"Validating on both in-distribution (CARA) and out-of-distribution (Cooperate) sets\n")

best_val_accuracy = 0.0
global_step = 0

# Store initial weights for comparison
initial_weight = model.reward_head.weight.clone().detach()
initial_bias = model.reward_head.bias.clone().detach()

for epoch in range(NUM_EPOCHS):
    print(f"{'='*60}")
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
    print(f"{'='*60}")
    
    # Recreate training dataset with epoch-specific randomization
    # This ensures 'both' bucket cases get new random label choices each epoch
    if epoch > 0:
        print(f"  Re-randomizing training data for epoch {epoch + 1}...")
        train_dataset = recreate_training_dataset(epoch)
    
    # Create dataloader for this epoch
    train_dataloader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True,
        pin_memory=True if torch.cuda.is_available() else False,
    )
    
    print(f"  Batches this epoch: {len(train_dataloader)}")
    
    model.train()
    epoch_loss = 0.0
    epoch_preferred_rewards = []
    epoch_rejected_rewards = []
    epoch_margins = []
    
    for step, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        
        # Forward pass for preferred options
        preferred_rewards = model(
            input_ids=batch['preferred_input_ids'].to(device),
            attention_mask=batch['preferred_attention_mask'].to(device)
        )
        
        # Forward pass for rejected options
        rejected_rewards = model(
            input_ids=batch['rejected_input_ids'].to(device),
            attention_mask=batch['rejected_attention_mask'].to(device)
        )
        
        # Track reward statistics
        epoch_preferred_rewards.extend(preferred_rewards.detach().cpu().tolist())
        epoch_rejected_rewards.extend(rejected_rewards.detach().cpu().tolist())
        reward_margin = (preferred_rewards - rejected_rewards).detach()
        epoch_margins.extend(reward_margin.cpu().tolist())
        
        # Compute Bradley-Terry loss
        loss = bradley_terry_loss(preferred_rewards, rejected_rewards)
        
        # Check for NaN loss
        if torch.isnan(loss):
            print(f"  WARNING: NaN loss detected at step {step + 1}")
            print(f"    Preferred rewards: {preferred_rewards}")
            print(f"    Rejected rewards: {rejected_rewards}")
            continue
        
        # Backward pass
        loss.backward()
        epoch_loss += loss.item()
        
        # Detailed diagnostics on first few steps
        if global_step < 3:
            print(f"\n  === Diagnostics for Step {global_step + 1} ===")
            
            # Reward statistics
            print(f"  Rewards:")
            print(f"    Preferred: mean={preferred_rewards.mean().item():.4f}, std={preferred_rewards.std().item():.4f}")
            print(f"    Rejected:  mean={rejected_rewards.mean().item():.4f}, std={rejected_rewards.std().item():.4f}")
            print(f"    Margin:    mean={reward_margin.mean().item():.4f}, std={reward_margin.std().item():.4f}")
            print(f"    Loss:      {loss.item():.4f}")
            
            # Gradient statistics - reward head
            print(f"  Gradients (Reward Head):")
            for name, param in model.reward_head.named_parameters():
                if param.grad is not None:
                    grad_norm = param.grad.norm().item()
                    grad_mean = param.grad.mean().item()
                    grad_max = param.grad.abs().max().item()
                    print(f"    {name}: norm={grad_norm:.6f}, mean={grad_mean:.6f}, max={grad_max:.6f}")
                else:
                    print(f"    {name}: NO GRADIENT!")
            
            # Gradient statistics - LoRA layers
            lora_grad_norms = []
            for name, param in model.backbone.named_parameters():
                if param.requires_grad and param.grad is not None:
                    lora_grad_norms.append(param.grad.norm().item())
            if lora_grad_norms:
                print(f"  Gradients (LoRA layers):")
                print(f"    mean_norm={np.mean(lora_grad_norms):.6f}, max_norm={max(lora_grad_norms):.6f}")
            
            # Parameter statistics
            print(f"  Parameters:")
            weight_change = (model.reward_head.weight - initial_weight).abs().max().item()
            bias_change = (model.reward_head.bias - initial_bias).abs().max().item()
            print(f"    Reward head weight max change: {weight_change:.6f}")
            print(f"    Reward head bias max change: {bias_change:.6f}")
            print()
        
        # Optimizer step
        optimizer.step()
        global_step += 1
        
        # Log every 50 steps (adjusted for smaller dataset)
        if (step + 1) % 50 == 0:
            avg_loss = epoch_loss / (step + 1)
            weight_change = (model.reward_head.weight - initial_weight).abs().max().item()
            recent_margin = np.mean(epoch_margins[-50:]) if len(epoch_margins) >= 50 else np.mean(epoch_margins)
            
            print(f"  Step {step + 1}/{len(train_dataloader)} | "
                  f"Loss: {avg_loss:.4f} | "
                  f"Margin: {recent_margin:+.4f} | "
                  f"Weight: {weight_change:.6f}")
            
            history['train_loss'].append(avg_loss)
            history['train_steps'].append(global_step)
    
    # End of epoch statistics
    avg_train_loss = epoch_loss / len(train_dataloader)
    
    # Weight changes
    total_weight_change = (model.reward_head.weight - initial_weight).abs().mean().item()
    total_bias_change = (model.reward_head.bias - initial_bias).abs().mean().item()
    
    # Reward statistics
    mean_preferred = np.mean(epoch_preferred_rewards)
    mean_rejected = np.mean(epoch_rejected_rewards)
    mean_margin = np.mean(epoch_margins)
    std_margin = np.std(epoch_margins)
    percent_correct = np.mean([m > 0 for m in epoch_margins]) * 100
    
    print(f"\n  Epoch {epoch + 1} Summary:")
    print(f"    Loss: {avg_train_loss:.4f}")
    print(f"    Rewards: preferred={mean_preferred:+.4f}, rejected={mean_rejected:+.4f}")
    print(f"    Margin: mean={mean_margin:+.4f}, std={std_margin:.4f}")
    print(f"    Correct ranking: {percent_correct:.1f}% (preferred > rejected)")
    print(f"    Weight changes: weight={total_weight_change:.6f}, bias={total_bias_change:.6f}")
    
    # Store for history
    history['reward_margins'].append(mean_margin)
    history['preferred_rewards'].append(mean_preferred)
    history['rejected_rewards'].append(mean_rejected)
    history['epochs'].append(epoch + 1)
    
    # ==========================================================================
    # IN-DISTRIBUTION VALIDATION (CARA labels)
    # ==========================================================================
    print(f"\n  Running in-distribution validation (CARA labels)...")
    in_dist_eval = evaluate_model(model, in_dist_val_dataset)
    
    in_dist_accuracy = in_dist_eval['accuracy']
    in_dist_loss = in_dist_eval['avg_loss']
    in_dist_error_breakdown = in_dist_eval['error_type_breakdown']
    
    print(f"    In-dist accuracy: {in_dist_accuracy:.4f} ({in_dist_accuracy*100:.2f}%)")
    print(f"    In-dist loss: {in_dist_loss:.4f}")
    
    if in_dist_error_breakdown:
        print_error_type_breakdown(in_dist_error_breakdown, "In-Dist Error Type Breakdown")
    
    # Save to history
    history['in_dist_val_accuracy'].append(in_dist_accuracy)
    history['in_dist_val_loss'].append(in_dist_loss)
    
    for error_type in ['too_risky', 'too_risk_averse', 'other']:
        if error_type in in_dist_error_breakdown:
            history['in_dist_error_type_accuracy'][error_type].append(
                in_dist_error_breakdown[error_type]['accuracy']
            )
    
    # ==========================================================================
    # OUT-OF-DISTRIBUTION VALIDATION (Cooperate labels)
    # ==========================================================================
    print(f"\n  Running out-of-distribution validation (Cooperate labels)...")
    out_dist_eval = evaluate_model(model, out_dist_val_dataset)
    
    out_dist_accuracy = out_dist_eval['accuracy']
    out_dist_loss = out_dist_eval['avg_loss']
    out_dist_pref_scores = out_dist_eval['preferred_scores']
    out_dist_rej_scores = out_dist_eval['rejected_scores']
    out_dist_error_breakdown = out_dist_eval['error_type_breakdown']
    
    out_dist_margin = np.mean(out_dist_pref_scores) - np.mean(out_dist_rej_scores)
    
    print(f"    Out-dist accuracy: {out_dist_accuracy:.4f} ({out_dist_accuracy*100:.2f}%)")
    print(f"    Out-dist loss: {out_dist_loss:.4f}")
    print(f"    Out-dist margin: {out_dist_margin:+.4f}")
    
    if out_dist_error_breakdown:
        print_error_type_breakdown(out_dist_error_breakdown, "Out-Dist Error Type Breakdown")
    
    # Save to history
    history['out_dist_val_accuracy'].append(out_dist_accuracy)
    history['out_dist_val_loss'].append(out_dist_loss)
    
    for error_type in ['too_risky', 'too_risk_averse', 'other']:
        if error_type in out_dist_error_breakdown:
            history['out_dist_error_type_accuracy'][error_type].append(
                out_dist_error_breakdown[error_type]['accuracy']
            )
    
    # ==========================================================================
    # CHECKPOINT SAVING (based on out-of-distribution accuracy for generalization)
    # ==========================================================================
    if out_dist_accuracy > best_val_accuracy:
        best_val_accuracy = out_dist_accuracy
        checkpoint_dir = f"outputs/best_model_epoch{epoch+1}"
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        # Save LoRA adapter weights
        model.backbone.save_pretrained(checkpoint_dir)
        
        # Save reward head separately
        torch.save({
            'reward_head_state_dict': model.reward_head.state_dict(),
            'epoch': epoch + 1,
            'in_dist_val_accuracy': in_dist_accuracy,
            'out_dist_val_accuracy': out_dist_accuracy,
            'in_dist_val_loss': in_dist_loss,
            'out_dist_val_loss': out_dist_loss,
            'in_dist_error_breakdown': in_dist_error_breakdown,
            'out_dist_error_breakdown': out_dist_error_breakdown,
            'optimizer_state_dict': optimizer.state_dict(),
            'lora_config': {
                'r': LORA_R,
                'alpha': LORA_ALPHA,
                'dropout': LORA_DROPOUT,
                'target_modules': LORA_TARGET_MODULES,
            }
        }, os.path.join(checkpoint_dir, "reward_head.pt"))
        
        print(f"\n    New best! Saved checkpoint: {checkpoint_dir}")
    
    # Clear GPU cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print()

print(f"\n{'='*60}")
print("Training Complete!")
print(f"{'='*60}")
print(f"Best out-of-dist validation accuracy: {best_val_accuracy:.4f} ({best_val_accuracy*100:.2f}%)")
print(f"Total steps: {global_step}")
print(f"Final model saved to: outputs/best_model_epoch*/")

## 12. Visualization and Results

Plot training curves, compare against baseline, and analyze model performance.

In [None]:
# =============================================================================
# FINAL EVALUATION ON BOTH VALIDATION SETS
# =============================================================================
print("Evaluating final model on both validation sets...")

# In-distribution final evaluation
print("\nIn-distribution validation (CARA labels)...")
final_in_dist_eval = evaluate_model(model, in_dist_val_dataset)
final_in_dist_accuracy = final_in_dist_eval['accuracy']
final_in_dist_loss = final_in_dist_eval['avg_loss']
final_in_dist_pref_scores = final_in_dist_eval['preferred_scores']
final_in_dist_rej_scores = final_in_dist_eval['rejected_scores']
final_in_dist_error_breakdown = final_in_dist_eval['error_type_breakdown']
final_in_dist_margins = np.array(final_in_dist_pref_scores) - np.array(final_in_dist_rej_scores)

# Out-of-distribution final evaluation
print("Out-of-distribution validation (Cooperate labels)...")
final_out_dist_eval = evaluate_model(model, out_dist_val_dataset)
final_out_dist_accuracy = final_out_dist_eval['accuracy']
final_out_dist_loss = final_out_dist_eval['avg_loss']
final_out_dist_pref_scores = final_out_dist_eval['preferred_scores']
final_out_dist_rej_scores = final_out_dist_eval['rejected_scores']
final_out_dist_error_breakdown = final_out_dist_eval['error_type_breakdown']
final_out_dist_per_example = final_out_dist_eval['per_example_results']
final_out_dist_margins = np.array(final_out_dist_pref_scores) - np.array(final_out_dist_rej_scores)

# Legacy aliases for backward compatibility
final_accuracy = final_out_dist_accuracy
final_loss = final_out_dist_loss
preferred_scores = final_out_dist_pref_scores
rejected_scores = final_out_dist_rej_scores
final_error_breakdown = final_out_dist_error_breakdown
per_example_results = final_out_dist_per_example
reward_margins = final_out_dist_margins

print(f"\n{'='*60}")
print("FINAL VALIDATION RESULTS")
print(f"{'='*60}")
print(f"\nIn-Distribution (CARA labels):")
print(f"  Accuracy: {final_in_dist_accuracy:.4f} ({final_in_dist_accuracy*100:.2f}%)")
print(f"  Loss: {final_in_dist_loss:.4f}")
print(f"  Mean margin: {np.mean(final_in_dist_margins):.4f}")

print(f"\nOut-of-Distribution (Cooperate labels):")
print(f"  Accuracy: {final_out_dist_accuracy:.4f} ({final_out_dist_accuracy*100:.2f}%)")
print(f"  Loss: {final_out_dist_loss:.4f}")
print(f"  Mean margin: {np.mean(final_out_dist_margins):.4f}")

# Print error type breakdowns
if final_in_dist_error_breakdown:
    print_error_type_breakdown(final_in_dist_error_breakdown, "In-Dist Final Error Type Breakdown")
if final_out_dist_error_breakdown:
    print_error_type_breakdown(final_out_dist_error_breakdown, "Out-Dist Final Error Type Breakdown")

# =============================================================================
# BASELINE vs TRAINED COMPARISON
# =============================================================================
print(f"\n{'='*60}")
print("BASELINE vs TRAINED MODEL COMPARISON")
print(f"{'='*60}")

print(f"\n                         Baseline    Trained     Improvement")
print(f"  In-Dist Accuracy:      {baseline_in_dist_results['accuracy']*100:6.2f}%     {final_in_dist_accuracy*100:6.2f}%     {(final_in_dist_accuracy - baseline_in_dist_results['accuracy'])*100:+6.2f}%")
print(f"  Out-Dist Accuracy:     {baseline_out_dist_results['accuracy']*100:6.2f}%     {final_out_dist_accuracy*100:6.2f}%     {(final_out_dist_accuracy - baseline_out_dist_results['accuracy'])*100:+6.2f}%")
print(f"  In-Dist Loss:          {baseline_in_dist_results['loss']:6.4f}      {final_in_dist_loss:6.4f}      {final_in_dist_loss - baseline_in_dist_results['loss']:+6.4f}")
print(f"  Out-Dist Loss:         {baseline_out_dist_results['loss']:6.4f}      {final_out_dist_loss:6.4f}      {final_out_dist_loss - baseline_out_dist_results['loss']:+6.4f}")

# Compare error type breakdown for out-of-dist
if final_out_dist_error_breakdown and baseline_out_dist_results.get('error_type_breakdown'):
    print(f"\n  Out-Dist Error Type Accuracy Comparison:")
    print(f"  {'Category':<20} {'Baseline':>10} {'Trained':>10} {'Improvement':>12}")
    print(f"  {'-'*54}")
    for error_type in ['too_risky', 'too_risk_averse', 'other']:
        base_acc = baseline_out_dist_results['error_type_breakdown'].get(error_type, {}).get('accuracy', 0) * 100
        train_acc = final_out_dist_error_breakdown.get(error_type, {}).get('accuracy', 0) * 100
        improvement = train_acc - base_acc
        print(f"  {error_type:<20} {base_acc:>9.1f}% {train_acc:>9.1f}% {improvement:>+11.1f}%")

print(f"{'='*60}")

# =============================================================================
# CREATE COMPREHENSIVE VISUALIZATIONS (5x3 grid with in-dist overlays)
# =============================================================================
fig = plt.figure(figsize=(20, 22))
gs = fig.add_gridspec(5, 3, hspace=0.35, wspace=0.3)

fig.suptitle('Reward Model Training Results\nTrained on CARA (Risk-Aversion) | Validated on Both In-Dist and Out-of-Dist', 
             fontsize=16, fontweight='bold', y=0.995)

# =============================================================================
# Row 1: Training Progress
# =============================================================================

# Plot 1: Training Loss Over Time
ax1 = fig.add_subplot(gs[0, 0])
if len(history['train_steps']) > 0:
    ax1.plot(history['train_steps'], history['train_loss'], 'b-', linewidth=2, label='Training Loss')
    ax1.axhline(y=baseline_out_dist_results['loss'], color='gray', linestyle='--', alpha=0.7, 
                label=f'Out-Dist Baseline ({baseline_out_dist_results["loss"]:.3f})')
    ax1.set_xlabel('Training Steps')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss Over Time')
    ax1.legend(fontsize=8)
    ax1.grid(True, alpha=0.3)

# Plot 2: Validation Accuracy Over Epochs (BOTH in-dist and out-of-dist)
ax2 = fig.add_subplot(gs[0, 1])
if len(history['epochs']) > 0:
    # Out-of-distribution (solid line)
    epochs_with_baseline = [0] + history['epochs']
    out_dist_acc_with_baseline = [baseline_out_dist_results['accuracy']] + history['out_dist_val_accuracy']
    ax2.plot(epochs_with_baseline, out_dist_acc_with_baseline, 'g-o', linewidth=2, markersize=8, 
             label='Out-Dist (Cooperate)')
    
    # In-distribution (dashed line)
    in_dist_acc_with_baseline = [baseline_in_dist_results['accuracy']] + history['in_dist_val_accuracy']
    ax2.plot(epochs_with_baseline, in_dist_acc_with_baseline, 'b--s', linewidth=2, markersize=7, 
             label='In-Dist (CARA)')
    
    ax2.axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='Random (50%)')
    ax2.scatter([0], [baseline_out_dist_results['accuracy']], color='orange', s=100, zorder=5, marker='D')
    ax2.scatter([0], [baseline_in_dist_results['accuracy']], color='orange', s=100, zorder=5, marker='D', 
                label='Baseline')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Validation Accuracy Over Training\n(In-Dist vs Out-of-Dist)')
    ax2.set_ylim([0, 1])
    ax2.legend(fontsize=8, loc='lower right')
    ax2.grid(True, alpha=0.3)

# Plot 3: Reward Margin Progression (Training)
ax3 = fig.add_subplot(gs[0, 2])
if len(history['epochs']) > 0 and len(history['reward_margins']) > 0:
    epochs_with_baseline = [0] + history['epochs']
    margins_with_baseline = [baseline_out_dist_results['mean_margin']] + history['reward_margins']
    ax3.plot(epochs_with_baseline, margins_with_baseline, 'purple', linewidth=2, marker='s', markersize=8)
    ax3.axhline(y=0, color='r', linestyle='--', alpha=0.5, label='No Preference')
    ax3.axhline(y=baseline_out_dist_results['mean_margin'], color='gray', linestyle=':', alpha=0.7, 
                label=f'Baseline ({baseline_out_dist_results["mean_margin"]:.3f})')
    ax3.scatter([0], [baseline_out_dist_results['mean_margin']], color='orange', s=100, zorder=5, marker='s', 
                label='Before Training')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Mean Reward Margin')
    ax3.set_title('Preference Strength (Training)\n(Preferred - Rejected)')
    ax3.legend(fontsize=8)
    ax3.grid(True, alpha=0.3)

# =============================================================================
# Row 2: Score Distributions (Out-of-Dist Trained Model)
# =============================================================================

# Plot 4: Score Distribution Comparison (Histogram) - Out-Dist Trained
ax4 = fig.add_subplot(gs[1, 0])
bins = np.linspace(min(min(preferred_scores), min(rejected_scores)),
                  max(max(preferred_scores), max(rejected_scores)), 30)
ax4.hist(preferred_scores, bins=bins, alpha=0.6, label='Preferred (Cooperate)', 
         color='green', density=True, edgecolor='black', linewidth=0.5)
ax4.hist(rejected_scores, bins=bins, alpha=0.6, label='Rejected (Non-Cooperate)', 
         color='red', density=True, edgecolor='black', linewidth=0.5)
ax4.set_xlabel('Reward Score')
ax4.set_ylabel('Density')
ax4.set_title('Out-Dist Trained: Score Distribution')
ax4.legend()
ax4.grid(True, alpha=0.3, axis='y')

# Plot 5: Scatter Plot - Preferred vs Rejected Scores (Out-Dist Trained)
ax5 = fig.add_subplot(gs[1, 1])
ax5.scatter(rejected_scores, preferred_scores, alpha=0.5, s=30, c=reward_margins, 
            cmap='RdYlGn', edgecolors='black', linewidth=0.5)
min_val = min(min(rejected_scores), min(preferred_scores))
max_val = max(max(rejected_scores), max(preferred_scores))
ax5.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.5, linewidth=2, label='Equal Scores')
ax5.fill_between([min_val, max_val], [min_val, max_val], [max_val, max_val],
                alpha=0.15, color='green', label='Correct (Preferred > Rejected)')
ax5.fill_between([min_val, max_val], [min_val, min_val], [min_val, max_val],
                alpha=0.15, color='red', label='Incorrect')
ax5.set_xlabel('Rejected Score')
ax5.set_ylabel('Preferred Score')
ax5.set_title('Out-Dist Trained: Pairwise Comparison')
ax5.legend(fontsize=8)
ax5.grid(True, alpha=0.3)
ax5.axis('equal')

# Plot 6: Reward Margin Distribution (Out-Dist Trained)
ax6 = fig.add_subplot(gs[1, 2])
ax6.hist(reward_margins, bins=40, alpha=0.7, color='purple', edgecolor='black', linewidth=0.5)
ax6.axvline(x=0, color='r', linestyle='--', linewidth=2, label='No Preference')
ax6.axvline(x=np.mean(reward_margins), color='g', linestyle='-', linewidth=2, 
            label=f'Out-Dist Mean: {np.mean(reward_margins):.3f}')
ax6.axvline(x=np.mean(final_in_dist_margins), color='b', linestyle='--', linewidth=2, 
            label=f'In-Dist Mean: {np.mean(final_in_dist_margins):.3f}')
ax6.set_xlabel('Reward Margin (Preferred - Rejected)')
ax6.set_ylabel('Count')
ax6.set_title('Trained: Margin Distribution Comparison')
ax6.legend(fontsize=8)
ax6.grid(True, alpha=0.3, axis='y')

# =============================================================================
# Row 3: Error Type Breakdown (with in-dist comparison)
# =============================================================================

# Plot 7: Error Type Accuracy Comparison (Baseline vs Trained, In-Dist vs Out-Dist)
ax7 = fig.add_subplot(gs[2, 0])
if final_out_dist_error_breakdown:
    error_types = ['too_risky', 'too_risk_averse', 'other']
    error_type_labels = ['Too Risky', 'Too Risk-Averse', 'Other']
    x = np.arange(len(error_types))
    width = 0.2
    
    # Out-dist baseline
    out_baseline_accs = [baseline_out_dist_results.get('error_type_breakdown', {}).get(et, {}).get('accuracy', 0) * 100 
                        for et in error_types]
    # Out-dist trained
    out_trained_accs = [final_out_dist_error_breakdown.get(et, {}).get('accuracy', 0) * 100 
                       for et in error_types]
    # In-dist trained (if available)
    in_trained_accs = [final_in_dist_error_breakdown.get(et, {}).get('accuracy', 0) * 100 
                      for et in error_types] if final_in_dist_error_breakdown else [0, 0, 0]
    
    bars1 = ax7.bar(x - width, out_baseline_accs, width, label='Out-Dist Baseline', color='gray', alpha=0.7)
    bars2 = ax7.bar(x, out_trained_accs, width, label='Out-Dist Trained', color='green', alpha=0.7)
    bars3 = ax7.bar(x + width, in_trained_accs, width, label='In-Dist Trained', color='blue', alpha=0.7)
    
    ax7.axhline(y=50, color='r', linestyle='--', alpha=0.5, label='Random (50%)')
    ax7.set_ylabel('Accuracy (%)')
    ax7.set_title('Accuracy by Error Type\n(Baseline vs Trained, In vs Out)')
    ax7.set_xticks(x)
    ax7.set_xticklabels(error_type_labels, fontsize=9)
    ax7.legend(fontsize=7, loc='upper right')
    ax7.set_ylim([0, 100])
    ax7.grid(True, alpha=0.3, axis='y')

# Plot 8: Scatter Plot Colored by Error Type (Out-Dist)
ax8 = fig.add_subplot(gs[2, 1])
if per_example_results:
    colors = {'too_risky': 'orange', 'too_risk_averse': 'blue', 'other': 'gray'}
    labels = {'too_risky': 'Too Risky', 'too_risk_averse': 'Too Risk-Averse', 'other': 'Other'}
    
    for error_type in ['too_risky', 'too_risk_averse', 'other']:
        examples = [r for r in per_example_results if r['error_type'] == error_type]
        if examples:
            pref = [r['preferred_score'] for r in examples]
            rej = [r['rejected_score'] for r in examples]
            ax8.scatter(rej, pref, alpha=0.5, s=30, c=colors[error_type], 
                       label=f'{labels[error_type]} (n={len(examples)})', edgecolors='black', linewidth=0.3)
    
    min_val = min(min(preferred_scores), min(rejected_scores))
    max_val = max(max(preferred_scores), max(rejected_scores))
    ax8.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.5, linewidth=2)
    ax8.set_xlabel('Rejected Score')
    ax8.set_ylabel('Preferred Score')
    ax8.set_title('Out-Dist: Pairwise by Error Type')
    ax8.legend(fontsize=8)
    ax8.grid(True, alpha=0.3)
    ax8.axis('equal')

# Plot 9: Error Type Accuracy Over Epochs (Both in-dist and out-of-dist)
ax9 = fig.add_subplot(gs[2, 2])
if history['out_dist_error_type_accuracy']['too_risky']:
    epochs_for_plot = history['epochs']
    colors = {'too_risky': 'orange', 'too_risk_averse': 'blue', 'other': 'gray'}
    
    # Out-of-dist (solid lines)
    for error_type in ['too_risky', 'too_risk_averse']:
        if history['out_dist_error_type_accuracy'][error_type]:
            baseline_acc = baseline_out_dist_results.get('error_type_breakdown', {}).get(error_type, {}).get('accuracy', 0)
            accs = [baseline_acc] + history['out_dist_error_type_accuracy'][error_type]
            epochs_with_baseline = [0] + epochs_for_plot
            ax9.plot(epochs_with_baseline, [a * 100 for a in accs], 
                    color=colors[error_type], marker='o', linewidth=2, 
                    markersize=6, label=f'Out: {error_type.replace("_", " ").title()}')
    
    # In-dist (dashed lines)
    for error_type in ['too_risky', 'too_risk_averse']:
        if history['in_dist_error_type_accuracy'][error_type]:
            baseline_acc = baseline_in_dist_results.get('error_type_breakdown', {}).get(error_type, {}).get('accuracy', 0)
            accs = [baseline_acc] + history['in_dist_error_type_accuracy'][error_type]
            epochs_with_baseline = [0] + epochs_for_plot
            ax9.plot(epochs_with_baseline, [a * 100 for a in accs], 
                    color=colors[error_type], marker='s', linewidth=2, linestyle='--',
                    markersize=5, alpha=0.7, label=f'In: {error_type.replace("_", " ").title()}')
    
    ax9.axhline(y=50, color='r', linestyle='--', alpha=0.5, label='Random (50%)')
    ax9.set_xlabel('Epoch')
    ax9.set_ylabel('Accuracy (%)')
    ax9.set_title('Error Type Accuracy Over Training\n(Solid=Out-Dist, Dashed=In-Dist)')
    ax9.legend(fontsize=7, loc='lower right', ncol=2)
    ax9.set_ylim([0, 100])
    ax9.grid(True, alpha=0.3)

# =============================================================================
# Row 4: Baseline Comparison and Performance Analysis
# =============================================================================

# Plot 10: Baseline Score Distribution (Out-Dist)
ax10 = fig.add_subplot(gs[3, 0])
baseline_pref = np.array(baseline_out_dist_results['preferred_scores'])
baseline_rej = np.array(baseline_out_dist_results['rejected_scores'])
bins_baseline = np.linspace(min(min(baseline_pref), min(baseline_rej)),
                           max(max(baseline_pref), max(baseline_rej)), 30)
ax10.hist(baseline_pref, bins=bins_baseline, alpha=0.6, label='Preferred (Cooperate)', 
         color='green', density=True, edgecolor='black', linewidth=0.5)
ax10.hist(baseline_rej, bins=bins_baseline, alpha=0.6, label='Rejected (Non-Cooperate)', 
         color='red', density=True, edgecolor='black', linewidth=0.5)
ax10.set_xlabel('Reward Score')
ax10.set_ylabel('Density')
ax10.set_title('Baseline (Untrained): Score Distribution')
ax10.legend()
ax10.grid(True, alpha=0.3, axis='y')

# Plot 11: Ranking Performance (Correct vs Incorrect by Type) - Out-Dist
ax11 = fig.add_subplot(gs[3, 1])
if final_out_dist_error_breakdown:
    error_types = ['too_risky', 'too_risk_averse', 'other']
    error_type_labels = ['Too Risky', 'Too Risk-Averse', 'Other']
    
    correct_counts = [final_out_dist_error_breakdown.get(et, {}).get('correct', 0) for et in error_types]
    incorrect_counts = [final_out_dist_error_breakdown.get(et, {}).get('total', 0) - 
                       final_out_dist_error_breakdown.get(et, {}).get('correct', 0) for et in error_types]
    
    x = np.arange(len(error_types))
    width = 0.35
    
    bars1 = ax11.bar(x - width/2, correct_counts, width, label='Correct', 
                     color='green', alpha=0.7, edgecolor='black')
    bars2 = ax11.bar(x + width/2, incorrect_counts, width, label='Incorrect', 
                     color='red', alpha=0.7, edgecolor='black')
    
    ax11.set_ylabel('Number of Examples')
    ax11.set_title('Out-Dist: Correct vs Incorrect by Type')
    ax11.set_xticks(x)
    ax11.set_xticklabels(error_type_labels, fontsize=9)
    ax11.legend(fontsize=8)
    ax11.grid(True, alpha=0.3, axis='y')

# Plot 12: Margin Distribution by Error Type - Out-Dist
ax12 = fig.add_subplot(gs[3, 2])
if per_example_results:
    colors = {'too_risky': 'orange', 'too_risk_averse': 'blue', 'other': 'gray'}
    labels = {'too_risky': 'Too Risky', 'too_risk_averse': 'Too Risk-Averse', 'other': 'Other'}
    
    box_data = []
    box_labels = []
    for error_type in ['too_risky', 'too_risk_averse', 'other']:
        margins_for_type = [r['margin'] for r in per_example_results if r['error_type'] == error_type]
        if margins_for_type:
            box_data.append(margins_for_type)
            box_labels.append(labels[error_type])
    
    if box_data:
        bp = ax12.boxplot(box_data, labels=box_labels, patch_artist=True)
        for patch, error_type in zip(bp['boxes'], ['too_risky', 'too_risk_averse', 'other']):
            if error_type in colors:
                patch.set_facecolor(colors[error_type])
                patch.set_alpha(0.6)
        
        ax12.axhline(y=0, color='r', linestyle='--', alpha=0.5, label='No Preference')
        ax12.set_ylabel('Reward Margin')
        ax12.set_title('Out-Dist: Margin by Error Type')
        ax12.grid(True, alpha=0.3, axis='y')

# =============================================================================
# Row 5: Summary and CDFs
# =============================================================================

# Plot 13: Overall Summary Bar Chart (In-Dist and Out-Dist)
ax13 = fig.add_subplot(gs[4, 0])
categories = ['Baseline\nIn-Dist', 'Trained\nIn-Dist', 'Baseline\nOut-Dist', 'Trained\nOut-Dist']
accuracies = [baseline_in_dist_results['accuracy'] * 100, final_in_dist_accuracy * 100,
              baseline_out_dist_results['accuracy'] * 100, final_out_dist_accuracy * 100]
colors_bar = ['lightblue', 'blue', 'lightgreen', 'green']
bars = ax13.bar(categories, accuracies, color=colors_bar, alpha=0.8, edgecolor='black', linewidth=1.5)
ax13.axhline(y=50, color='r', linestyle='--', alpha=0.5, label='Random (50%)')
for bar, acc in zip(bars, accuracies):
    height = bar.get_height()
    ax13.text(bar.get_x() + bar.get_width()/2., height + 1,
             f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold', fontsize=10)
ax13.set_ylabel('Accuracy (%)')
ax13.set_title('Overall Accuracy Comparison')
ax13.set_ylim([0, 100])
ax13.legend()
ax13.grid(True, alpha=0.3, axis='y')

# Plot 14: Score Evolution Over Training
ax14 = fig.add_subplot(gs[4, 1])
if len(history['epochs']) > 0 and len(history['preferred_rewards']) > 0:
    epochs_with_baseline = [0] + history['epochs']
    pref_with_baseline = [np.mean(baseline_out_dist_results['preferred_scores'])] + history['preferred_rewards']
    rej_with_baseline = [np.mean(baseline_out_dist_results['rejected_scores'])] + history['rejected_rewards']
    ax14.plot(epochs_with_baseline, pref_with_baseline, 'g-o', 
             linewidth=2, markersize=8, label='Preferred (CARA)')
    ax14.plot(epochs_with_baseline, rej_with_baseline, 'r-s', 
             linewidth=2, markersize=8, label='Rejected')
    ax14.scatter([0, 0], [pref_with_baseline[0], rej_with_baseline[0]], 
                color='orange', s=100, zorder=5, marker='D', label='Baseline')
    ax14.set_xlabel('Epoch')
    ax14.set_ylabel('Mean Reward Score')
    ax14.set_title('Training Score Evolution')
    ax14.legend()
    ax14.grid(True, alpha=0.3)

# Plot 15: CDFs (In-Dist and Out-Dist comparison)
ax15 = fig.add_subplot(gs[4, 2])
# Out-dist trained
sorted_pref = np.sort(preferred_scores)
sorted_rej = np.sort(rejected_scores)
cdf_pref = np.arange(1, len(sorted_pref) + 1) / len(sorted_pref)
cdf_rej = np.arange(1, len(sorted_rej) + 1) / len(sorted_rej)
ax15.plot(sorted_pref, cdf_pref, 'g-', linewidth=2, label='Out-Dist: Preferred')
ax15.plot(sorted_rej, cdf_rej, 'r-', linewidth=2, label='Out-Dist: Rejected')

# In-dist trained
sorted_in_pref = np.sort(final_in_dist_pref_scores)
sorted_in_rej = np.sort(final_in_dist_rej_scores)
cdf_in_pref = np.arange(1, len(sorted_in_pref) + 1) / len(sorted_in_pref)
cdf_in_rej = np.arange(1, len(sorted_in_rej) + 1) / len(sorted_in_rej)
ax15.plot(sorted_in_pref, cdf_in_pref, 'g--', linewidth=1.5, alpha=0.7, label='In-Dist: Preferred')
ax15.plot(sorted_in_rej, cdf_in_rej, 'r--', linewidth=1.5, alpha=0.7, label='In-Dist: Rejected')

ax15.set_xlabel('Reward Score')
ax15.set_ylabel('Cumulative Probability')
ax15.set_title('CDF: In-Dist (dashed) vs Out-Dist (solid)')
ax15.legend(fontsize=8)
ax15.grid(True, alpha=0.3)

plt.tight_layout()

# Save plot
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
plot_path = f"outputs/training_results_{timestamp}.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f"\nComprehensive plots saved to: {plot_path}")

plt.show()

# =============================================================================
# SAVE RESULTS TO JSON
# =============================================================================
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())

results = {
    'model_name': MODEL_NAME,
    'architecture': {
        'type': 'LoRA + Reward Head',
        'lora_config': {
            'r': LORA_R,
            'alpha': LORA_ALPHA,
            'dropout': LORA_DROPOUT,
            'target_modules': LORA_TARGET_MODULES,
        },
        'trainable_params': trainable_params,
        'total_params': total_params,
        'trainable_percent': 100 * trainable_params / total_params,
    },
    'data': {
        'training_file': TRAIN_DATA_FILE,
        'out_dist_validation_file': VAL_DATA_FILE,
        'in_dist_val_split': IN_DIST_VAL_SPLIT,
        'training_label_type': 'CARA (risk-aversion)',
        'in_dist_validation_label_type': 'CARA (risk-aversion)',
        'out_dist_validation_label_type': 'cooperate',
    },
    'baseline': {
        'in_dist': {
            'accuracy': float(baseline_in_dist_results['accuracy']),
            'loss': float(baseline_in_dist_results['loss']),
            'mean_margin': float(baseline_in_dist_results['mean_margin']),
            'error_type_breakdown': {
                et: {k: float(v) if isinstance(v, (int, float)) else v 
                     for k, v in stats.items()}
                for et, stats in baseline_in_dist_results.get('error_type_breakdown', {}).items()
            },
        },
        'out_dist': {
            'accuracy': float(baseline_out_dist_results['accuracy']),
            'loss': float(baseline_out_dist_results['loss']),
            'mean_margin': float(baseline_out_dist_results['mean_margin']),
            'error_type_breakdown': {
                et: {k: float(v) if isinstance(v, (int, float)) else v 
                     for k, v in stats.items()}
                for et, stats in baseline_out_dist_results.get('error_type_breakdown', {}).items()
            },
        },
    },
    'trained': {
        'in_dist': {
            'final_accuracy': float(final_in_dist_accuracy),
            'final_loss': float(final_in_dist_loss),
            'mean_margin': float(np.mean(final_in_dist_margins)),
            'error_type_breakdown': {
                et: {k: float(v) if isinstance(v, (int, float)) else v 
                     for k, v in stats.items()}
                for et, stats in final_in_dist_error_breakdown.items()
            } if final_in_dist_error_breakdown else {},
        },
        'out_dist': {
            'final_accuracy': float(final_out_dist_accuracy),
            'final_loss': float(final_out_dist_loss),
            'best_accuracy': float(best_val_accuracy),
            'mean_margin': float(np.mean(final_out_dist_margins)),
            'error_type_breakdown': {
                et: {k: float(v) if isinstance(v, (int, float)) else v 
                     for k, v in stats.items()}
                for et, stats in final_out_dist_error_breakdown.items()
            } if final_out_dist_error_breakdown else {},
        },
    },
    'improvement': {
        'in_dist': {
            'accuracy_gain': float(final_in_dist_accuracy - baseline_in_dist_results['accuracy']),
            'accuracy_gain_percent': float((final_in_dist_accuracy - baseline_in_dist_results['accuracy']) * 100),
        },
        'out_dist': {
            'accuracy_gain': float(final_out_dist_accuracy - baseline_out_dist_results['accuracy']),
            'accuracy_gain_percent': float((final_out_dist_accuracy - baseline_out_dist_results['accuracy']) * 100),
        },
    },
    'config': {
        'num_epochs': NUM_EPOCHS,
        'epochs_trained': len(history['epochs']),
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'weight_decay': WEIGHT_DECAY,
        'training_samples': len(train_dataset),
        'in_dist_validation_samples': len(in_dist_val_dataset),
        'out_dist_validation_samples': len(out_dist_val_dataset),
    },
    'timestamp': timestamp
}

results_path = f"outputs/results_{timestamp}.json"
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)
print(f"Results saved to: {results_path}")

# Save training history
history_path = f"outputs/training_history_{timestamp}.json"
# Convert numpy arrays to lists for JSON serialization
history_serializable = {}
for key, value in history.items():
    if isinstance(value, dict):
        history_serializable[key] = {k: v if not isinstance(v, np.ndarray) else v.tolist() 
                                     for k, v in value.items()}
    elif isinstance(value, np.ndarray):
        history_serializable[key] = value.tolist()
    else:
        history_serializable[key] = value

with open(history_path, 'w') as f:
    json.dump(history_serializable, f, indent=2)
print(f"Training history saved to: {history_path}")

# =============================================================================
# FINAL SUMMARY
# =============================================================================
print("\n" + "="*60)
print("Experiment Complete!")
print("="*60)
print(f"\nDataset Summary:")
print(f"  Training: {len(train_dataset)} examples (CARA labels)")
print(f"  In-Dist Validation: {len(in_dist_val_dataset)} examples (CARA labels)")
print(f"  Out-of-Dist Validation: {len(out_dist_val_dataset)} examples (Cooperate labels)")

print(f"\nAccuracy Results:")
print(f"  In-Distribution:")
print(f"    Baseline: {baseline_in_dist_results['accuracy']*100:.1f}%")
print(f"    Trained:  {final_in_dist_accuracy*100:.1f}%")
print(f"    Improvement: {(final_in_dist_accuracy - baseline_in_dist_results['accuracy'])*100:+.1f}%")
print(f"  Out-of-Distribution:")
print(f"    Baseline: {baseline_out_dist_results['accuracy']*100:.1f}%")
print(f"    Trained:  {final_out_dist_accuracy*100:.1f}%")
print(f"    Improvement: {(final_out_dist_accuracy - baseline_out_dist_results['accuracy'])*100:+.1f}%")

print(f"\nTrainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.4f}%)")

# Find checkpoints
import glob
import shutil
checkpoint_dirs = glob.glob("outputs/best_model_epoch*")

# Automatic downloads (for Google Colab)
print("\n" + "="*60)
print("Downloading Results...")
print("="*60)

try:
    from google.colab import files
    
    print(f"Downloading: {plot_path}")
    files.download(plot_path)
    
    print(f"Downloading: {results_path}")
    files.download(results_path)
    
    print(f"Downloading: {history_path}")
    files.download(history_path)
    
    if checkpoint_dirs:
        latest_checkpoint = max(checkpoint_dirs, key=os.path.getctime)
        if os.path.isdir(latest_checkpoint):
            zip_path = f"{latest_checkpoint}.zip"
            shutil.make_archive(latest_checkpoint, 'zip', latest_checkpoint)
            print(f"Downloading: {zip_path}")
            files.download(zip_path)
    
    print("\nAll files downloaded successfully!")
    
except ImportError:
    print("\nNOTE: Not running in Google Colab - files saved to outputs/ directory")
    print("Files available:")
    print(f"  - {plot_path}")
    print(f"  - {results_path}")
    print(f"  - {history_path}")
    if checkpoint_dirs:
        latest_checkpoint = max(checkpoint_dirs, key=os.path.getctime)
        print(f"  - {latest_checkpoint}/ (LoRA checkpoint directory)")