# Risk-Averse Reward Model Training

This notebook implements the complete risk aversion experiment for training TinyLlama 1.1B to prefer risk-averse choices over risk-neutral ones.

## Features
- **CSV Data Loading**: Loads scenarios from `strict_disagreements_10k_with_prompts_and_bad_formats.csv`
- **Mixed Training**: Combines pairwise ranking (relative scoring) + single-input classification (absolute scoring)
- **GPU Optimizations**: fp16 mixed precision, fused AdamW, device mapping
- **Comprehensive Visualization**: 4-panel plots showing training progress and results
- **Advanced Metrics**: Risk preference rate, score distributions, bad variation handling

## Requirements
- Google Colab with GPU enabled (Runtime → Change runtime type → GPU)
- T4 GPU recommended (15GB VRAM)
- Upload `strict_disagreements_10k_with_prompts_and_bad_formats.csv` to Colab

**Memory Optimized:** Uses batch_size=1 and sequence_length=128 for T4 GPU compatibility

## Expected Output
- Training plots saved to `outputs/training_results_YYYYMMDD_HHMMSS.png`
- Results JSON saved to `outputs/experiment_results.json`
- Model saved to `risk_averse_model/`

## 1. Install Dependencies

In [None]:
# Install required packages
!pip install -q transformers datasets accelerate torch pandas numpy scikit-learn matplotlib seaborn

print("✓ Dependencies installed successfully!")

## 2. Import Libraries

In [None]:
import os
import sys
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
from sklearn.model_selection import train_test_split
import json
from typing import List, Dict, Tuple
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")

