# Risk-Averse Reward Model Training

This notebook implements the complete risk aversion experiment for training **Qwen3-8B** to prefer risk-averse choices over risk-neutral ones.

## Features
- **CSV Data Loading**: Loads scenarios from `11_7_low_stakes_training_set.csv`
- **Pure Single-Input Training**: Binary classification approach (risk-averse=1.0, risk-neutral=0.0)
- **Low Stakes Training**: 1,000 situations with low-stakes gambles
- **Larger Model**: Qwen3-8B (8 billion parameters) for better learning capacity
- **GPU Optimizations**: fp16 mixed precision, gradient checkpointing, fused AdamW
- **Comprehensive Visualization**: 4-panel plots showing training progress and results
- **CARA vs LINEAR Utility**: Trains on CARA (risk-averse) vs LINEAR (risk-neutral) best options

## Requirements
- Google Colab with **High-RAM GPU** (A100 recommended for 8B model)
- Runtime → Change runtime type → GPU → A100
- Upload `11_7_low_stakes_training_set.csv` to Colab (in data/ folder or root)

**Memory Optimized for 8B Model:** Uses gradient checkpointing, smaller batch size, and fp16 for memory efficiency

## Expected Output
- Training plots saved to `outputs/training_results_YYYYMMDD_HHMMSS.png`
- Results JSON saved to `outputs/experiment_results.json`
- Model saved to `risk_averse_model/`

## Training Scale
- **1,000 situations** (low-stakes dataset)
- **2,000 training examples** (2 per situation)
- **10 epochs**
- **~2,500 total training steps**
- Estimated training time: **30-60 minutes** on A100 GPU

## 1. Install Dependencies

In [None]:
# Install required packages
!pip install -q transformers datasets accelerate torch pandas numpy scikit-learn matplotlib seaborn

print("✓ Dependencies installed successfully!")

## 2. Import Libraries

In [None]:
import os
import sys
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
from sklearn.model_selection import train_test_split
import json
from typing import List, Dict, Tuple
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")

