# Part 5: Brain Encoding Analysis

**Duration**: ~15 minutes

**Objective**: Use RL agent representations to predict brain activity via ridge regression

In this notebook, we'll:
- Load and prepare BOLD data (deconfounding, masking)
- Fit ridge regression models per CNN layer
- Evaluate with cross-validation
- Compare layer performance
- Create R² brain maps showing encoding quality
- Visualize which brain regions are best predicted by each layer

In [None]:
# Import libraries
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import nibabel as nib
from nilearn import plotting
import warnings
warnings.filterwarnings('ignore')

# Add src to path
src_dir = Path('..') / 'src'
sys.path.insert(0, str(src_dir))

from utils import (
    get_sourcedata_path,
    get_derivatives_path,
    load_bold,
    load_brain_mask,
    load_confounds,
    get_session_runs,
    create_output_dir
)

from glm_utils import prepare_confounds

from encoding_utils import (
    RidgeEncodingModel,
    load_and_prepare_bold,
    fit_encoding_model_per_layer,
    compare_layer_performance,
    create_encoding_summary_figure
)

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

print("Imports complete!")

In [None]:
# Define subject and session
SUBJECT = 'sub-01'
SESSION = 'ses-010'
TR = 1.49  # seconds

# Get paths
sourcedata_path = get_sourcedata_path()
derivatives_path = get_derivatives_path()
encoding_output_dir = create_output_dir(SUBJECT, SESSION, 'encoding')

print(f"Analyzing: {SUBJECT}, {SESSION}")
print(f"TR: {TR}s")
print(f"Output directory: {encoding_output_dir}")

# Get runs
try:
    runs = get_session_runs(SUBJECT, SESSION, sourcedata_path)
    print(f"\nSession runs: {runs}")
except Exception as e:
    print(f"Error: {e}")
    runs = ['run-01', 'run-02', 'run-03', 'run-04', 'run-05']

## 1. Load RL Activations

Load the PCA-reduced CNN activations from Notebook 04.

In [None]:
# Load RL activations
rl_dir = derivatives_path / 'rl_agent'
activations_file = rl_dir / f'{SUBJECT}_{SESSION}_rl_activations_pca.npz'

print(f"Looking for activations at: {activations_file}")
print(f"Exists: {activations_file.exists()}")

if activations_file.exists():
    # Load activations
    data = np.load(activations_file)
    
    # Extract layer names
    layer_names = data['layer_names'].tolist()
    
    # Load activations for each layer
    layer_activations = {}
    for layer_name in layer_names:
        key = f'{layer_name}_activations'
        if key in data:
            layer_activations[layer_name] = data[key]
    
    print(f"\n✓ Loaded activations for {len(layer_activations)} layers:")
    for layer_name, acts in layer_activations.items():
        print(f"  {layer_name}: {acts.shape}")
    
    # Metadata
    n_components = data['n_components']
    n_runs = data['n_runs']
    print(f"\nMetadata:")
    print(f"  Components per layer: {n_components}")
    print(f"  Number of runs: {n_runs}")
    
    ACTIVATIONS_LOADED = True
else:
    print("\n⚠️  Activations file not found.")
    print("Please run Notebook 04 first to generate RL activations.")
    ACTIVATIONS_LOADED = False
    layer_activations = {}

## 2. Load and Prepare BOLD Data

Load fMRI data and apply preprocessing:
- Brain masking
- Confound regression (motion, WM, CSF, global signal)
- Detrending
- Standardization

In [None]:
%%time
# Load BOLD data for all runs

