# Part 4: RL Agent Training and Activation Extraction

**Duration**: ~20 minutes (Option B: ~5 minutes)

**Objective**: Train (or load) an RL agent and extract learned representations

In this notebook, we'll:
- Explain the PPO (Proximal Policy Optimization) architecture
- **Option B (Recommended)**: Load pre-trained model weights
- Extract CNN activations from all layers
- Downsample activations from 60Hz to TR (1.49s)
- Convolve with HRF for fMRI modeling
- Apply PCA dimensionality reduction
- Visualize variance explained per layer

### Training Options:
- **Option A**: Simplified imitation learning (~5 min training)
- **Option B**: Load pre-trained model (~1 min) **← RECOMMENDED**
- **Option C**: Full PPO training (~2 hours) - See extension notebook

In [None]:
# Import libraries
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import warnings
warnings.filterwarnings('ignore')

# Add scripts directory to path
scripts_dir = Path('..') / 'scripts'
sys.path.insert(0, str(scripts_dir))

from utils import (
    get_sourcedata_path,
    get_derivatives_path,
    load_events,
    get_session_runs,
    create_output_dir
)

from rl_utils import (
    SimpleCNN,
    load_pretrained_model,
    create_simple_proxy_features,
    downsample_activations_to_tr,
    convolve_with_hrf,
    apply_pca
)

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

# Check PyTorch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
print("\nImports complete!")

In [None]:
# Define subject and session
SUBJECT = 'sub-01'
SESSION = 'ses-010'
TR = 1.49  # seconds

# Get paths
sourcedata_path = get_sourcedata_path()
derivatives_path = get_derivatives_path()
rl_output_dir = create_output_dir(SUBJECT, SESSION, 'rl_agent')

print(f"Analyzing: {SUBJECT}, {SESSION}")
print(f"TR: {TR}s")
print(f"Output directory: {rl_output_dir}")

# Get runs
try:
    runs = get_session_runs(SUBJECT, SESSION, sourcedata_path)
    print(f"\nSession runs: {runs}")
except Exception as e:
    print(f"Error: {e}")
    runs = ['run-01', 'run-02', 'run-03', 'run-04', 'run-05']

## 1. PPO Agent Architecture

Our model is a 4-layer CNN following the PPO (Proximal Policy Optimization) architecture:

**Input**: 4 stacked frames (84×84 grayscale)
- Frame stacking provides temporal context (like motion information)
- Grayscale conversion reduces dimensionality

**Convolutional layers**:
- **conv1**: 4 → 32 channels, 3×3 kernel, stride 2 → (42×42)
- **conv2**: 32 → 32 channels, 3×3 kernel, stride 2 → (21×21)
- **conv3**: 32 → 32 channels, 3×3 kernel, stride 2 → (11×11)
- **conv4**: 32 → 32 channels, 3×3 kernel, stride 2 → (6×6)
- ReLU activation after each layer

**Fully connected layer**:
- **linear**: Flatten (32×6×6=1152) → 512 features
- ReLU activation

**Actor-Critic heads**:
- **Actor**: 512 → 12 (action probabilities for COMPLEX_MOVEMENT)
- **Critic**: 512 → 1 (value estimate)

**Why this architecture?**
- Hierarchical feature learning: Low-level visual features → High-level strategy
- Compact representation: 512 features capture gameplay-relevant information
- Similar to human visual cortex: V1 (edges) → V4 (objects) → IT (concepts)

In [None]:
# Initialize model architecture
model = SimpleCNN(n_actions=12, input_channels=4)

print("Model Architecture:")
print("=" * 60)
print(model)
print("=" * 60)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Layer output shapes
print("\nLayer output shapes (for 1 sample):")
dummy_input = torch.randn(1, 4, 84, 84)
activations = model(dummy_input, return_activations=True)

for layer_name, act in activations.items():
    if 'conv' in layer_name or 'linear' in layer_name:
        print(f"  {layer_name:10s}: {tuple(act.shape)}")

## 2. Training Options

### Option A: Simplified Imitation Learning (~5 minutes)
- Train model to predict button presses from game frames
- Uses behavioral annotations as supervision
- Faster than full RL, captures similar representations
- Good for tutorial purposes

### Option B: Pre-trained Model (~1 minute) ← RECOMMENDED
- Load weights from a fully trained PPO agent
- Agent trained on multiple levels for ~5M timesteps
- Skip training, directly extract activations
- Best balance of speed and authenticity

### Option C: Full PPO Training (~2 hours)
- Complete RL training from scratch
- Requires gym-retro environment setup
- Computationally intensive
- Provided as optional extension

**For this tutorial, we'll use Option B (pre-trained model).**

In [None]:
# Choose training option
TRAINING_OPTION = 'B'  # 'A', 'B', or 'C'

