# EEG Motor Imagery Model Comparison

This notebook compares different deep learning models (EEGNet, ATCNet, etc.) for EEG motor imagery classification using Leave-One-Subject-Out (LOSO) cross-validation.

**Sections:**
1. Setup & Configuration
2. Preprocessed Data Visualization
3. Training Curves (Averaged Over Folds)
4. Detailed Test Performance
5. Model Comparison Dashboard


In [None]:
import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay
import seaborn as sns

# Configure matplotlib for clean visuals
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['axes.labelsize'] = 12


## 1. Setup & Configuration

Define experiment directories for each model. To add a new model, simply add an entry to `EXPERIMENTS`.


In [None]:
# ============================================================
# CONFIGURATION - Edit these paths to match your experiments
# ============================================================

EXPERIMENTS: Dict[str, str] = {
    "EEGNet": "../reports/loso/eegnet_experiment",
    "ATCNet": "../reports/loso/atcnet_experiment",
    # Add future models here:
    # "ShallowConvNet": "../reports/loso/shallowconv_experiment",
    # "DeepConvNet": "../reports/loso/deepconv_experiment",
}

DATA_DIR = Path("../data/processed")
SUBJECTS = [f"A0{i}T" for i in range(1, 10)]  # A01T to A09T
CLASS_NAMES = ["Left Hand", "Right Hand"]

# Aesthetic configuration
MODEL_COLORS = {
    "EEGNet": "#2E86AB",     # Steel blue
    "ATCNet": "#A23B72",     # Raspberry
    "ShallowConvNet": "#F18F01",  # Orange
    "DeepConvNet": "#C73E1D",     # Red
}

def get_model_color(model_name: str) -> str:
    """Get color for model, with fallback for unknown models."""
    return MODEL_COLORS.get(model_name, "#666666")


In [None]:
# ============================================================
# HELPER FUNCTIONS
# ============================================================

def load_fold_histories(experiment_dir: Path) -> Dict[str, List[Dict]]:
    """Load training history from all folds in an experiment.
    
    Returns:
        Dict mapping subject ID to list of epoch records.
    """
    histories = {}
    for subject in SUBJECTS:
        history_path = experiment_dir / f"fold_{subject}" / "history.json"
        if history_path.exists():
            with open(history_path) as f:
                histories[subject] = json.load(f)
    return histories


def load_fold_predictions(experiment_dir: Path) -> Dict[str, Dict[str, np.ndarray]]:
    """Load test predictions from all folds in an experiment.
    
    Returns:
        Dict mapping subject ID to dict with 'y_pred' and 'y_true' arrays.
    """
    predictions = {}
    for subject in SUBJECTS:
        pred_path = experiment_dir / f"fold_{subject}" / "test_predictions.npz"
        if pred_path.exists():
            data = np.load(pred_path, allow_pickle=True)
            predictions[subject] = {
                "y_pred": data["y_pred"],
                "y_true": data["y_true"],
            }
    return predictions


def load_summary(experiment_dir: Path) -> Optional[Dict]:
    """Load LOSO summary JSON."""
    summary_path = experiment_dir / "loso_summary.json"
    if summary_path.exists():
        with open(summary_path) as f:
            return json.load(f)
    return None