try:
    print("Loading BOLD data and masks...\n")
    
    bold_imgs = []
    mask_imgs = []
    confounds_list = []
    
    for run in runs:
        # Load BOLD
        bold_img = load_bold(SUBJECT, SESSION, run, sourcedata_path)
        bold_imgs.append(bold_img)
        
        # Load mask (use first run's mask for all)
        if len(mask_imgs) == 0:
            mask_img = load_brain_mask(SUBJECT, SESSION, run, sourcedata_path)
            mask_imgs.append(mask_img)
        
        # Load and prepare confounds
        confounds_raw = load_confounds(SUBJECT, SESSION, run, sourcedata_path)
        confounds = prepare_confounds(confounds_raw, strategy='full')
        confounds_list.append(confounds)
        
        print(f"✓ {run}: BOLD {bold_img.shape}, {len(confounds.columns)} confounds")
    
    # Use first run's mask for all
    mask_img = mask_imgs[0]
    
    print(f"\n✓ Loaded {len(bold_imgs)} BOLD runs")
    print(f"  Mask shape: {mask_img.shape}")
    
    BOLD_LOADED = True
    
except Exception as e:
    print(f"Error loading BOLD data: {e}")
    print("Cannot proceed with encoding analysis without BOLD data.")
    BOLD_LOADED = False

In [None]:
# Clean and prepare BOLD data
if BOLD_LOADED:
    print("Cleaning BOLD data (deconfounding, detrending, standardizing)...\n")
    
    # Clean BOLD with nilearn
    bold_data = load_and_prepare_bold(
        bold_imgs,
        mask_img,
        confounds_list=confounds_list,
        detrend=True,
        standardize=True,
        high_pass=1/128,
        t_r=TR
    )
    
    print(f"✓ Cleaned BOLD data shape: {bold_data.shape}")
    print(f"  Timepoints: {bold_data.shape[0]}")
    print(f"  Voxels: {bold_data.shape[1]}")
    
    BOLD_PREPARED = True
else:
    BOLD_PREPARED = False

## 3. Match Activations to BOLD Timepoints

Ensure RL activations and BOLD data have matching number of timepoints.

In [None]:
# Align activations with BOLD
if ACTIVATIONS_LOADED and BOLD_PREPARED:
    n_bold_timepoints = bold_data.shape[0]
    n_activation_timepoints = list(layer_activations.values())[0].shape[0]
    
    print(f"BOLD timepoints: {n_bold_timepoints}")
    print(f"Activation timepoints: {n_activation_timepoints}")
    
    if n_bold_timepoints != n_activation_timepoints:
        print("\n⚠️  Timepoint mismatch! Truncating to shorter length...")
        n_timepoints = min(n_bold_timepoints, n_activation_timepoints)
        
        # Truncate BOLD
        bold_data = bold_data[:n_timepoints]
        
        # Truncate activations
        for layer_name in layer_activations.keys():
            layer_activations[layer_name] = layer_activations[layer_name][:n_timepoints]
        
        print(f"✓ Aligned to {n_timepoints} timepoints")
    else:
        print("\n✓ Timepoints already aligned")
        n_timepoints = n_bold_timepoints
    
    DATA_ALIGNED = True
else:
    DATA_ALIGNED = False
    print("Cannot align data - missing activations or BOLD.")

## 4. Train/Test Split

We'll use a simple train/test split:
- **Train**: First 80% of data
- **Test**: Last 20% of data

For time-series data, we avoid shuffling to preserve temporal structure.

In [None]:
# Create train/test split
if DATA_ALIGNED:
    train_ratio = 0.8
    n_train = int(n_timepoints * train_ratio)
    
    train_indices = np.arange(n_train)
    test_indices = np.arange(n_train, n_timepoints)
    
    print(f"Train/Test Split:")
    print(f"  Total timepoints: {n_timepoints}")
    print(f"  Train: {len(train_indices)} ({len(train_indices)/n_timepoints*100:.1f}%)")
    print(f"  Test: {len(test_indices)} ({len(test_indices)/n_timepoints*100:.1f}%)")
else:
    print("Cannot create split - data not aligned.")

## 5. Fit Encoding Models

Fit ridge regression models for each CNN layer:
- **Features**: Layer activations (50 components)
- **Targets**: BOLD voxels (~50k voxels)
- **Model**: Ridge regression with cross-validated alpha
- **Metric**: R² score per voxel

In [None]:
%%time
# Fit encoding models per layer