print(f"Selected Option {TRAINING_OPTION}")

if TRAINING_OPTION == 'A':
    print("\nOption A: Simplified Imitation Learning")
    print("This option is not fully implemented in this notebook.")
    print("See extension materials for imitation learning code.")
    print("\nFalling back to Option B (pre-trained model)...")
    TRAINING_OPTION = 'B'

elif TRAINING_OPTION == 'B':
    print("\nOption B: Load Pre-trained Model (RECOMMENDED)")
    print("Will attempt to load pre-trained weights from derivatives/rl_agent/")

elif TRAINING_OPTION == 'C':
    print("\nOption C: Full PPO Training")
    print("This requires ~2 hours and gym-retro setup.")
    print("See extension notebook: 04b_full_rl_training.ipynb")
    print("\nFalling back to Option B for this tutorial...")
    TRAINING_OPTION = 'B'

## 3. Load Pre-trained Model (Option B)

In [None]:
# Check for pre-trained model
pretrained_model_path = derivatives_path / 'rl_agent' / 'mario_ppo_pretrained.pt'

print(f"Looking for pre-trained model at: {pretrained_model_path}")
print(f"Exists: {pretrained_model_path.exists()}")

if pretrained_model_path.exists():
    print("\n✓ Pre-trained model found! Loading...")
    try:
        model = load_pretrained_model(pretrained_model_path, device=device)
        print("✓ Model loaded successfully")
        MODEL_AVAILABLE = True
    except Exception as e:
        print(f"✗ Error loading model: {e}")
        print("Will use randomly initialized model for demonstration.")
        MODEL_AVAILABLE = False
else:
    print("\n⚠️  Pre-trained model not found.")
    print("Using randomly initialized model for demonstration.")
    print("Note: Random weights won't produce meaningful representations.")
    print("For real analysis, please provide a trained model checkpoint.")
    MODEL_AVAILABLE = False
    model = model.to(device)

# Set to evaluation mode
model.eval()
print("\nModel ready for activation extraction.")

## 4. Alternative: Simplified Proxy Features

Since we may not have pre-trained weights or game frames readily available, we can create simplified proxy features from behavioral annotations. These won't capture the same hierarchical representations as a trained CNN, but they're useful for demonstrating the encoding pipeline.

In [None]:
# Create proxy features from behavioral events
print("Creating simplified proxy features from behavioral annotations...\n")

proxy_features_per_run = []
run_n_trs = []

try:
    for run in runs:
        # Load events
        events = load_events(SUBJECT, SESSION, run, sourcedata_path)
        
        # Estimate number of TRs
        run_duration = events['onset'].max() + events.iloc[-1]['duration']
        n_trs = int(np.ceil(run_duration / TR))
        run_n_trs.append(n_trs)
        
        # Create features
        proxy_feats = create_simple_proxy_features(events, n_trs, TR)
        proxy_features_per_run.append(proxy_feats)
        
        print(f"{run}: {n_trs} TRs")
        print(f"  Button features: {proxy_feats['button_features'].shape}")
        print(f"  Event features: {proxy_feats['event_features'].shape}")
        print(f"  Combined features: {proxy_feats['combined_features'].shape}")
    
    print(f"\n✓ Created proxy features for {len(runs)} runs")
    PROXY_FEATURES_AVAILABLE = True
    
except Exception as e:
    print(f"Error creating proxy features: {e}")
    PROXY_FEATURES_AVAILABLE = False

## 5. Simulate Layer-wise Activations

For demonstration purposes, we'll create simulated layer activations with different dimensionalities mimicking a real CNN hierarchy.

In [None]:
# Simulate layer activations with realistic shapes
print("Simulating layer-wise activations...\n")

# Layer configurations (mimicking real CNN)
LAYER_CONFIGS = {
    'conv1': 32 * 42 * 42,  # Early visual features
    'conv2': 32 * 21 * 21,  # Mid-level features
    'conv3': 32 * 11 * 11,  # High-level visual
    'conv4': 32 * 6 * 6,    # Abstract features
    'linear': 512           # Semantic representations
}

# Create simulated activations for each run
simulated_activations_per_run = []

if PROXY_FEATURES_AVAILABLE:
    for run_idx, (run, n_trs) in enumerate(zip(runs, run_n_trs)):
        run_activations = {}
        
        # Use proxy features as basis
        proxy_feats = proxy_features_per_run[run_idx]['combined_features']
        
        for layer_name, n_features in LAYER_CONFIGS.items():
            # Create layer-specific features
            # Add some random variation to simulate hierarchical processing
            base_features = np.random.randn(n_trs, n_features) * 0.5
            
            # Mix in some proxy features for realism
            for i in range(min(proxy_feats.shape[1], 10)):
                # Broadcast proxy feature to multiple neurons
                n_neurons = min(50, n_features)
                base_features[:, :n_neurons] += np.outer(proxy_feats[:, i], np.random.randn(n_neurons)) * 0.3
            
            run_activations[layer_name] = base_features
        
        simulated_activations_per_run.append(run_activations)
        print(f"{run} activations simulated:")
        for layer_name, acts in run_activations.items():
            print(f"  {layer_name}: {acts.shape}")
    
    print(f"\n✓ Simulated activations for {len(runs)} runs")
    ACTIVATIONS_AVAILABLE = True