print("✓ Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 3. Data Loading Class

Loads risk scenarios from CSV file with proper prompt modification and grouping by situation_id.

In [None]:
class RiskAversionDataLoader:
    """Load and process data from CSV file for risk aversion training"""
    
    def __init__(self, csv_file_path="11_7_low_stakes_training_set.csv"):
        self.csv_file_path = csv_file_path
        
    def load_and_process_data(self) -> pd.DataFrame:
        """Load CSV data and process it for training
        
        New format (11_7_low_stakes_training_set.csv):
        - Multiple rows per situation (one per option)
        - is_best_cara = True marks risk-averse option
        - is_best_linear = True marks risk-neutral option
        - option_index is 0-indexed, prompts use 1-indexed numbers
        """
        # Check if CSV file exists
        if not os.path.exists(self.csv_file_path):
            raise FileNotFoundError(
                f"Required data file '{self.csv_file_path}' not found. "
                f"Please ensure the CSV file is uploaded to Colab."
            )
        
        # Load the CSV file
        df = pd.read_csv(self.csv_file_path)
        print(f"Loaded {len(df)} rows from {self.csv_file_path}")
        
        # Check required columns exist (new format)
        required_columns = ['situation_id', 'prompt_text', 'option_index', 'is_best_cara', 'is_best_linear']
        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            raise ValueError(
                f"Missing required columns in CSV: {missing_columns}. "
                f"Available columns: {list(df.columns)}"
            )
        
        # Group by situation_id to get unique situations
        situations = []
        situations_skipped = 0
        
        for situation_id, group in df.groupby('situation_id'):
            # Find risk-averse option (CARA best)
            cara_rows = group[group['is_best_cara'] == True]
            if len(cara_rows) == 0:
                situations_skipped += 1
                continue
            cara_option = cara_rows.iloc[0]
            
            # Find risk-neutral option (LINEAR best)
            linear_rows = group[group['is_best_linear'] == True]
            if len(linear_rows) == 0:
                situations_skipped += 1
                continue
            linear_option = linear_rows.iloc[0]
            
            # Get prompt text from first row (same for all options)
            prompt_text = group.iloc[0]['prompt_text']
            
            # Convert 0-indexed option_index to 1-indexed option numbers (as shown in prompt)
            correct_label = str(cara_option['option_index'] + 1)
            incorrect_label = str(linear_option['option_index'] + 1)
            
            situations.append({
                'situation_id': situation_id,
                'prompt_text': prompt_text,
                'correct_label': correct_label,  # Risk-averse option number
                'incorrect_label': incorrect_label,  # Risk-neutral option number
                'num_options': len(group)
            })
        
        if situations_skipped > 0:
            print(f"Warning: Skipped {situations_skipped} situations missing CARA or LINEAR best options")
        
        result_df = pd.DataFrame(situations)
        print(f"Processed into {len(result_df)} unique situations")
        
        # Display sample data
        if len(result_df) > 0:
            sample = result_df.iloc[0]
            print(f"\nSample situation:")
            print(f"Prompt: {sample['prompt_text'][:200]}...")
            print(f"Risk-averse choice (CARA best): Option {sample['correct_label']}")
            print(f"Risk-neutral choice (LINEAR best): Option {sample['incorrect_label']}")
            print(f"Number of options in this situation: {sample['num_options']}")
        
        return result_df

print("✓ RiskAversionDataLoader defined")

## 4. Dataset Classes

PyTorch datasets for both pairwise ranking training and single-input evaluation.

In [None]:
class SingleInputDataset(Dataset):
    """Simple dataset for pure single-input classification training
    
    For each situation:
    - Returns 2 examples: risk-averse option (label=1.0) and risk-neutral option (label=0.0)
    
    This teaches the model absolute scoring: risk-averse=high score, risk-neutral=low score.
    """
    
    def __init__(self, dataframe: pd.DataFrame, tokenizer, max_length=128):
        # Reset index to ensure sequential 0-based indexing (important after train_test_split)
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Expand dataset: each situation generates 2 examples
        # Example 0: risk-averse option (label=1.0)
        # Example 1: risk-neutral option (label=0.0)
        self.examples = []
        for idx in range(len(self.data)):
            row = self.data.iloc[idx]
            # Risk-averse example
            self.examples.append({
                'situation_idx': idx,
                'is_risk_averse': True,
                'situation_id': row['situation_id']
            })
            # Risk-neutral example
            self.examples.append({
                'situation_idx': idx,
                'is_risk_averse': False,
                'situation_id': row['situation_id']
            })
        
        print(f"SingleInputDataset: {len(self.examples)} examples from {len(self.data)} situations")
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        example_info = self.examples[idx]
        row = self.data.iloc[example_info['situation_idx']]
        
        is_risk_averse = example_info['is_risk_averse']
        option_text = row['correct_label'] if is_risk_averse else row['incorrect_label']
        input_text = f"{row['prompt_text']}\n\nChosen option: {option_text}"
        label = 1.0 if is_risk_averse else 0.0
        
        encoding = self.tokenizer(
            input_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': label,
            'situation_id': row['situation_id']
        }


# Keep PairwiseRiskAversionDataset for compatibility with final evaluation
class PairwiseRiskAversionDataset(Dataset):
    """Dataset that provides pairs of risk-averse vs risk-neutral choices for evaluation"""
    
    def __init__(self, dataframe: pd.DataFrame, tokenizer, max_length=128):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Create input texts for both choices
        risk_averse_text = f"{row['prompt_text']}\n\nChosen option: {row['correct_label']}"
        risk_neutral_text = f"{row['prompt_text']}\n\nChosen option: {row['incorrect_label']}"
        
        # Tokenize both inputs
        risk_averse_encoding = self.tokenizer(
            risk_averse_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        risk_neutral_encoding = self.tokenizer(
            risk_neutral_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'risk_averse_input_ids': risk_averse_encoding['input_ids'].flatten(),
            'risk_averse_attention_mask': risk_averse_encoding['attention_mask'].flatten(),
            'risk_neutral_input_ids': risk_neutral_encoding['input_ids'].flatten(),
            'risk_neutral_attention_mask': risk_neutral_encoding['attention_mask'].flatten(),
            'situation_id': row['situation_id']
        }

print("✓ Dataset classes defined")

## 5. Reward Model - **Pure Single-Input Classification**

Risk-averse reward model with **simplified single-input training** using **Qwen3-8B**.

**Note:** Model loads with automatic dtype and uses gradient checkpointing for memory efficiency.

### Training Approach: Pure Single-Input Classification

This model uses **pure single-input binary classification**:

| Parameter | Value |
|-----------|-------|
| **Model Size** | 8B parameters (Qwen3-8B) |
| **Training Data** | 1,000 low-stakes situations |
| **Training Examples** | 2,000 examples (2 per situation) |
| **Epochs** | 10 epochs |
| **Total Training Steps** | ~2,500 steps |
| **Training Time** | ~30-60 minutes on A100 |

### Dataset: Low-Stakes Gambles

The new dataset (`11_7_low_stakes_training_set.csv`) contains:
- **1,000 unique situations** with low-stakes monetary gambles
- Each situation has 2-5 options to choose from
- **CARA utility** (risk-averse): `u(w) = 1 - e^(-0.01w)`
- **LINEAR utility** (risk-neutral): `u(w) = w`
- Each situation has one option that maximizes CARA utility (risk-averse best)
- Each situation has one option that maximizes LINEAR utility (risk-neutral best)

### Loss Function

**Binary Cross-Entropy Loss (BCE)**

```python
loss = BCEWithLogitsLoss(score, label)
```

Where:
- **Risk-averse options** (CARA best): `label = 1.0` → model learns to output high scores (~1.0)
- **Risk-neutral options** (LINEAR best): `label = 0.0` → model learns to output low scores (~0.0)
- Simple, direct supervision on individual options

### Training Data Distribution

For the full dataset:
- **Risk-averse examples**: 1,000 (label=1.0)
- **Risk-neutral examples**: 1,000 (label=0.0)
- **Total training examples**: 2,000
- **80/20 split**: 1,600 train / 400 validation

### Memory Optimizations for 8B Model

- **Gradient checkpointing**: Trade compute for memory
- **fp16 mixed precision**: Half precision for forward/backward
- **Batch size 1 + gradient accumulation 8**: Effective batch size 8
- **Sequence length 128**: Sufficient for most prompts

### Expected Behavior

With 8B parameters and 1K low-stakes situations, we expect:
- **Risk-averse option scores**: High values (~0.6-0.9 in sigmoid space)
- **Risk-neutral option scores**: Low values (~0.1-0.4 in sigmoid space)
- **Score distributions**: Clearly separated (green vs red in histogram)
- **Risk preference scatter**: Points in green zone (above diagonal)
- **Accuracy**: >0.6 (clearly better than random)
- **Risk-averse preference rate**: >0.6 (model shows consistent preference)

### Why Low Stakes?

Low-stakes gambles ($0-$100 range) may be easier to learn because:
- Simpler numerical ranges for the model to understand
- More consistent risk patterns
- Less extreme utility differences

### Why Qwen3-8B?

Qwen3-8B offers:
- Latest Qwen architecture with improved capabilities
- Strong instruction-following and reasoning
- Efficient training and inference
- Open weights for research use

In [None]:
class RiskAverseRewardModel(nn.Module):
    """Reward model for scoring risk-averse behavior with pure single-input classification"""
    
    def __init__(self, model_name="Qwen/Qwen3-8B"):
        super().__init__()
        # Optimized for CUDA with automatic device mapping (fp16 handled by Trainer)
        load_kwargs = {
            "num_labels": 1,
            "device_map": "auto",
            "low_cpu_mem_usage": True,
            "torch_dtype": "auto",  # Let it choose based on hardware
        }

        print(f"Loading {model_name}...")
        self.backbone = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            **load_kwargs
        )
        
        # Enable gradient checkpointing for memory efficiency with 8B model
        self.backbone.gradient_checkpointing_enable()
        print("✓ Gradient checkpointing enabled for memory efficiency")
        
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        """Simple single-input forward pass
        
        During training: labels are provided (1.0 for risk-averse, 0.0 for risk-neutral)
        During evaluation: labels are None, just return logits
        """
        # Get model device from backbone parameters
        device = next(self.backbone.parameters()).device
        
        # Ensure all tensors are on the correct device
        if input_ids is not None:
            input_ids = input_ids.to(device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(device)
        if labels is not None:
            labels = labels.to(device)
        
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        logits = outputs.logits.squeeze(-1)
        
        if labels is not None:
            # Binary cross-entropy loss for single-input classification
            # Teaches: risk-averse options should score high (1.0), risk-neutral low (0.0)
            loss_fn = nn.BCEWithLogitsLoss()
            loss = loss_fn(logits, labels)
            
            # Debug output during training (more frequent than before)
            if self.training and torch.rand(1).item() < 0.01:  # 1% chance
                pred_probs = torch.sigmoid(logits).mean().item()
                print(f"[TRAIN] Avg logit: {logits.mean().item():.3f}, "
                      f"Avg sigmoid: {pred_probs:.3f}, Target avg: {labels.mean().item():.3f}")
            
            return {"loss": loss, "logits": logits}
        
        return {"logits": logits}

print("✓ RiskAverseRewardModel defined")

## 6. Evaluation Function

Comprehensive evaluation with bad variation handling and detailed metrics.

In [None]:
def evaluate_model(model, tokenizer, test_df: pd.DataFrame, return_detailed=False):
    """Evaluate the trained model on test situations"""
    print(f"Evaluating model on {len(test_df)} test situations...")
    
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    risk_averse_wins = 0
    situations_processed = 0
    
    # Store detailed results for plotting
    results = {
        'risk_averse_scores': [],
        'risk_neutral_scores': [],
        'predictions': [],
        'expected': [],
        'situation_ids': [],
    }
    
    device = next(model.parameters()).device
    
    with torch.no_grad():
        for _, row in test_df.iterrows():
            # Test risk-averse option (correct_label)
            risk_averse_text = f"{row['prompt_text']}\n\nChosen option: {row['correct_label']}"
            risk_averse_encoding = tokenizer(
                risk_averse_text,
                truncation=True,
                padding='max_length',
                max_length=128,
                return_tensors='pt'
            )
            risk_averse_encoding = {k: v.to(device) for k, v in risk_averse_encoding.items()}
            risk_averse_output = model(
                input_ids=risk_averse_encoding['input_ids'], 
                attention_mask=risk_averse_encoding['attention_mask']
            )
            risk_averse_score = risk_averse_output["logits"].item()
            
            # Test risk-neutral option (incorrect_label)
            risk_neutral_text = f"{row['prompt_text']}\n\nChosen option: {row['incorrect_label']}"
            risk_neutral_encoding = tokenizer(
                risk_neutral_text,
                truncation=True,
                padding='max_length',
                max_length=128,
                return_tensors='pt'
            )
            risk_neutral_encoding = {k: v.to(device) for k, v in risk_neutral_encoding.items()}
            risk_neutral_output = model(
                input_ids=risk_neutral_encoding['input_ids'], 
                attention_mask=risk_neutral_encoding['attention_mask']
            )
            risk_neutral_score = risk_neutral_output["logits"].item()
            
            # Binary classification accuracy (each option independently)
            risk_averse_pred = torch.sigmoid(torch.tensor(risk_averse_score)).item()
            risk_neutral_pred = torch.sigmoid(torch.tensor(risk_neutral_score)).item()
            
            if risk_averse_pred > 0.5:  # Correctly classified as risk-averse
                correct_predictions += 1
            if risk_neutral_pred <= 0.5:  # Correctly classified as risk-neutral
                correct_predictions += 1
            total_predictions += 2
            
            # Check if risk-averse option scores higher (preference)
            if risk_averse_score > risk_neutral_score:
                risk_averse_wins += 1
            
            # Store results for plotting
            results['risk_averse_scores'].append(risk_averse_score)
            results['risk_neutral_scores'].append(risk_neutral_score)
            results['predictions'].extend([risk_averse_pred, risk_neutral_pred])
            results['expected'].extend([1.0, 0.0])
            results['situation_ids'].append(row['situation_id'])
            
            # Print progress every 50 situations
            situations_processed += 1
            if situations_processed % 50 == 0:
                current_acc = correct_predictions / total_predictions if total_predictions > 0 else 0
                current_pref = risk_averse_wins / situations_processed
                print(f"  Progress: {situations_processed}/{len(test_df)} situations | "
                      f"Accuracy: {current_acc:.3f} | Risk-averse preference: {current_pref:.3f}")
    
    accuracy = correct_predictions / total_predictions
    risk_averse_preference_rate = risk_averse_wins / len(test_df)
    
    print(f"\n✓ Evaluation complete:")
    print(f"  Model accuracy: {accuracy:.3f}")
    print(f"  Risk-averse preference rate: {risk_averse_preference_rate:.3f}")
    
    if return_detailed:
        results['risk_averse_preference_rate'] = risk_averse_preference_rate
        return accuracy, results
    return accuracy

print("✓ Evaluation function defined")

## 7. Plotting Functions

Comprehensive 4-panel visualization showing training progress, score distributions, risk preferences, and performance summary.

In [None]:
def plot_training_loss(trainer, ax):
    """Plot training and validation loss over time"""
    log_history = trainer.state.log_history
    
    train_steps = []
    train_losses = []
    eval_steps = []
    eval_losses = []
    
    for log_entry in log_history:
        if 'loss' in log_entry:
            train_steps.append(log_entry['step'])
            train_losses.append(log_entry['loss'])
        if 'eval_loss' in log_entry:
            eval_steps.append(log_entry['step'])
            eval_losses.append(log_entry['eval_loss'])
    
    ax.plot(train_steps, train_losses, label='Training Loss', linewidth=2, marker='o', markersize=4)
    if eval_losses:
        ax.plot(eval_steps, eval_losses, label='Validation Loss', linewidth=2, marker='s', markersize=4)
    
    ax.set_xlabel('Training Steps')
    ax.set_ylabel('Loss')
    ax.set_title('Training Progress')
    ax.legend()
    ax.grid(True, alpha=0.3)


def plot_score_distribution(eval_results, ax):
    """Plot distribution of model scores for risk-averse vs risk-neutral choices"""
    risk_averse_scores = eval_results['risk_averse_scores']
    risk_neutral_scores = eval_results['risk_neutral_scores']
    
    bins = np.linspace(min(min(risk_averse_scores), min(risk_neutral_scores)),
                       max(max(risk_averse_scores), max(risk_neutral_scores)), 21)
    ax.hist(risk_averse_scores, bins=bins, alpha=0.7, label='Risk-Averse Options', 
            color='green', density=True)
    ax.hist(risk_neutral_scores, bins=bins, alpha=0.7, label='Risk-Neutral Options', 
            color='red', density=True)
    
    ax.set_xlabel('Model Score (logits)')
    ax.set_ylabel('Density')
    ax.set_title('Score Distribution by Option Type')
    ax.legend()
    ax.grid(True, alpha=0.3)


def plot_risk_preference_comparison(eval_results, ax):
    """Plot comparison of scores for risk-averse vs risk-neutral options"""
    risk_averse_scores = np.array(eval_results['risk_averse_scores'])
    risk_neutral_scores = np.array(eval_results['risk_neutral_scores'])
    
    # Scatter plot
    ax.scatter(risk_neutral_scores, risk_averse_scores, alpha=0.6, s=50)
    
    # Add diagonal line (equal preference)
    min_val = min(risk_neutral_scores.min(), risk_averse_scores.min())
    max_val = max(risk_neutral_scores.max(), risk_averse_scores.max())
    ax.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, label='Equal Preference')
    
    # Add preference regions
    ax.fill_between([min_val, max_val], [min_val, max_val], [max_val, max_val], 
                    alpha=0.2, color='green', label='Risk-Averse Preferred')
    ax.fill_between([min_val, max_val], [min_val, min_val], [min_val, max_val], 
                    alpha=0.2, color='red', label='Risk-Neutral Preferred')
    
    ax.set_xlabel('Risk-Neutral Option Score')
    ax.set_ylabel('Risk-Averse Option Score')
    ax.set_title('Risk Preference Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)


def plot_performance_summary(eval_results, accuracy, ax):
    """Plot model performance summary statistics"""
    risk_averse_scores = np.array(eval_results['risk_averse_scores'])
    risk_neutral_scores = np.array(eval_results['risk_neutral_scores'])
    
    correctly_prefers_risk_averse = np.mean(risk_averse_scores > risk_neutral_scores)
    avg_risk_averse_score = np.mean(risk_averse_scores)
    avg_risk_neutral_score = np.mean(risk_neutral_scores)
    score_difference = avg_risk_averse_score - avg_risk_neutral_score
    
    metrics = ['Overall\nAccuracy', 'Risk-Averse\nPreference Rate']
    values = [accuracy, correctly_prefers_risk_averse]
    colors = ['blue', 'green']
    
    bars = ax.bar(metrics, values, color=colors, alpha=0.7)
    
    for bar, value in zip(bars, values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{value:.3f}', ha='center', va='bottom', fontweight='bold')
    
    ax.set_ylabel('Score')
    ax.set_title('Model Performance Summary')
    ax.set_ylim(0, 1.1)
    ax.grid(True, alpha=0.3, axis='y')
    
    ax.text(0.5, 0.5, f'Avg Score Difference:\n{score_difference:+.3f}', 
            transform=ax.transAxes, ha='center', 
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8),
            fontweight='bold', fontsize=12)


def plot_results(trainer, eval_results, accuracy):
    """Create comprehensive plots of training and evaluation results"""
    os.makedirs("outputs", exist_ok=True)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Risk-Averse Reward Model Training Results', fontsize=16, fontweight='bold')
    
    plot_training_loss(trainer, axes[0, 0])
    plot_score_distribution(eval_results, axes[0, 1])
    plot_risk_preference_comparison(eval_results, axes[1, 0])
    plot_performance_summary(eval_results, accuracy, axes[1, 1])
    
    plt.tight_layout()
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"outputs/training_results_{timestamp}.png"
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Plots saved to {filename}")
    
    try:
        plt.show()
    except:
        print("Display not available - plots saved to file only")

print("✓ Plotting functions defined")

## 8. Training Function

Complete training pipeline with GPU optimizations and pairwise ranking loss.

In [None]:
def train_reward_model(dataset_df: pd.DataFrame, model_name="Qwen/Qwen3-8B"):
    """Train the risk-averse reward model with pure single-input classification"""
    print(f"Training reward model with {len(dataset_df)} situations using single-input classification...")
    print(f"Model: {model_name}")
    print(f"Training scale: ALL situations, 10 epochs")
    
    # Initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = RiskAverseRewardModel(model_name)
    
    # Set padding token in model config (required for batch sizes > 1)
    model.backbone.config.pad_token_id = tokenizer.pad_token_id
    
    # Split data
    train_df, val_df = train_test_split(dataset_df, test_size=0.2, random_state=42)
    
    # Create single-input datasets
    train_dataset = SingleInputDataset(train_df, tokenizer, max_length=128)
    val_dataset = SingleInputDataset(val_df, tokenizer, max_length=128)
    
    print(f"Training: {len(train_df)} situations → {len(train_dataset)} examples (2 per situation)")
    print(f"Validation: {len(val_df)} situations → {len(val_dataset)} examples (2 per situation)")
    
    # Custom data collator that handles labels
    def collate_fn(features):
        input_ids = torch.stack([f['input_ids'] for f in features])
        attention_mask = torch.stack([f['attention_mask'] for f in features])
        labels = torch.tensor([f['labels'] for f in features], dtype=torch.float32)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }
    
    # Training arguments optimized for 8B model on A100
    training_args = TrainingArguments(
        output_dir="./risk_averse_model",
        num_train_epochs=10,
        per_device_train_batch_size=1,  # Small batch for 8B model
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=8,  # Effective batch size = 8
        warmup_steps=100,  # Fewer warmup steps for smaller dataset
        weight_decay=0.01,
        learning_rate=2e-5,  # Standard for large models
        logging_dir="./logs",
        logging_steps=50,  # More frequent logging for smaller dataset
        report_to="none",
        eval_strategy="steps",
        eval_steps=250,  # More frequent eval for smaller dataset
        save_steps=250,
        save_total_limit=3,  # Keep only 3 best checkpoints
        load_best_model_at_end=False,
        fp16=True,
        dataloader_pin_memory=True,
        dataloader_num_workers=2,
        remove_unused_columns=False,
        optim="adamw_torch_fused",
        prediction_loss_only=True,
        gradient_checkpointing=True,  # Memory efficiency
    )
    
    # Standard Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=collate_fn,
    )
    
    print("✓ Model setup complete")
    
    # Train
    print("\nStarting training:")
    print(f"  - {len(train_dataset)} training examples")
    print(f"  - 10 epochs")
    print(f"  - Effective batch size: 8 (1 × 8 gradient accumulation)")
    print(f"  - Estimated steps: ~{len(train_dataset) * 10 // 8}")
    print("\nWatch for [TRAIN] debug outputs showing model learning progress")
    print("Estimated time: 30-60 minutes on A100 GPU...\n")
    
    trainer.train()
    
    return model, tokenizer, trainer