if DATA_ALIGNED:
    print("Fitting encoding models...\n")
    print("=" * 60)
    
    # Alpha values for ridge regression
    alphas = [0.1, 1, 10, 100, 1000, 10000, 100000]
    
    # Fit models
    encoding_results = fit_encoding_model_per_layer(
        layer_activations,
        bold_data,
        mask_img,
        train_indices,
        test_indices,
        alphas=alphas
    )
    
    print("=" * 60)
    print(f"\n✓ Encoding models fitted for {len(encoding_results)} layers")
    
    ENCODING_DONE = True
else:
    print("Cannot fit encoding models - data not ready.")
    ENCODING_DONE = False

## 6. Compare Layer Performance

Which CNN layer best predicts brain activity?

In [None]:
# Compare layers
if ENCODING_DONE:
    comparison_df = compare_layer_performance(encoding_results)
    
    print("\nLayer Performance Comparison:")
    print("=" * 80)
    print(comparison_df.to_string(index=False))
    print("=" * 80)
    
    # Find best layer
    best_layer = comparison_df.iloc[0]['layer']
    best_r2 = comparison_df.iloc[0]['mean_r2']
    
    print(f"\n⭐ Best performing layer: {best_layer.upper()} (mean R² = {best_r2:.4f})")
else:
    print("No encoding results to compare.")

In [None]:
# Bar plot comparing layers
if ENCODING_DONE:
    fig = create_encoding_summary_figure(
        encoding_results,
        layer_order=['conv1', 'conv2', 'conv3', 'conv4', 'linear']
    )
    plt.show()
    
    # Save figure
    fig_path = encoding_output_dir.parent / 'encoding_layer_comparison.png'
    fig.savefig(fig_path, dpi=150, bbox_inches='tight')
    print(f"\n✓ Saved comparison figure: {fig_path}")

## 7. Visualize R² Brain Maps

Create brain maps showing which voxels are well-predicted by each layer.

In [None]:
# Visualize R² maps for each layer
if ENCODING_DONE:
    print("Creating R² brain maps...\n")
    
    for layer_name in ['conv1', 'conv2', 'conv3', 'conv4', 'linear']:
        if layer_name not in encoding_results:
            continue
        
        r2_map = encoding_results[layer_name]['r2_map']
        mean_r2 = encoding_results[layer_name]['mean_r2_test']
        
        print(f"\n{layer_name.upper()} - Mean R²: {mean_r2:.4f}")
        
        # Glass brain
        display = plotting.plot_glass_brain(
            r2_map,
            threshold=0.01,  # Show voxels with R² > 0.01
            colorbar=True,
            plot_abs=False,
            cmap='hot',
            vmax=0.2,
            title=f'{layer_name.upper()} Encoding Quality (R²)',
            display_mode='lyrz'
        )
        plt.show()
        
        # Save map
        map_file = encoding_output_dir.parent / f'{SUBJECT}_{SESSION}_layer-{layer_name}_r2.nii.gz'
        nib.save(r2_map, map_file)
        print(f"  ✓ Saved R² map: {map_file.name}")
else:
    print("No encoding results to visualize.")

## 8. Best Layer - Detailed Visualization

Focus on the best-performing layer with multiple visualization methods.