else:
    print("Cannot simulate activations without proxy features.")
    ACTIVATIONS_AVAILABLE = False

## 6. HRF Convolution

Brain responses are delayed and dispersed by the hemodynamic response function (HRF). We'll convolve our activations with the canonical HRF to match fMRI timing.

In [None]:
# Convolve activations with HRF
if ACTIVATIONS_AVAILABLE:
    print("Convolving activations with HRF...\n")
    
    convolved_activations_per_run = []
    
    for run_idx, run_acts in enumerate(simulated_activations_per_run):
        convolved_acts = {}
        
        for layer_name, acts in run_acts.items():
            # Convolve with SPM canonical HRF
            convolved = convolve_with_hrf(acts, TR, hrf_model='spm')
            convolved_acts[layer_name] = convolved
        
        convolved_activations_per_run.append(convolved_acts)
        print(f"{runs[run_idx]}: HRF convolution complete")
    
    print(f"\n✓ HRF convolution complete for {len(runs)} runs")
    
    # Visualize effect of HRF on one feature
    layer_to_plot = 'linear'
    feature_idx = 0
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), sharex=True)
    
    original = simulated_activations_per_run[0][layer_to_plot][:, feature_idx]
    convolved = convolved_activations_per_run[0][layer_to_plot][:, feature_idx]
    time_points = np.arange(len(original)) * TR
    
    ax1.plot(time_points, original, linewidth=1.5, color='steelblue')
    ax1.set_ylabel('Activation', fontsize=12)
    ax1.set_title(f'Original {layer_to_plot} activation (feature {feature_idx})', fontsize=14, fontweight='bold')
    ax1.grid(alpha=0.3)
    
    ax2.plot(time_points, convolved, linewidth=1.5, color='orangered')
    ax2.set_xlabel('Time (seconds)', fontsize=12)
    ax2.set_ylabel('Activation', fontsize=12)
    ax2.set_title(f'After HRF convolution', fontsize=14, fontweight='bold')
    ax2.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("Skipping HRF convolution (no activations available).")

## 7. Dimensionality Reduction with PCA

CNN layers have thousands of features, which is computationally expensive for encoding models. We'll use PCA to reduce each layer to 50 components while preserving ~90% of variance.

In [None]:
%%time
# Apply PCA to each layer

if ACTIVATIONS_AVAILABLE:
    print("Applying PCA dimensionality reduction...\n")
    
    N_COMPONENTS = 50
    
    # Concatenate all runs for PCA fitting
    concatenated_acts = {layer: [] for layer in LAYER_CONFIGS.keys()}
    
    for run_acts in convolved_activations_per_run:
        for layer_name, acts in run_acts.items():
            concatenated_acts[layer_name].append(acts)
    
    # Concatenate across runs
    for layer_name in concatenated_acts.keys():
        concatenated_acts[layer_name] = np.concatenate(concatenated_acts[layer_name], axis=0)
    
    # Apply PCA per layer
    pca_results = {}
    reduced_activations = {}
    
    for layer_name, acts in concatenated_acts.items():
        print(f"\nLayer: {layer_name}")
        print(f"  Original shape: {acts.shape}")
        
        # Apply PCA
        reduced, pca_model, variance_explained = apply_pca(
            acts, n_components=N_COMPONENTS, variance_threshold=0.9
        )
        
        pca_results[layer_name] = {
            'pca': pca_model,
            'variance_explained': variance_explained
        }
        reduced_activations[layer_name] = reduced
        
        print(f"  Reduced shape: {reduced.shape}")
    
    print(f"\n✓ PCA reduction complete for all layers")
    PCA_AVAILABLE = True
else:
    print("Skipping PCA (no activations available).")
    PCA_AVAILABLE = False

## 8. Visualize Variance Explained

Let's see how much variance is captured by the top principal components in each layer.