def aggregate_histories(
    histories: Dict[str, List[Dict]], 
    metric: str,
    pad_strategy: str = "final"
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Aggregate a metric across all folds, handling variable epoch lengths.
    
    Args:
        histories: Dict mapping subject to list of epoch dicts
        metric: Key to extract (e.g., 'train_loss', 'val_acc')
        pad_strategy: 'final' pads with last value, 'nan' pads with NaN
        
    Returns:
        epochs: Array of epoch numbers [1, 2, ..., max_epochs]
        mean: Mean value per epoch
        std: Std value per epoch
    """
    if not histories:
        return np.array([]), np.array([]), np.array([])
    
    # Extract metric series from each fold
    series_list = []
    for subject, history in histories.items():
        values = [epoch_dict[metric] for epoch_dict in history]
        series_list.append(values)
    
    # Find max length
    max_len = max(len(s) for s in series_list)
    
    # Pad to equal length
    padded = []
    for series in series_list:
        if len(series) < max_len:
            if pad_strategy == "final":
                pad_value = series[-1]
            else:
                pad_value = np.nan
            series = series + [pad_value] * (max_len - len(series))
        padded.append(series)
    
    arr = np.array(padded)  # (n_folds, max_epochs)
    epochs = np.arange(1, max_len + 1)
    
    if pad_strategy == "final":
        mean = arr.mean(axis=0)
        std = arr.std(axis=0)
    else:
        mean = np.nanmean(arr, axis=0)
        std = np.nanstd(arr, axis=0)
    
    return epochs, mean, std


def compute_aggregate_metrics(predictions: Dict[str, Dict]) -> Dict:
    """Compute aggregate metrics from all fold predictions.
    
    Returns:
        Dict with per-fold accuracies, confusion matrix, and classification report.
    """
    all_y_true = []
    all_y_pred = []
    fold_accuracies = {}
    
    for subject, data in predictions.items():
        y_true = data["y_true"]
        y_pred = data["y_pred"]
        all_y_true.extend(y_true)
        all_y_pred.extend(y_pred)
        fold_accuracies[subject] = (y_true == y_pred).mean()
    
    all_y_true = np.array(all_y_true)
    all_y_pred = np.array(all_y_pred)
    
    return {
        "fold_accuracies": fold_accuracies,
        "overall_accuracy": (all_y_true == all_y_pred).mean(),
        "confusion_matrix": confusion_matrix(all_y_true, all_y_pred),
        "classification_report": classification_report(
            all_y_true, all_y_pred, target_names=CLASS_NAMES, output_dict=True
        ),
    }


In [None]:
# Validate experiment directories
print("Checking experiment directories...\n")
valid_experiments = {}

for name, path in EXPERIMENTS.items():
    exp_path = Path(path)
    if exp_path.exists():
        summary = load_summary(exp_path)
        if summary:
            n_folds = summary.get("n_folds", "?")
            mean_acc = summary.get("mean_test_acc", 0) * 100
            std_acc = summary.get("std_test_acc", 0) * 100
            print(f"✓ {name}: {n_folds} folds, {mean_acc:.1f}% ± {std_acc:.1f}%")
            valid_experiments[name] = exp_path
        else:
            print(f"⚠ {name}: Directory exists but no summary found")
    else:
        print(f"✗ {name}: Directory not found ({path})")

if not valid_experiments:
    print("\n❌ No valid experiments found! Run training first:")
    print("   python -m src.train --config configs/eegnet.yaml --output-dir reports/loso/eegnet_experiment")
    print("   python -m src.train --config configs/atcnet.yaml --output-dir reports/loso/atcnet_experiment")
else:
    print(f"\n✓ Found {len(valid_experiments)} valid experiment(s)")


---
## 2. Preprocessed Data Visualization

Visualize the preprocessed EEG data to understand what the models are learning from.


In [None]:
# Load sample subject data
sample_subject = "A01T"
sample_path = DATA_DIR / f"{sample_subject}.npz"

if sample_path.exists():
    data = np.load(sample_path, allow_pickle=True)
    X = data["X"]  # (trials, channels, time)
    y = data["y"]
    
    n_trials, n_channels, n_samples = X.shape
    sfreq = 250  # BCI Competition IV 2a sampling frequency
    
    print(f"Subject: {sample_subject}")
    print(f"Shape: {X.shape} (trials × channels × samples)")
    print(f"Sampling frequency: {sfreq} Hz")
    print(f"Trial duration: {n_samples / sfreq:.2f} seconds")
    print(f"Class distribution: {dict(zip(*np.unique(y, return_counts=True)))}")
else:
    print(f"❌ Sample data not found at {sample_path}")
    X, y = None, None


In [None]:
if X is not None:
    # Select one trial per class
    trial_left = X[np.where(y == 0)[0][0]]  # First left-hand trial
    trial_right = X[np.where(y == 1)[0][0]]  # First right-hand trial
    
    time = np.arange(n_samples) / sfreq
    
    # Channel names for BCI Competition IV 2a (22 EEG channels)
    channel_names = [
        'Fz', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'Cz', 'C2',
        'C4', 'C6', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'P1', 'Pz', 'P2', 'POz'
    ]
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 10))
    
    for ax, trial, label in zip(axes, [trial_left, trial_right], CLASS_NAMES):
        # Plot each channel with offset for visibility
        offset = 0
        for i in range(n_channels):
            ch_data = trial[i] - trial[i].mean()  # Center each channel
            ax.plot(time, ch_data + offset, linewidth=0.5, color='#1a1a2e')
            offset += np.abs(ch_data).max() * 2.5
        
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Channel')
        ax.set_title(f'{label} Motor Imagery', fontweight='bold')
        ax.set_yticks(np.linspace(0, offset, n_channels))
        ax.set_yticklabels(channel_names[:n_channels], fontsize=8)
        ax.set_xlim(0, time[-1])
    
    plt.suptitle(f'EEG Time Series - Subject {sample_subject}', fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()


In [None]:
if X is not None:
    # Compute and plot spectrograms for motor-relevant channels (C3, Cz, C4)
    motor_channels = {'C3': 7, 'Cz': 9, 'C4': 11}  # Indices for motor cortex channels
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    
    for row, (trial, label) in enumerate([(trial_left, "Left Hand"), (trial_right, "Right Hand")]):
        for col, (ch_name, ch_idx) in enumerate(motor_channels.items()):
            ax = axes[row, col]
            
            # Compute spectrogram
            f, t, Sxx = signal.spectrogram(
                trial[ch_idx], fs=sfreq, nperseg=64, noverlap=56
            )
            
            # Plot only 4-40 Hz (relevant for motor imagery)
            freq_mask = (f >= 4) & (f <= 40)
            im = ax.pcolormesh(
                t, f[freq_mask], 10 * np.log10(Sxx[freq_mask] + 1e-10),
                shading='gouraud', cmap='viridis'
            )
            
            ax.set_ylabel('Frequency (Hz)' if col == 0 else '')
            ax.set_xlabel('Time (s)' if row == 1 else '')
            ax.set_title(f'{ch_name} - {label}', fontweight='bold')
            
            # Highlight mu (8-12 Hz) and beta (18-26 Hz) bands
            for band, (lo, hi) in [("μ", (8, 12)), ("β", (18, 26))]:
                ax.axhline(lo, color='white', linestyle='--', alpha=0.5, linewidth=0.5)
                ax.axhline(hi, color='white', linestyle='--', alpha=0.5, linewidth=0.5)
    
    plt.colorbar(im, ax=axes, label='Power (dB)', shrink=0.8)
    plt.suptitle(f'Spectrograms - Motor Cortex Channels ({sample_subject})', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


---
## 3. Training Curves (Averaged Over Folds)

Visualize training dynamics across all 9 LOSO folds.


In [None]:
def plot_training_curves(experiments: Dict[str, Path], metric_pairs: List[Tuple[str, str]]):
    """Plot training curves for multiple experiments.
    
    Args:
        experiments: Dict mapping model name to experiment path
        metric_pairs: List of (metric_name, display_name) tuples
    """
    n_models = len(experiments)
    n_metrics = len(metric_pairs)
    
    if n_models == 0:
        print("No experiments to plot.")
        return
    
    fig, axes = plt.subplots(n_metrics, n_models, figsize=(6 * n_models, 4 * n_metrics), squeeze=False)
    
    for col, (model_name, exp_path) in enumerate(experiments.items()):
        histories = load_fold_histories(exp_path)
        color = get_model_color(model_name)
        
        for row, (metric, display_name) in enumerate(metric_pairs):
            ax = axes[row, col]
            
            if not histories:
                ax.text(0.5, 0.5, "No data", ha='center', va='center', transform=ax.transAxes)
                continue
            
            epochs, mean, std = aggregate_histories(histories, metric)
            
            ax.plot(epochs, mean, color=color, linewidth=2, label=f'Mean ({len(histories)} folds)')
            ax.fill_between(epochs, mean - std, mean + std, color=color, alpha=0.2, label='±1 Std')
            
            ax.set_xlabel('Epoch')
            ax.set_ylabel(display_name)
            ax.set_title(f'{model_name} - {display_name}', fontweight='bold')
            ax.legend(loc='best', fontsize=9)
            ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()


In [None]:
# Plot loss and accuracy curves for all models
if valid_experiments:
    plot_training_curves(
        valid_experiments,
        metric_pairs=[
            ("train_loss", "Training Loss"),
            ("val_loss", "Validation Loss"),
            ("train_acc", "Training Accuracy"),
            ("val_acc", "Validation Accuracy"),
        ]
    )
else:
    print("No valid experiments found.")


---
## 4. Detailed Test Performance

Per-fold accuracy, confusion matrices, and classification metrics.


In [None]:
def plot_fold_accuracies(experiments: Dict[str, Path]):
    """Plot per-fold test accuracies for each model."""
    n_models = len(experiments)
    if n_models == 0:
        print("No experiments to plot.")
        return
    
    fig, axes = plt.subplots(1, n_models, figsize=(6 * n_models, 5), squeeze=False)
    
    for idx, (model_name, exp_path) in enumerate(experiments.items()):
        ax = axes[0, idx]
        predictions = load_fold_predictions(exp_path)
        
        if not predictions:
            ax.text(0.5, 0.5, "No predictions found", ha='center', va='center', transform=ax.transAxes)
            continue
        
        metrics = compute_aggregate_metrics(predictions)
        fold_accs = metrics["fold_accuracies"]
        
        subjects = list(fold_accs.keys())
        accs = [fold_accs[s] * 100 for s in subjects]
        mean_acc = np.mean(accs)
        
        color = get_model_color(model_name)
        bars = ax.bar(subjects, accs, color=color, alpha=0.8, edgecolor='black', linewidth=0.5)
        ax.axhline(mean_acc, color='#333', linestyle='--', linewidth=2, label=f'Mean: {mean_acc:.1f}%')
        
        # Add value labels on bars
        for bar, acc in zip(bars, accs):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                    f'{acc:.1f}', ha='center', va='bottom', fontsize=9)
        
        ax.set_xlabel('Held-out Subject')
        ax.set_ylabel('Test Accuracy (%)')
        ax.set_title(f'{model_name} - Per-Fold Accuracy', fontweight='bold')
        ax.set_ylim(0, 110)
        ax.legend(loc='lower right')
        ax.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()


In [None]:
if valid_experiments:
    plot_fold_accuracies(valid_experiments)
else:
    print("No valid experiments found.")


In [None]:
def plot_confusion_matrices(experiments: Dict[str, Path]):
    """Plot aggregated confusion matrices for each model."""
    n_models = len(experiments)
    if n_models == 0:
        print("No experiments to plot.")
        return
    
    fig, axes = plt.subplots(1, n_models, figsize=(5 * n_models, 4), squeeze=False)
    
    for idx, (model_name, exp_path) in enumerate(experiments.items()):
        ax = axes[0, idx]
        predictions = load_fold_predictions(exp_path)
        
        if not predictions:
            ax.text(0.5, 0.5, "No predictions found", ha='center', va='center', transform=ax.transAxes)
            continue
        
        metrics = compute_aggregate_metrics(predictions)
        cm = metrics["confusion_matrix"]
        
        # Normalize to percentages
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
        
        sns.heatmap(
            cm_normalized, annot=True, fmt='.1f', cmap='Blues',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
            ax=ax, cbar_kws={'label': '%'}
        )
        
        # Add raw counts in smaller text
        for i in range(2):
            for j in range(2):
                ax.text(j + 0.5, i + 0.75, f'(n={cm[i, j]})', 
                       ha='center', va='center', fontsize=8, color='gray')
        
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        ax.set_title(f'{model_name}\nAcc: {metrics["overall_accuracy"]*100:.1f}%', fontweight='bold')
    
    plt.suptitle('Confusion Matrices (All Folds Aggregated)', fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()


In [None]:
if valid_experiments:
    plot_confusion_matrices(valid_experiments)
else:
    print("No valid experiments found.")


In [None]:
def print_classification_reports(experiments: Dict[str, Path]):
    """Print classification reports for each model."""
    for model_name, exp_path in experiments.items():
        predictions = load_fold_predictions(exp_path)
        
        if not predictions:
            print(f"\n{model_name}: No predictions found")
            continue
        
        metrics = compute_aggregate_metrics(predictions)
        report = metrics["classification_report"]
        
        print(f"\n{'='*50}")
        print(f" {model_name} - Classification Report")
        print(f"{'='*50}")
        print(f"{'Class':<15} {'Precision':>10} {'Recall':>10} {'F1-Score':>10} {'Support':>10}")
        print("-" * 55)
        
        for class_name in CLASS_NAMES:
            cls = report[class_name]
            print(f"{class_name:<15} {cls['precision']:>10.3f} {cls['recall']:>10.3f} {cls['f1-score']:>10.3f} {int(cls['support']):>10}")
        
        print("-" * 55)
        acc = report['accuracy']
        macro = report['macro avg']
        print(f"{'Accuracy':<15} {'':<10} {'':<10} {acc:>10.3f} {int(report['weighted avg']['support']):>10}")
        print(f"{'Macro Avg':<15} {macro['precision']:>10.3f} {macro['recall']:>10.3f} {macro['f1-score']:>10.3f}")


In [None]:
if valid_experiments:
    print_classification_reports(valid_experiments)
else:
    print("No valid experiments found.")


---
## 5. Model Comparison Dashboard

Direct side-by-side comparison of all models.


In [None]:
def plot_comparison_bar_chart(experiments: Dict[str, Path]):
    """Create grouped bar chart comparing per-fold accuracy across models."""
    if len(experiments) < 2:
        print("Need at least 2 experiments for comparison.")
        return
    
    # Collect data
    model_data = {}
    all_subjects = set()
    
    for model_name, exp_path in experiments.items():
        predictions = load_fold_predictions(exp_path)
        if predictions:
            metrics = compute_aggregate_metrics(predictions)
            model_data[model_name] = metrics["fold_accuracies"]
            all_subjects.update(metrics["fold_accuracies"].keys())
    
    if not model_data:
        print("No prediction data found.")
        return
    
    subjects = sorted(all_subjects)
    n_models = len(model_data)
    n_subjects = len(subjects)
    
    # Create grouped bar chart
    fig, ax = plt.subplots(figsize=(14, 6))
    
    bar_width = 0.8 / n_models
    x = np.arange(n_subjects)
    
    for i, (model_name, fold_accs) in enumerate(model_data.items()):
        accs = [fold_accs.get(s, 0) * 100 for s in subjects]
        offset = (i - n_models / 2 + 0.5) * bar_width
        color = get_model_color(model_name)
        
        bars = ax.bar(x + offset, accs, bar_width, label=model_name, color=color, 
                      alpha=0.85, edgecolor='black', linewidth=0.5)
    
    ax.set_xlabel('Held-out Subject', fontsize=12)
    ax.set_ylabel('Test Accuracy (%)', fontsize=12)
    ax.set_title('Per-Subject Accuracy Comparison', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(subjects)
    ax.set_ylim(0, 100)
    ax.legend(loc='lower right', fontsize=10)
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()


In [None]:
if len(valid_experiments) >= 2:
    plot_comparison_bar_chart(valid_experiments)
elif len(valid_experiments) == 1:
    print("Only one experiment found. Run training for another model to enable comparison.")
else:
    print("No valid experiments found.")


In [None]:
def plot_aggregate_summary_table(experiments: Dict[str, Path]):
    """Display summary statistics table for all models."""
    if not experiments:
        print("No experiments found.")
        return
    
    rows = []
    
    for model_name, exp_path in experiments.items():
        summary = load_summary(exp_path)
        predictions = load_fold_predictions(exp_path)
        
        if summary:
            row = {
                "Model": model_name,
                "Folds": summary.get("n_folds", "?"),
                "Mean Acc": f"{summary.get('mean_test_acc', 0)*100:.2f}%",
                "Std": f"±{summary.get('std_test_acc', 0)*100:.2f}%",
                "Min": f"{summary.get('min_test_acc', 0)*100:.2f}%",
                "Max": f"{summary.get('max_test_acc', 0)*100:.2f}%",
            }
            
            if predictions:
                metrics = compute_aggregate_metrics(predictions)
                report = metrics["classification_report"]
                row["Macro F1"] = f"{report['macro avg']['f1-score']:.3f}"
            else:
                row["Macro F1"] = "N/A"
            
            rows.append(row)
    
    if not rows:
        print("No summary data available.")
        return
    
    # Print as formatted table
    print("\n" + "=" * 85)
    print(" MODEL COMPARISON SUMMARY")
    print("=" * 85)
    
    headers = ["Model", "Folds", "Mean Acc", "Std", "Min", "Max", "Macro F1"]
    header_fmt = f"{'Model':<15} {'Folds':>6} {'Mean Acc':>10} {'Std':>10} {'Min':>10} {'Max':>10} {'Macro F1':>10}"
    print(header_fmt)
    print("-" * 85)
    
    for row in rows:
        print(f"{row['Model']:<15} {row['Folds']:>6} {row['Mean Acc']:>10} {row['Std']:>10} {row['Min']:>10} {row['Max']:>10} {row['Macro F1']:>10}")
    
    print("=" * 85)


In [None]:
plot_aggregate_summary_table(valid_experiments)


In [None]:
def plot_overlaid_training_curves(experiments: Dict[str, Path], metric: str, display_name: str):
    """Plot training curves for all models on the same axes."""
    if not experiments:
        print("No experiments found.")
        return
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    for model_name, exp_path in experiments.items():
        histories = load_fold_histories(exp_path)
        
        if not histories:
            continue
        
        epochs, mean, std = aggregate_histories(histories, metric)
        color = get_model_color(model_name)
        
        ax.plot(epochs, mean, color=color, linewidth=2.5, label=model_name)
        ax.fill_between(epochs, mean - std, mean + std, color=color, alpha=0.15)
    
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel(display_name, fontsize=12)
    ax.set_title(f'{display_name} Comparison (Mean ± Std across folds)', fontsize=14, fontweight='bold')
    ax.legend(loc='best', fontsize=11)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()


In [None]:
if valid_experiments:
    plot_overlaid_training_curves(valid_experiments, "val_acc", "Validation Accuracy")
    plot_overlaid_training_curves(valid_experiments, "val_loss", "Validation Loss")
else:
    print("No valid experiments found.")


---
## Next Steps

To add a new model to this comparison:

1. **Create a config file** (e.g., `configs/shallowconv.yaml`)

2. **Run training with explicit output directory:**
   ```bash
   python -m src.train --config configs/shallowconv.yaml --output-dir reports/loso/shallowconv_experiment
   ```

3. **Add to EXPERIMENTS dict** at the top of this notebook:
   ```python
   EXPERIMENTS = {
       "EEGNet": "../reports/loso/eegnet_experiment",
       "ATCNet": "../reports/loso/atcnet_experiment",
       "ShallowConvNet": "../reports/loso/shallowconv_experiment",  # New!
   }
   ```

4. **Re-run all cells** to see the new model in comparisons.