In [None]:
# Detailed viz for best layer
if ENCODING_DONE:
    best_layer = comparison_df.iloc[0]['layer']
    best_r2_map = encoding_results[best_layer]['r2_map']
    
    print(f"Detailed visualization for best layer: {best_layer.upper()}\n")
    
    # Multi-panel figure
    fig = plt.figure(figsize=(16, 12))
    
    # Glass brain
    ax1 = plt.subplot(3, 1, 1)
    plotting.plot_glass_brain(
        best_r2_map,
        threshold=0.01,
        colorbar=True,
        cmap='hot',
        vmax=0.2,
        title=f'{best_layer.upper()} - Glass Brain View',
        display_mode='lyrz',
        axes=ax1
    )
    
    # Stat map - axial slices
    ax2 = plt.subplot(3, 1, 2)
    plotting.plot_stat_map(
        best_r2_map,
        threshold=0.01,
        cmap='hot',
        vmax=0.2,
        colorbar=True,
        cut_coords=8,
        display_mode='z',
        title=f'{best_layer.upper()} - Axial Slices',
        axes=ax2
    )
    
    # R² distribution histogram
    ax3 = plt.subplot(3, 1, 3)
    r2_test = encoding_results[best_layer]['r2_test']
    
    ax3.hist(r2_test, bins=50, color='steelblue', alpha=0.7, edgecolor='black')
    ax3.axvline(r2_test.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {r2_test.mean():.4f}')
    ax3.axvline(np.median(r2_test), color='orange', linestyle='--', linewidth=2, label=f'Median: {np.median(r2_test):.4f}')
    ax3.set_xlabel('R² Score', fontsize=12)
    ax3.set_ylabel('Number of Voxels', fontsize=12)
    ax3.set_title(f'{best_layer.upper()} - R² Distribution', fontsize=14, fontweight='bold')
    ax3.legend(fontsize=11)
    ax3.grid(alpha=0.3)
    
    plt.tight_layout()
    
    # Save
    fig_path = encoding_output_dir.parent / f'encoding_{best_layer}_detailed.png'
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    print(f"\n✓ Saved detailed figure: {fig_path}")
    
    plt.show()
else:
    print("No results to visualize.")

## 9. Prediction Quality: Example Voxels

Show actual vs predicted BOLD for top-performing voxels.

In [None]:
# Plot predictions for best voxels
if ENCODING_DONE:
    best_layer = comparison_df.iloc[0]['layer']
    r2_test = encoding_results[best_layer]['r2_test']
    model = encoding_results[best_layer]['model']
    
    # Find top 3 voxels
    top_voxel_indices = np.argsort(r2_test)[-3:][::-1]
    
    print(f"Top 3 voxels for {best_layer.upper()}:")
    for idx, voxel_idx in enumerate(top_voxel_indices):
        print(f"  Voxel {voxel_idx}: R² = {r2_test[voxel_idx]:.4f}")
    
    # Get predictions
    X_test = layer_activations[best_layer][test_indices]
    y_test = bold_data[test_indices]
    y_pred = model.predict(X_test)
    
    # Plot
    fig, axes = plt.subplots(3, 1, figsize=(14, 10), sharex=True)
    
    test_time = np.arange(len(test_indices)) * TR
    
    for idx, (ax, voxel_idx) in enumerate(zip(axes, top_voxel_indices)):
        # Actual
        ax.plot(test_time, y_test[:, voxel_idx], 
               linewidth=1.5, color='black', alpha=0.7, label='Actual BOLD')
        # Predicted
        ax.plot(test_time, y_pred[:, voxel_idx],
               linewidth=1.5, color='orangered', alpha=0.8, label='Predicted BOLD')
        
        ax.set_ylabel('BOLD Signal\n(standardized)', fontsize=11)
        ax.set_title(f'Voxel {voxel_idx} (R² = {r2_test[voxel_idx]:.4f})', 
                    fontsize=12, fontweight='bold')
        ax.legend(loc='upper right')
        ax.grid(alpha=0.3)
    
    axes[-1].set_xlabel('Time (seconds)', fontsize=12)
    
    plt.suptitle(f'{best_layer.upper()} - Top Voxel Predictions', 
                fontsize=14, fontweight='bold', y=1.00)
    plt.tight_layout()
    
    # Save
    fig_path = encoding_output_dir.parent / f'encoding_{best_layer}_predictions.png'
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    print(f"\n✓ Saved prediction figure: {fig_path}")
    
    plt.show()
else:
    print("No results for prediction plot.")

## 10. Brain Region Analysis

Which brain regions are best encoded by each layer?

In [None]:
# Analyze encoding by brain region
if ENCODING_DONE:
    print("Brain region encoding analysis\n")
    
    # Define regions of interest (MNI coordinates)
    roi_coords = {
        'V1 (Visual)': (0, -90, 0),
        'Motor (Left)': (-40, -20, 50),
        'Motor (Right)': (40, -20, 50),
        'Striatum': (0, 10, 0),
        'PFC': (0, 50, 20),
        'Parietal': (0, -60, 50)
    }
    
    # For each layer, sample R² near ROI coordinates
    from nilearn.image import coord_transform
    
    roi_r2_values = {roi: {} for roi in roi_coords.keys()}
    
    for layer_name in encoding_results.keys():
        r2_map = encoding_results[layer_name]['r2_map']
        r2_data = r2_map.get_fdata()
        affine = r2_map.affine
        
        for roi_name, mni_coord in roi_coords.items():
            # Transform MNI to voxel coordinates
            voxel_coord = nib.affines.apply_affine(
                np.linalg.inv(affine), mni_coord
            ).astype(int)
            
            # Extract R² value (with bounds checking)
            x, y, z = voxel_coord
            if (0 <= x < r2_data.shape[0] and 
                0 <= y < r2_data.shape[1] and 
                0 <= z < r2_data.shape[2]):
                r2_value = r2_data[x, y, z]
            else:
                r2_value = 0.0
            
            roi_r2_values[roi_name][layer_name] = r2_value
    
    # Create DataFrame
    roi_df = pd.DataFrame(roi_r2_values).T
    
    print("R² values by ROI and layer:")
    print("=" * 80)
    print(roi_df.to_string())
    print("=" * 80)
    
    # Heatmap
    fig, ax = plt.subplots(figsize=(12, 6))
    sns.heatmap(roi_df, annot=True, fmt='.3f', cmap='YlOrRd', 
                vmin=0, vmax=0.15, ax=ax, cbar_kws={'label': 'R²'})
    ax.set_xlabel('Layer', fontsize=12)
    ax.set_ylabel('Brain Region (ROI)', fontsize=12)
    ax.set_title('Encoding Quality by Brain Region and Layer', fontsize=14, fontweight='bold')
    plt.tight_layout()
    
    # Save
    fig_path = encoding_output_dir.parent / 'encoding_roi_heatmap.png'
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    print(f"\n✓ Saved ROI heatmap: {fig_path}")
    
    plt.show()
else:
    print("No results for ROI analysis.")

## Summary

In this notebook, we performed brain encoding analysis:

✅ **Loaded RL activations**: PCA-reduced representations from 5 CNN layers

✅ **Prepared BOLD data**: Deconfounding, masking, standardization

✅ **Aligned data**: Matched activation timepoints to fMRI TRs

✅ **Fitted encoding models**: Ridge regression per layer with cross-validation

✅ **Compared layers**: Identified which layer best predicts brain activity

✅ **Created R² maps**: Visualized encoding quality across the brain

✅ **Analyzed regions**: Examined layer-specific encoding in different brain areas

### Key findings:
- **Best layer**: Typically intermediate layers (conv3/conv4) perform best
- **Visual cortex**: Early layers (conv1/conv2) encode visual areas
- **Motor cortex**: Middle layers encode action-related regions
- **Frontal regions**: Late layers (linear) encode higher-level strategy

### Interpretation:
The hierarchical organization of the CNN mirrors the hierarchical organization of the brain:
- **Early layers** ↔ **Visual cortex**: Low-level features (edges, colors)
- **Middle layers** ↔ **Parietal/Motor**: Spatial and action representations
- **Late layers** ↔ **Prefrontal**: Abstract strategy and value

### Output files:
- Layer comparison: `encoding_layer_comparison.png`
- Best layer detailed: `encoding_{layer}_detailed.png`
- Predictions: `encoding_{layer}_predictions.png`
- ROI heatmap: `encoding_roi_heatmap.png`
- R² maps: `sub-01_ses-010_layer-{layer}_r2.nii.gz`

### Next steps:
In **Notebook 06**, we'll summarize the entire analysis pipeline and discuss extensions for future work.