print("✓ Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 3. Data Loading Class

Loads risk scenarios from CSV file with proper prompt modification and grouping by situation_id.

In [None]:
class RiskAversionDataLoader:
    """Load and process data from CSV file for risk aversion training"""
    
    def __init__(self, csv_file_path="strict_disagreements_10k_with_prompts_and_bad_formats.csv"):
        self.csv_file_path = csv_file_path
        
    def load_and_process_data(self) -> pd.DataFrame:
        """Load CSV data and process it for training"""
        # Check if CSV file exists
        if not os.path.exists(self.csv_file_path):
            raise FileNotFoundError(
                f"Required data file '{self.csv_file_path}' not found. "
                f"Please ensure the CSV file is uploaded to Colab."
            )
        
        # Load the CSV file
        df = pd.read_csv(self.csv_file_path)
        print(f"Loaded {len(df)} rows from {self.csv_file_path}")
        
        # Check required columns exist
        required_columns = ['situation_id', 'prompt_text', 'correct_label', 'incorrect_label']
        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            raise ValueError(
                f"Missing required columns in CSV: {missing_columns}. "
                f"Available columns: {list(df.columns)}"
            )
        
        # Group by situation_id to get unique situations
        situations = []
        for situation_id, group in df.groupby('situation_id'):
            # Take the first row for situation-level properties
            first_row = group.iloc[0]
            
            # Process the prompt text - replace thinking instruction with output-only instruction
            prompt_text = first_row['prompt_text']
            if 'You can think about which you\'d choose' in prompt_text:
                prompt_text = prompt_text.replace(
                    'You can think about which you\'d choose and then output your answer.',
                    'Output the label of your chosen option only.'
                )
            
            # Collect all bad answer variations from all rows in the group
            bad_correct_list = []
            bad_incorrect_list = []
            
            for _, row in group.iterrows():
                bad_correct = row.get('bad_correct_answers', '')
                bad_incorrect = row.get('bad_incorrect_answers', '')
                
                if pd.notna(bad_correct) and bad_correct.strip():
                    bad_correct_list.extend([x.strip() for x in str(bad_correct).split(',') if x.strip()])
                if pd.notna(bad_incorrect) and bad_incorrect.strip():
                    bad_incorrect_list.extend([x.strip() for x in str(bad_incorrect).split(',') if x.strip()])
            
            # Remove duplicates and join back
            bad_correct_combined = ','.join(list(set(bad_correct_list))) if bad_correct_list else ''
            bad_incorrect_combined = ','.join(list(set(bad_incorrect_list))) if bad_incorrect_list else ''
            
            situations.append({
                'situation_id': situation_id,
                'prompt_text': prompt_text,
                'correct_label': first_row['correct_label'],
                'incorrect_label': first_row['incorrect_label'],
                'bad_correct_answers': bad_correct_combined,
                'bad_incorrect_answers': bad_incorrect_combined,
                'num_options': len(group)
            })
        
        result_df = pd.DataFrame(situations)
        print(f"Processed into {len(result_df)} unique situations")
        
        # Display sample data
        if len(result_df) > 0:
            sample = result_df.iloc[0]
            print(f"\nSample situation:")
            print(f"Prompt: {sample['prompt_text'][:200]}...")
            print(f"Risk-averse choice: {sample['correct_label']}")
            print(f"Risk-neutral choice: {sample['incorrect_label']}")
        
        return result_df

print("✓ RiskAversionDataLoader defined")

## 4. Dataset Classes

PyTorch datasets for both pairwise ranking training and single-input evaluation.

In [None]:
class PairwiseRiskAversionDataset(Dataset):
    """Dataset that provides pairs of risk-averse vs risk-neutral choices for ranking loss"""
    
    def __init__(self, dataframe: pd.DataFrame, tokenizer, max_length=128):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Create input texts for both choices
        risk_averse_text = f"{row['prompt_text']}\n\nChosen option: {row['correct_label']}"
        risk_neutral_text = f"{row['prompt_text']}\n\nChosen option: {row['incorrect_label']}"
        
        # Tokenize both inputs
        risk_averse_encoding = self.tokenizer(
            risk_averse_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        risk_neutral_encoding = self.tokenizer(
            risk_neutral_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'risk_averse_input_ids': risk_averse_encoding['input_ids'].flatten(),
            'risk_averse_attention_mask': risk_averse_encoding['attention_mask'].flatten(),
            'risk_neutral_input_ids': risk_neutral_encoding['input_ids'].flatten(),
            'risk_neutral_attention_mask': risk_neutral_encoding['attention_mask'].flatten(),
            'situation_id': row['situation_id']
        }


class MixedTrainingDataset(Dataset):
    """Dataset that provides both pairwise and single-input examples for mixed training
    
    For each situation:
    - Returns 1 pairwise example (risk-averse vs risk-neutral)
    - Returns 2 single-input examples (risk-averse=1, risk-neutral=0)
    
    This teaches the model both relative scoring (pairwise) and absolute scoring (single-input).
    """
    
    def __init__(self, dataframe: pd.DataFrame, tokenizer, max_length=128, single_input_ratio=0.3):
        # Reset index to ensure sequential 0-based indexing (important after train_test_split)
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.single_input_ratio = single_input_ratio  # 30% single-input, 70% pairwise
        
        # Expand dataset: each situation generates 3 examples
        # Example 0: pairwise (risk-averse vs risk-neutral)
        # Example 1: single (risk-averse, label=1)
        # Example 2: single (risk-neutral, label=0)
        self.examples = []
        for idx in range(len(self.data)):
            row = self.data.iloc[idx]
            # Pairwise example
            self.examples.append({
                'type': 'pairwise',
                'situation_idx': idx,
                'situation_id': row['situation_id']
            })
            # Single-input examples
            self.examples.append({
                'type': 'single',
                'situation_idx': idx,
                'is_risk_averse': True,
                'situation_id': row['situation_id']
            })
            self.examples.append({
                'type': 'single',
                'situation_idx': idx,
                'is_risk_averse': False,
                'situation_id': row['situation_id']
            })
        
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        example_info = self.examples[idx]
        row = self.data.iloc[example_info['situation_idx']]
        
        if example_info['type'] == 'pairwise':
            # Return pairwise example
            risk_averse_text = f"{row['prompt_text']}\n\nChosen option: {row['correct_label']}"
            risk_neutral_text = f"{row['prompt_text']}\n\nChosen option: {row['incorrect_label']}"
            
            risk_averse_encoding = self.tokenizer(
                risk_averse_text,
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
            
            risk_neutral_encoding = self.tokenizer(
                risk_neutral_text,
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
            
            return {
                'mode': 'pairwise',
                'risk_averse_input_ids': risk_averse_encoding['input_ids'].flatten(),
                'risk_averse_attention_mask': risk_averse_encoding['attention_mask'].flatten(),
                'risk_neutral_input_ids': risk_neutral_encoding['input_ids'].flatten(),
                'risk_neutral_attention_mask': risk_neutral_encoding['attention_mask'].flatten(),
                'situation_id': row['situation_id']
            }
        else:
            # Return single-input example
            is_risk_averse = example_info['is_risk_averse']
            option_text = row['correct_label'] if is_risk_averse else row['incorrect_label']
            input_text = f"{row['prompt_text']}\n\nChosen option: {option_text}"
            label = 1.0 if is_risk_averse else 0.0
            
            encoding = self.tokenizer(
                input_text,
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
            
            return {
                'mode': 'single',
                'input_ids': encoding['input_ids'].flatten(),
                'attention_mask': encoding['attention_mask'].flatten(),
                'labels': label,
                'situation_id': row['situation_id']
            }


class MixedDataCollator:
    """Data collator for mixed training that handles both pairwise and single-input batches"""
    
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
    
    def __call__(self, features):
        if len(features) == 0:
            raise ValueError("Empty batch received in MixedDataCollator")
        
        # Check that all features have the same mode
        modes = [f.get('mode', 'unknown') for f in features]
        if len(set(modes)) > 1:
            raise ValueError(f"Batch contains mixed modes: {modes}. All examples in a batch must have the same mode.")
        
        mode = modes[0]
        
        if mode == 'pairwise':
            # Validate that all features have required pairwise keys
            required_keys = ['risk_averse_input_ids', 'risk_averse_attention_mask', 
                           'risk_neutral_input_ids', 'risk_neutral_attention_mask']
            for i, f in enumerate(features):
                missing = [k for k in required_keys if k not in f]
                if missing:
                    raise KeyError(f"Feature {i} in pairwise batch missing keys: {missing}")
            
            # Pairwise batch
            risk_averse_input_ids = [f['risk_averse_input_ids'] for f in features]
            risk_averse_attention_mask = [f['risk_averse_attention_mask'] for f in features]
            risk_neutral_input_ids = [f['risk_neutral_input_ids'] for f in features]
            risk_neutral_attention_mask = [f['risk_neutral_attention_mask'] for f in features]
            
            return {
                'mode': 'pairwise',
                'risk_averse_input_ids': torch.stack(risk_averse_input_ids).long(),
                'risk_averse_attention_mask': torch.stack(risk_averse_attention_mask).long(),
                'risk_neutral_input_ids': torch.stack(risk_neutral_input_ids).long(),
                'risk_neutral_attention_mask': torch.stack(risk_neutral_attention_mask).long(),
            }
        elif mode == 'single':
            # Validate that all features have required single-input keys
            required_keys = ['input_ids', 'attention_mask', 'labels']
            for i, f in enumerate(features):
                missing = [k for k in required_keys if k not in f]
                if missing:
                    raise KeyError(f"Feature {i} in single-input batch missing keys: {missing}")
            
            # Single-input batch
            input_ids = [f['input_ids'] for f in features]
            attention_mask = [f['attention_mask'] for f in features]
            labels = [f['labels'] for f in features]
            
            return {
                'mode': 'single',
                'input_ids': torch.stack(input_ids).long(),
                'attention_mask': torch.stack(attention_mask).long(),
                'labels': torch.tensor(labels).float(),
            }
        else:
            raise ValueError(f"Unknown mode: {mode}. Expected 'pairwise' or 'single'.")


# Keep PairwiseDataCollator for compatibility with evaluation
class PairwiseDataCollator:
    """Data collator for pairwise ranking training"""
    
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
    
    def __call__(self, features):
        # Extract all the different input types
        risk_averse_input_ids = [f['risk_averse_input_ids'] for f in features]
        risk_averse_attention_mask = [f['risk_averse_attention_mask'] for f in features]
        risk_neutral_input_ids = [f['risk_neutral_input_ids'] for f in features]
        risk_neutral_attention_mask = [f['risk_neutral_attention_mask'] for f in features]
        
        # Stack tensors
        batch = {
            'risk_averse_input_ids': torch.stack(risk_averse_input_ids).long(),
            'risk_averse_attention_mask': torch.stack(risk_averse_attention_mask).long(),
            'risk_neutral_input_ids': torch.stack(risk_neutral_input_ids).long(),
            'risk_neutral_attention_mask': torch.stack(risk_neutral_attention_mask).long(),
        }
        
        return batch

print("✓ Dataset classes defined")

## 5. Reward Model

Risk-averse reward model with **mixed training approach** combining pairwise ranking and single-input classification.

**Note:** Model loads in fp32 and is converted to fp16 by the Trainer for mixed precision training.

### Training Approach: Mixed Mode

This model uses **mixed training** to learn both relative and absolute scoring:

| Training Mode | Examples per Situation | Loss Function | Purpose |
|---------------|----------------------|---------------|---------|
| **Pairwise Ranking** | 1 pair | Hybrid loss (margin + Bradley-Terry + L2) | Learn relative scoring: r_A > r_B |
| **Single-Input Classification** | 2 examples | Binary cross-entropy | Learn absolute scoring: risk-averse=1, risk-neutral=0 |

**Total:** Each situation generates **3 training examples** (1 pairwise + 2 single-input)

#### Why Mixed Training?

**The Problem with Pure Pairwise Training:**
- Pairwise training teaches the model to compare options within the same context
- But evaluation uses single-input mode (one option at a time)
- Result: Model learns relative scoring but fails at absolute scoring
- Symptoms: All scores collapse to ~0, no differentiation

**The Solution:**
- **Pairwise mode**: Teaches "risk-averse options should score higher than risk-neutral"
- **Single-input mode**: Teaches "risk-averse options should score ~1, risk-neutral ~0"
- Combined: Model learns both relative preferences AND absolute score meanings

### Loss Functions

#### 1. Pairwise Ranking Loss (Hybrid)

Used when training with option pairs:

| Component | Weight | Formula | Purpose |
|-----------|--------|---------|---------|
| **Margin Ranking Loss** | 1.0 (90%) | `max(0, margin - (r_A - r_B))` | Enforce hard separation: r_A - r_B ≥ 1.0 |
| **Bradley-Terry Loss** | 0.1 (10%) | `-log σ(r_A - r_B)` | Probabilistic preference: P(A>B) = σ(r_A - r_B) |
| **L2 Regularization** | 0.01 (1%) | `r_A² + r_B²` | Prevent score explosion |

**Total Pairwise Loss:** `total = 1.0 × margin_loss + 0.1 × bradley_terry + 0.01 × L2`

**Bradley-Terry Component:**
```python
# Bradley-Terry: P(A > B) = σ(r_A - r_B) = 1 / (1 + exp(r_B - r_A))
# Bradley-Terry loss: -log P(A > B) = -log σ(r_A - r_B)
sigmoid_loss = F.binary_cross_entropy_with_logits(score_diff, ones)
```

Where:
- `score_diff = risk_averse_scores - risk_neutral_scores`
- Target = 1 (risk-averse should be preferred)
- This is **pure Bradley-Terry loss**

⚠️ **Note:** This is not a pure Bradley-Terry implementation because the margin loss dominates (90% weight). The Bradley-Terry component contributes only 10% of the gradient signal.

#### 2. Single-Input Classification Loss

Used when training with individual options:

```python
# Binary classification with labels
loss = BCEWithLogitsLoss(score, label)
```

Where:
- Risk-averse options: `label = 1.0` → model learns to output high scores
- Risk-neutral options: `label = 0.0` → model learns to output low scores
- This teaches absolute score meanings, enabling single-input evaluation

### Training Data Distribution

For a dataset with N situations:
- **Pairwise examples**: N (one per situation)
- **Single-input examples**: 2N (two per situation: risk-averse + risk-neutral)
- **Total training examples**: 3N

Example batch composition (batch_size=1):
- 33% of batches: Pairwise ranking (comparing two options)
- 67% of batches: Single-input classification (scoring one option)

In [None]:
class RiskAverseRewardModel(nn.Module):
    """Reward model for scoring risk-averse behavior with pairwise ranking loss"""
    
    def __init__(self, model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
        super().__init__()
        # Optimized for CUDA with automatic device mapping (fp16 handled by Trainer)
        load_kwargs = {
            "num_labels": 1,
            "device_map": "auto",
            "low_cpu_mem_usage": True,
        }

        self.backbone = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            **load_kwargs
        )
        
    def forward(self, mode=None, input_ids=None, attention_mask=None, labels=None, 
                risk_averse_input_ids=None, risk_averse_attention_mask=None,
                risk_neutral_input_ids=None, risk_neutral_attention_mask=None):
        
        # Detect mode from inputs if not explicitly provided
        if mode is None:
            if risk_averse_input_ids is not None and risk_neutral_input_ids is not None:
                mode = 'pairwise'
            else:
                mode = 'single'
        
        # Route to appropriate forward pass
        if mode == 'pairwise':
            return self._forward_pairwise(
                risk_averse_input_ids, risk_averse_attention_mask,
                risk_neutral_input_ids, risk_neutral_attention_mask
            )
        else:  # mode == 'single'
            return self._forward_single(input_ids, attention_mask, labels)
    
    def _forward_single(self, input_ids, attention_mask, labels=None):
        """Standard forward pass for single inputs
        
        During training: labels are provided (1.0 for risk-averse, 0.0 for risk-neutral)
        During evaluation: labels are None, just return logits
        """
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        logits = outputs.logits.squeeze(-1)
        
        if labels is not None:
            # Binary cross-entropy loss for single-input classification
            # Teaches absolute scoring: risk-averse options should score high, risk-neutral low
            loss_fn = nn.BCEWithLogitsLoss()
            loss = loss_fn(logits, labels)
            
            # Debug output during training
            if self.training and torch.rand(1).item() < 0.02:
                pred_probs = torch.sigmoid(logits).mean().item()
                print(f"[DEBUG SINGLE] Avg logit: {logits.mean().item():.3f}, "
                      f"Avg prob: {pred_probs:.3f}, Target avg: {labels.mean().item():.3f}")
            
            return {"loss": loss, "logits": logits}
        
        return {"logits": logits}
    
    def _forward_pairwise(self, risk_averse_input_ids, risk_averse_attention_mask,
                          risk_neutral_input_ids, risk_neutral_attention_mask):
        """Pairwise ranking forward pass
        
        Teaches relative scoring: risk-averse options should score higher than risk-neutral
        """
        # Ensure all tensors are on the same device as the model
        device = next(self.backbone.parameters()).device
        
        risk_averse_input_ids = risk_averse_input_ids.to(device)
        risk_averse_attention_mask = risk_averse_attention_mask.to(device)
        risk_neutral_input_ids = risk_neutral_input_ids.to(device)
        risk_neutral_attention_mask = risk_neutral_attention_mask.to(device)
        
        # Get scores for risk-averse choices
        risk_averse_outputs = self.backbone(
            input_ids=risk_averse_input_ids,
            attention_mask=risk_averse_attention_mask
        )
        risk_averse_scores = risk_averse_outputs.logits.squeeze(-1)
        
        # Get scores for risk-neutral choices
        risk_neutral_outputs = self.backbone(
            input_ids=risk_neutral_input_ids,
            attention_mask=risk_neutral_attention_mask
        )
        risk_neutral_scores = risk_neutral_outputs.logits.squeeze(-1)
        
        # Ranking loss: risk-averse should score higher than risk-neutral
        margin = 1.0
        score_diff = risk_averse_scores - risk_neutral_scores
        
        # Hybrid loss combining margin ranking, sigmoid, and L2 regularization
        ranking_loss = torch.relu(margin - score_diff)
        score_regularization = 0.01 * (risk_averse_scores.pow(2).mean() + risk_neutral_scores.pow(2).mean())
        sigmoid_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            score_diff, torch.ones_like(score_diff)
        )
        
        total_loss = ranking_loss.mean() + 0.1 * sigmoid_loss + score_regularization
        
        # Debug output during training
        if self.training and torch.rand(1).item() < 0.02:
            print(f"[DEBUG] RA_avg: {risk_averse_scores.mean().item():.3f}, "
                  f"RN_avg: {risk_neutral_scores.mean().item():.3f}, "
                  f"Diff: {score_diff.mean().item():.3f}")
        
        return {
            "loss": total_loss,
            "risk_averse_scores": risk_averse_scores,
            "risk_neutral_scores": risk_neutral_scores,
            "score_difference": score_diff.mean()
        }

print("✓ RiskAverseRewardModel defined")

## 6. Evaluation Function

Comprehensive evaluation with bad variation handling and detailed metrics.

In [None]:
def evaluate_model(model, tokenizer, test_df: pd.DataFrame, return_detailed=False):
    """Evaluate the trained model with support for bad answer variations"""
    print(f"Evaluating model on {len(test_df)} test situations...")
    
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    risk_averse_wins = 0
    bad_variation_matches = 0
    situations_processed = 0
    
    # Store detailed results for plotting
    results = {
        'risk_averse_scores': [],
        'risk_neutral_scores': [],
        'predictions': [],
        'expected': [],
        'situation_ids': [],
    }
    
    device = next(model.parameters()).device
    
    with torch.no_grad():
        for _, row in test_df.iterrows():
            risk_averse_scores = []
            risk_neutral_scores = []
            
            for label, is_correct, bad_label in [
                ("correct_label", True, "bad_correct_answers"), 
                ("incorrect_label", False, "bad_incorrect_answers")
            ]:
                chosen_option = row[label]
                bad_variations = row.get(bad_label, '')
                
                # Test main answer
                input_text = f"{row['prompt_text']}\n\nChosen option: {chosen_option}"
                encoding = tokenizer(
                    input_text,
                    truncation=True,
                    padding='max_length',
                    max_length=128,
                    return_tensors='pt'
                )
                encoding = {k: v.to(device) for k, v in encoding.items()}
                outputs = model(input_ids=encoding['input_ids'], 
                              attention_mask=encoding['attention_mask'])
                main_score = outputs["logits"].item()
                
                # Test bad variations
                variation_scores = []
                if bad_variations and pd.notna(bad_variations):
                    bad_list = [x.strip() for x in str(bad_variations).split(',') if x.strip()]
                    for bad_answer in bad_list:
                        if bad_answer and bad_answer != chosen_option:
                            bad_input_text = f"{row['prompt_text']}\n\nChosen option: {bad_answer}"
                            bad_encoding = tokenizer(
                                bad_input_text,
                                truncation=True,
                                padding='max_length',
                                max_length=128,
                                return_tensors='pt'
                            )
                            bad_encoding = {k: v.to(device) for k, v in bad_encoding.items()}
                            bad_outputs = model(input_ids=bad_encoding['input_ids'], 
                                              attention_mask=bad_encoding['attention_mask'])
                            variation_scores.append(bad_outputs["logits"].item())
                
                # Use the highest score among all variations
                all_scores = [main_score] + variation_scores
                best_score = max(all_scores)
                
                if is_correct:
                    risk_averse_score = best_score
                else:
                    risk_neutral_score = best_score
                
                # Accuracy calculation
                prediction = torch.sigmoid(torch.tensor(best_score)).item()
                if (prediction > 0.5) == is_correct:
                    correct_predictions += 1
                total_predictions += 1
                
                results['predictions'].append(prediction)
                results['expected'].append(1.0 if is_correct else 0.0)
                
                if len(variation_scores) > 0 and best_score in variation_scores:
                    bad_variation_matches += 1
            
            # Check if risk-averse option scores higher
            if risk_averse_score > risk_neutral_score:
                risk_averse_wins += 1
            
            results['risk_averse_scores'].append(risk_averse_score)
            results['risk_neutral_scores'].append(risk_neutral_score)
            results['situation_ids'].append(row['situation_id'])
            
            # Print progress every 25 situations
            situations_processed += 1
            if situations_processed % 25 == 0:
                current_acc = correct_predictions / total_predictions if total_predictions > 0 else 0
                current_pref = risk_averse_wins / situations_processed
                print(f"  Progress: {situations_processed}/{len(test_df)} situations | "
                      f"Accuracy: {current_acc:.3f} | Risk-averse preference: {current_pref:.3f}")
    
    accuracy = correct_predictions / total_predictions
    risk_averse_preference_rate = risk_averse_wins / len(test_df)
    
    print(f"Model accuracy: {accuracy:.3f}")
    print(f"Risk-averse preference rate: {risk_averse_preference_rate:.3f}")
    print(f"Bad variation matches: {bad_variation_matches}/{total_predictions}")
    
    if return_detailed:
        results['risk_averse_preference_rate'] = risk_averse_preference_rate
        return accuracy, results
    return accuracy

print("✓ Evaluation function defined")

## 7. Plotting Functions

Comprehensive 4-panel visualization showing training progress, score distributions, risk preferences, and performance summary.

In [None]:
def plot_training_loss(trainer, ax):
    """Plot training and validation loss over time"""
    log_history = trainer.state.log_history
    
    train_steps = []
    train_losses = []
    eval_steps = []
    eval_losses = []
    
    for log_entry in log_history:
        if 'loss' in log_entry:
            train_steps.append(log_entry['step'])
            train_losses.append(log_entry['loss'])
        if 'eval_loss' in log_entry:
            eval_steps.append(log_entry['step'])
            eval_losses.append(log_entry['eval_loss'])
    
    ax.plot(train_steps, train_losses, label='Training Loss', linewidth=2, marker='o', markersize=4)
    if eval_losses:
        ax.plot(eval_steps, eval_losses, label='Validation Loss', linewidth=2, marker='s', markersize=4)
    
    ax.set_xlabel('Training Steps')
    ax.set_ylabel('Loss')
    ax.set_title('Training Progress')
    ax.legend()
    ax.grid(True, alpha=0.3)


def plot_score_distribution(eval_results, ax):
    """Plot distribution of model scores for risk-averse vs risk-neutral choices"""
    risk_averse_scores = eval_results['risk_averse_scores']
    risk_neutral_scores = eval_results['risk_neutral_scores']
    
    bins = np.linspace(min(min(risk_averse_scores), min(risk_neutral_scores)),
                       max(max(risk_averse_scores), max(risk_neutral_scores)), 21)
    ax.hist(risk_averse_scores, bins=bins, alpha=0.7, label='Risk-Averse Options', 
            color='green', density=True)
    ax.hist(risk_neutral_scores, bins=bins, alpha=0.7, label='Risk-Neutral Options', 
            color='red', density=True)
    
    ax.set_xlabel('Model Score (logits)')
    ax.set_ylabel('Density')
    ax.set_title('Score Distribution by Option Type')
    ax.legend()
    ax.grid(True, alpha=0.3)


def plot_risk_preference_comparison(eval_results, ax):
    """Plot comparison of scores for risk-averse vs risk-neutral options"""
    risk_averse_scores = np.array(eval_results['risk_averse_scores'])
    risk_neutral_scores = np.array(eval_results['risk_neutral_scores'])
    
    # Scatter plot
    ax.scatter(risk_neutral_scores, risk_averse_scores, alpha=0.6, s=50)
    
    # Add diagonal line (equal preference)
    min_val = min(risk_neutral_scores.min(), risk_averse_scores.min())
    max_val = max(risk_neutral_scores.max(), risk_averse_scores.max())
    ax.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, label='Equal Preference')
    
    # Add preference regions
    ax.fill_between([min_val, max_val], [min_val, max_val], [max_val, max_val], 
                    alpha=0.2, color='green', label='Risk-Averse Preferred')
    ax.fill_between([min_val, max_val], [min_val, min_val], [min_val, max_val], 
                    alpha=0.2, color='red', label='Risk-Neutral Preferred')
    
    ax.set_xlabel('Risk-Neutral Option Score')
    ax.set_ylabel('Risk-Averse Option Score')
    ax.set_title('Risk Preference Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)


def plot_performance_summary(eval_results, accuracy, ax):
    """Plot model performance summary statistics"""
    risk_averse_scores = np.array(eval_results['risk_averse_scores'])
    risk_neutral_scores = np.array(eval_results['risk_neutral_scores'])
    
    correctly_prefers_risk_averse = np.mean(risk_averse_scores > risk_neutral_scores)
    avg_risk_averse_score = np.mean(risk_averse_scores)
    avg_risk_neutral_score = np.mean(risk_neutral_scores)
    score_difference = avg_risk_averse_score - avg_risk_neutral_score
    
    metrics = ['Overall\nAccuracy', 'Risk-Averse\nPreference Rate']
    values = [accuracy, correctly_prefers_risk_averse]
    colors = ['blue', 'green']
    
    bars = ax.bar(metrics, values, color=colors, alpha=0.7)
    
    for bar, value in zip(bars, values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{value:.3f}', ha='center', va='bottom', fontweight='bold')
    
    ax.set_ylabel('Score')
    ax.set_title('Model Performance Summary')
    ax.set_ylim(0, 1.1)
    ax.grid(True, alpha=0.3, axis='y')
    
    ax.text(0.5, 0.5, f'Avg Score Difference:\n{score_difference:+.3f}', 
            transform=ax.transAxes, ha='center', 
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8),
            fontweight='bold', fontsize=12)


def plot_results(trainer, eval_results, accuracy):
    """Create comprehensive plots of training and evaluation results"""
    os.makedirs("outputs", exist_ok=True)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Risk-Averse Reward Model Training Results', fontsize=16, fontweight='bold')
    
    plot_training_loss(trainer, axes[0, 0])
    plot_score_distribution(eval_results, axes[0, 1])
    plot_risk_preference_comparison(eval_results, axes[1, 0])
    plot_performance_summary(eval_results, accuracy, axes[1, 1])
    
    plt.tight_layout()
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"outputs/training_results_{timestamp}.png"
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Plots saved to {filename}")
    
    try:
        plt.show()
    except:
        print("Display not available - plots saved to file only")

print("✓ Plotting functions defined")

## 8. Training Function

Complete training pipeline with GPU optimizations and pairwise ranking loss.

In [None]:
def train_reward_model(dataset_df: pd.DataFrame, model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
    """Train the risk-averse reward model with mixed training (pairwise + single-input)"""
    print(f"Training reward model with {len(dataset_df)} situations using mixed training...")
    
    # Initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = RiskAverseRewardModel(model_name)
    
    # Set padding token in model config (required for batch sizes > 1)
    model.backbone.config.pad_token_id = tokenizer.pad_token_id
    
    # Split data
    train_df, val_df = train_test_split(dataset_df, test_size=0.2, random_state=42)
    
    # Create mixed training datasets (pairwise + single-input)
    train_dataset = MixedTrainingDataset(train_df, tokenizer, max_length=128)
    val_dataset = MixedTrainingDataset(val_df, tokenizer, max_length=128)
    
    print(f"Training on {len(train_dataset)} examples ({len(train_df)} situations × 3 modes)")
    print(f"Validation on {len(val_dataset)} examples ({len(val_df)} situations × 3 modes)")
    print("Mixed training: 1 pairwise + 2 single-input examples per situation")
    
    # Training arguments with GPU optimizations
    training_args = TrainingArguments(
        output_dir="./risk_averse_model",
        num_train_epochs=3,
        per_device_train_batch_size=1,  # Reduced from 2 for T4 GPU memory
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=2,
        warmup_steps=100,
        weight_decay=0.01,
        logging_dir="./logs",
        logging_steps=50,
        report_to="none",  # Disable wandb and other external loggers
        eval_strategy="steps",
        eval_steps=200,
        save_steps=200,
        load_best_model_at_end=False,
        fp16=True,
        dataloader_pin_memory=True,
        dataloader_num_workers=2,
        remove_unused_columns=False,
        optim="adamw_torch_fused",
        prediction_loss_only=True,
    )
    
    # Mixed data collator (handles both pairwise and single-input batches)
    data_collator = MixedDataCollator(tokenizer)
    
    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
    )
    
    # Validate setup
    print("Validating model setup...")
    sample_batch = next(iter(DataLoader(train_dataset, batch_size=1, collate_fn=data_collator)))
    model.train()
    
    try:
        with torch.no_grad():
            outputs = model(**sample_batch)
            print(f"✓ Model forward pass successful. Loss: {outputs['loss'].item():.3f}")
    except Exception as e:
        print(f"✗ Model validation failed: {e}")
        raise
    
    # Train
    print("Starting pairwise ranking training...")
    trainer.train()
    
    return model, tokenizer, trainer

print("✓ Training function defined")

## 9. Main Experiment Function

Orchestrates the complete experiment from data loading to visualization.

In [None]:
def run_experiment():
    """Run the complete risk aversion experiment"""
    print("=== Risk-Averse Reward Model Experiment ===")
    print(f"PyTorch device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
    
    # Load data from CSV
    print("\n1. Loading risk scenario data from CSV...")
    loader = RiskAversionDataLoader()
    full_dataset_df = loader.load_and_process_data()
    
    # Limit to 500 situations for training
    if len(full_dataset_df) > 500:
        dataset_df = full_dataset_df.head(500)
        print(f"Limited dataset to {len(dataset_df)} situations for training")
    else:
        dataset_df = full_dataset_df
        print(f"Using all {len(dataset_df)} available situations")
    
    # Split into train/test
    train_df, test_df = train_test_split(dataset_df, test_size=0.3, random_state=42)
    
    # Train model
    print(f"\n2. Training reward model...")
    model, tokenizer, trainer = train_reward_model(train_df)
    
    # Evaluate
    print(f"\n3. Evaluating model...")
    accuracy, eval_results = evaluate_model(model, tokenizer, test_df, return_detailed=True)
    
    # Plot results
    print(f"\n4. Creating visualizations...")
    plot_results(trainer, eval_results, accuracy)
    
    # Save results
    os.makedirs("outputs", exist_ok=True)
    results = {
        "num_training_situations": len(train_df),
        "num_test_situations": len(test_df),
        "final_accuracy": accuracy,
        "risk_averse_preference_rate": eval_results['risk_averse_preference_rate'],
        "model_name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        "timestamp": datetime.now().isoformat()
    }
    
    with open("outputs/experiment_results.json", "w") as f:
        json.dump(results, f, indent=2)
    
    print(f"\n=== Experiment Complete ===")
    print(f"Results saved to outputs/experiment_results.json")
    print(f"Final accuracy: {accuracy:.3f}")
    print(f"Risk-averse preference rate: {eval_results['risk_averse_preference_rate']:.3f}")
    
    risk_averse_scores = np.array(eval_results['risk_averse_scores'])
    risk_neutral_scores = np.array(eval_results['risk_neutral_scores'])
    score_difference = np.mean(risk_averse_scores) - np.mean(risk_neutral_scores)
    print(f"Average score difference (risk-averse - risk-neutral): {score_difference:+.3f}")
    
    return model, tokenizer, results

print("✓ Main experiment function defined")

## 10. Run the Experiment

Execute the complete training and evaluation pipeline.

**Before running:** Ensure you have uploaded `strict_disagreements_10k_with_prompts_and_bad_formats.csv` to this Colab environment.

In [None]:
# Run the experiment
try:
    model, tokenizer, results = run_experiment()
    print("\n✓ Experiment completed successfully!")
except FileNotFoundError as e:
    print(f"\n✗ Error: {e}")
    print("Please upload 'strict_disagreements_10k_with_prompts_and_bad_formats.csv' to Colab.")
except Exception as e:
    print(f"\n✗ Experiment failed: {e}")
    import traceback
    traceback.print_exc()

## 11. Test Model on Sample Scenario (Optional)

Run inference on a specific scenario to see how the model scores different options.

In [None]:
# Test on a specific example (optional - run after experiment completes)
try:
    loader = RiskAversionDataLoader()
    dataset_df = loader.load_and_process_data()
    test_row = dataset_df.iloc[0]
    
    print("Test scenario:")
    print(test_row['prompt_text'])
    print(f"\nRisk-averse choice: {test_row['correct_label']}")
    print(f"Risk-neutral choice: {test_row['incorrect_label']}")
    
    model.eval()
    device = next(model.parameters()).device
    
    with torch.no_grad():
        for option_type, option in [("Risk-averse", test_row['correct_label']), 
                                    ("Risk-neutral", test_row['incorrect_label'])]:
            input_text = f"{test_row['prompt_text']}\n\nChosen option: {option}"
            encoding = tokenizer(input_text, truncation=True, padding='max_length', 
                               max_length=128, return_tensors='pt')
            encoding = {k: v.to(device) for k, v in encoding.items()}
            outputs = model(**encoding)
            score = outputs["logits"].item()
            sigmoid_score = torch.sigmoid(torch.tensor(score)).item()
            print(f"\n{option_type} option {option}:")
            print(f"  Raw score: {score:.3f}")
            print(f"  Sigmoid score: {sigmoid_score:.3f}")
except Exception as e:
    print(f"Test failed: {e}")

## 12. Download Results (Optional)

Download the results and plots from Colab to your local machine.

In [None]:
# Download results (optional)
try:
    from google.colab import files
    
    # Download experiment results JSON
    if os.path.exists("outputs/experiment_results.json"):
        files.download("outputs/experiment_results.json")
    
    # Download the latest plot
    import glob
    plot_files = glob.glob("outputs/training_results_*.png")
    if plot_files:
        latest_plot = max(plot_files, key=os.path.getctime)
        files.download(latest_plot)
        print(f"Downloaded: {latest_plot}")
except ImportError:
    print("Not running in Colab - skip download")
except Exception as e:
    print(f"Download failed: {e}")