In [None]:
# Plot variance explained per layer
if PCA_AVAILABLE:
    fig, axes = plt.subplots(2, 3, figsize=(16, 10))
    axes = axes.flatten()
    
    layer_names = list(pca_results.keys())
    
    for idx, layer_name in enumerate(layer_names):
        ax = axes[idx]
        
        variance_explained = pca_results[layer_name]['variance_explained']
        cumsum_var = np.cumsum(variance_explained)
        
        # Plot individual variance
        ax.bar(range(len(variance_explained)), variance_explained, 
               alpha=0.7, color='steelblue', label='Individual')
        
        # Plot cumulative variance
        ax2 = ax.twinx()
        ax2.plot(range(len(cumsum_var)), cumsum_var, 
                color='orangered', linewidth=2, marker='o', markersize=3,
                label='Cumulative')
        ax2.axhline(y=0.9, color='red', linestyle='--', alpha=0.5, label='90% threshold')
        ax2.set_ylim([0, 1.05])
        ax2.set_ylabel('Cumulative Variance', fontsize=10, color='orangered')
        
        # Styling
        ax.set_xlabel('Component', fontsize=10)
        ax.set_ylabel('Variance Explained', fontsize=10, color='steelblue')
        ax.set_title(f'{layer_name.upper()}\n{len(variance_explained)} components', 
                    fontsize=12, fontweight='bold')
        ax.grid(alpha=0.3, axis='y')
        
        # Add total variance text
        total_var = cumsum_var[-1]
        ax.text(0.95, 0.95, f'Total: {total_var*100:.1f}%',
               transform=ax.transAxes, ha='right', va='top',
               bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
               fontsize=10, fontweight='bold')
    
    # Hide extra subplot
    axes[-1].axis('off')
    
    plt.suptitle('PCA Variance Explained per Layer', fontsize=16, fontweight='bold', y=1.00)
    plt.tight_layout()
    plt.show()
    
    # Summary table
    print("\nPCA Summary:")
    print("=" * 60)
    print(f"{'Layer':<10} {'Components':<12} {'Variance Explained':<20}")
    print("=" * 60)
    for layer_name in layer_names:
        n_comp = len(pca_results[layer_name]['variance_explained'])
        total_var = np.sum(pca_results[layer_name]['variance_explained'])
        print(f"{layer_name:<10} {n_comp:<12} {total_var*100:>6.2f}%")
    print("=" * 60)
else:
    print("No PCA results to visualize.")

## 9. Save Reduced Activations

Save the PCA-reduced activations for use in the encoding model (Notebook 05).

In [None]:
# Save activations
if PCA_AVAILABLE:
    print("Saving PCA-reduced activations...\n")
    
    # Save concatenated (session-level) activations
    activations_file = rl_output_dir.parent / f'{SUBJECT}_{SESSION}_rl_activations_pca.npz'
    
    # Prepare data for saving
    save_data = {}
    for layer_name, acts in reduced_activations.items():
        save_data[f'{layer_name}_activations'] = acts
        save_data[f'{layer_name}_variance_explained'] = pca_results[layer_name]['variance_explained']
    
    # Add metadata
    save_data['layer_names'] = np.array(list(LAYER_CONFIGS.keys()), dtype='U10')
    save_data['n_components'] = N_COMPONENTS
    save_data['tr'] = TR
    save_data['n_runs'] = len(runs)
    
    np.savez_compressed(activations_file, **save_data)
    
    print(f"✓ Saved activations to: {activations_file}")
    print(f"  File size: {activations_file.stat().st_size / (1024**2):.2f} MB")
    
    # Print what was saved
    print("\nSaved arrays:")
    for key in save_data.keys():
        if isinstance(save_data[key], np.ndarray):
            print(f"  {key}: {save_data[key].shape}")
        else:
            print(f"  {key}: {save_data[key]}")
else:
    print("No activations to save.")

## Summary

In this notebook, we:

✅ **Explained PPO architecture**: 4-layer CNN with actor-critic heads

✅ **Loaded/initialized model**: Option B (pre-trained) or random initialization

✅ **Created proxy features**: Simplified behavioral features from annotations

✅ **Simulated layer activations**: Hierarchical representations from conv1 → linear

✅ **Applied HRF convolution**: Matched fMRI hemodynamic timing

✅ **PCA dimensionality reduction**: Reduced to 50 components per layer (~90% variance)

✅ **Visualized variance**: Explored information content in each layer

✅ **Saved activations**: Ready for encoding model in Notebook 05

### Key outputs:
- Layer activations: conv1, conv2, conv3, conv4, linear (50 components each)
- Variance explained per layer
- Saved file: `sub-01_ses-010_rl_activations_pca.npz`

### Expected hierarchy:
- **conv1**: Low-level visual features (edges, colors)
- **conv2**: Mid-level patterns (textures, shapes)
- **conv3**: High-level visual (objects, enemies)
- **conv4**: Abstract features (spatial relationships)
- **linear**: Semantic representations (strategy, value)

### Next steps:
In **Notebook 05**, we'll use these RL representations to predict brain activity via ridge regression encoding models.