print("✓ Training function defined")

## 9. Main Experiment Function

Orchestrates the complete experiment from data loading to visualization.

In [None]:
def run_experiment():
    """Run the complete risk aversion experiment with low-stakes training data"""
    print("=== Risk-Averse Reward Model Experiment (Low Stakes) ===")
    print(f"PyTorch device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
    print(f"Model: Qwen3-8B (8 billion parameters)")
    print(f"Dataset: 11_7_low_stakes_training_set.csv")
    
    # Load data from CSV
    print("\n1. Loading risk scenario data from CSV...")
    loader = RiskAversionDataLoader()
    full_dataset_df = loader.load_and_process_data()
    
    # Use ALL situations (no limiting!)
    dataset_df = full_dataset_df
    print(f"Using ALL {len(dataset_df)} situations for training (no limit)")
    
    # Split into train/test
    train_df, test_df = train_test_split(dataset_df, test_size=0.2, random_state=42)
    print(f"Train: {len(train_df)} situations ({len(train_df)*2} examples)")
    print(f"Test: {len(test_df)} situations ({len(test_df)*2} examples)")
    
    # Train model
    print(f"\n2. Training reward model with Qwen3-8B...")
    print("⚠️  This will take approximately 30-60 minutes on A100 GPU")
    model, tokenizer, trainer = train_reward_model(train_df)
    
    # Evaluate
    print(f"\n3. Evaluating model...")
    accuracy, eval_results = evaluate_model(model, tokenizer, test_df, return_detailed=True)
    
    # Plot results
    print(f"\n4. Creating visualizations...")
    plot_results(trainer, eval_results, accuracy)
    
    # Save results
    os.makedirs("outputs", exist_ok=True)
    results = {
        "num_training_situations": len(train_df),
        "num_test_situations": len(test_df),
        "final_accuracy": accuracy,
        "risk_averse_preference_rate": eval_results['risk_averse_preference_rate'],
        "model_name": "Qwen/Qwen3-8B",
        "training_epochs": 10,
        "dataset": "11_7_low_stakes_training_set.csv",
        "timestamp": datetime.now().isoformat()
    }
    
    with open("outputs/experiment_results.json", "w") as f:
        json.dump(results, f, indent=2)
    
    print(f"\n=== Experiment Complete ===")
    print(f"Results saved to outputs/experiment_results.json")
    print(f"Final accuracy: {accuracy:.3f}")
    print(f"Risk-averse preference rate: {eval_results['risk_averse_preference_rate']:.3f}")
    
    risk_averse_scores = np.array(eval_results['risk_averse_scores'])
    risk_neutral_scores = np.array(eval_results['risk_neutral_scores'])
    score_difference = np.mean(risk_averse_scores) - np.mean(risk_neutral_scores)
    print(f"Average score difference (risk-averse - risk-neutral): {score_difference:+.3f}")
    
    return model, tokenizer, results

print("✓ Main experiment function defined")

## 10. Run the Experiment

Execute the complete training and evaluation pipeline.

**Before running:** Ensure you have uploaded `11_7_low_stakes_training_set.csv` to this Colab environment (either in a `data/` folder or in the root directory).

In [None]:
# Run the experiment
try:
    model, tokenizer, results = run_experiment()
    print("\n✓ Experiment completed successfully!")
except FileNotFoundError as e:
    print(f"\n✗ Error: {e}")
    print("Please upload '11_7_low_stakes_training_set.csv' to Colab.")
except Exception as e:
    print(f"\n✗ Experiment failed: {e}")
    import traceback
    traceback.print_exc()

## 11. Test Model on Sample Scenario (Optional)

Run inference on a specific scenario to see how the model scores different options.

In [None]:
# Analyze sample training data to understand what the model sees
try:
    print("="*80)
    print("ANALYZING SAMPLE TRAINING DATA")
    print("="*80)
    
    loader = RiskAversionDataLoader()
    dataset_df = loader.load_and_process_data()
    
    # Show first example
    sample = dataset_df.iloc[0]
    
    print("\n" + "="*80)
    print("SAMPLE SITUATION #1")
    print("="*80)
    
    print("\nPROMPT TEXT:")
    print("-"*80)
    print(sample['prompt_text'])
    print("-"*80)
    
    print(f"\nRisk-averse choice (correct_label): {sample['correct_label']}")
    print(f"Risk-neutral choice (incorrect_label): {sample['incorrect_label']}")
    
    # Show what the model actually receives during training
    print("\n" + "="*80)
    print("WHAT THE MODEL SEES DURING TRAINING")
    print("="*80)
    
    print("\nRISK-AVERSE EXAMPLE (label=1.0):")
    print("-"*80)
    risk_averse_input = f"{sample['prompt_text']}\n\nChosen option: {sample['correct_label']}"
    print(risk_averse_input)
    print("-"*80)
    
    print("\nRISK-NEUTRAL EXAMPLE (label=0.0):")
    print("-"*80)
    risk_neutral_input = f"{sample['prompt_text']}\n\nChosen option: {sample['incorrect_label']}"
    print(risk_neutral_input)
    print("-"*80)
    
    # Show a second example for comparison
    print("\n\n" + "="*80)
    print("SAMPLE SITUATION #2 (for comparison)")
    print("="*80)
    
    sample2 = dataset_df.iloc[1]
    print("\nPROMPT TEXT:")
    print("-"*80)
    print(sample2['prompt_text'])
    print("-"*80)
    print(f"\nRisk-averse choice: {sample2['correct_label']}")
    print(f"Risk-neutral choice: {sample2['incorrect_label']}")
    
    # Analysis
    print("\n\n" + "="*80)
    print("ANALYSIS: CAN THE MODEL LEARN FROM THIS?")
    print("="*80)
    
    print("\nQUESTIONS TO ASK:")
    print("1. Does the prompt contain information about probabilities or utilities?")
    print("2. Are the options distinguishable by anything other than their letter?")
    print("3. Is there ANY textual difference between risk-averse and risk-neutral options?")
    print("4. Or is the model being asked to memorize 'A is always risk-averse'?")
    
    print("\nIf the options are just letters (A, B, C) with no context about what")
    print("they mean, the model has NOTHING to learn from. It would be like asking")
    print("the model to learn 'always prefer A over B' without any reason why.")
    
except Exception as e:
    print(f"Analysis failed: {e}")
    import traceback
    traceback.print_exc()

## 12. Download Results (Optional)

Download the results and plots from Colab to your local machine.

In [None]:
# Download results (optional)
try:
    from google.colab import files
    
    # Download experiment results JSON
    if os.path.exists("outputs/experiment_results.json"):
        files.download("outputs/experiment_results.json")
    
    # Download the latest plot
    import glob
    plot_files = glob.glob("outputs/training_results_*.png")
    if plot_files:
        latest_plot = max(plot_files, key=os.path.getctime)
        files.download(latest_plot)
        print(f"Downloaded: {latest_plot}")
except ImportError:
    print("Not running in Colab - skip download")
except Exception as e:
    print(f"Download failed: {e}")