1. EEGNet - compact CNN designed for EEG (Lawhern et al., 2018)
    - learns frequency filters 
    - learns spatial filters like CSP
    - efficient feature combo
    - only 2k parameters, reduces overfitting on small datasets
2. Training with early stopping
3. Comparison with classical CSP+SVM
4. Learned feature visualization

In [None]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, cohen_kappa_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / 'src'))

from models import EEGNet, EEGNetTrainer, create_eeg_dataloaders
from visualization import set_style, plot_training_history, plot_confusion_matrix, CLASS_NAMES

set_style()

# Set seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# Device
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

PROCESSED_DIR = Path('../data/processed')
RESULTS_DIR = Path('../results')
FIGURES_DIR = Path('../figures')
RESULTS_DIR.mkdir(exist_ok=True)
FIGURES_DIR.mkdir(exist_ok=True)

In [None]:
# Load preprocessed data
data = np.load(PROCESSED_DIR / 'preprocessed_data.npz', allow_pickle=True)

X_train_all = data['X_train']
y_train_all = data['y_train']
subjects_train = data['subjects_train']

X_test_all = data['X_test']
y_test_all = data['y_test']
subjects_test = data['subjects_test']

n_channels = X_train_all.shape[1]
n_times = X_train_all.shape[2]
n_classes = len(np.unique(y_train_all))

print(f"Training data: {X_train_all.shape}")
print(f"Test data: {X_test_all.shape}")
print(f"Channels: {n_channels}, Time points: {n_times}, Classes: {n_classes}")

Architecture

Input: (batch, 1, channels, times)
  │
  ├─► Conv2D(1→F1, kernel=(1, 64))      # Temporal filtering
  ├─► BatchNorm
  │
  ├─► DepthwiseConv2D(F1→F1*D, kernel=(channels, 1))  # Spatial filtering (like CSP)
  ├─► BatchNorm → ELU → AvgPool(1,4) → Dropout
  │
  ├─► SeparableConv2D(F1*D→F2, kernel=(1, 16))  # Feature combination
  ├─► BatchNorm → ELU → AvgPool(1,8) → Dropout
  │
  └─► Flatten → Dense(n_classes)

In [None]:
# Inspect EEGNet architecture
model_example = EEGNet(
    n_channels=n_channels,
    n_times=n_times,
    n_classes=n_classes,
    F1=8,      # Temporal filters
    D=2,       # Depth multiplier
    F2=16,     # Pointwise filters
    dropout=0.5
)

print(model_example)
print(f"\nTotal parameters: {sum(p.numel() for p in model_example.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model_example.parameters() if p.requires_grad):,}")

In [None]:
# Test forward pass
x_test = torch.randn(4, n_channels, n_times)  # batch of 4
with torch.no_grad():
    out = model_example(x_test)
print(f"Input shape: {x_test.shape}")
print(f"Output shape: {out.shape}")
print(f"Output (logits): {out[0]}")

In [None]:
# Within-subject training (same as classical approach)
# Hyperparameters
HPARAMS = {
    'F1': 8,
    'D': 2,
    'F2': 16,
    'dropout': 0.5,
    'learning_rate': 1e-3,
    'batch_size': 32,
    'epochs': 200,
    'early_stopping_patience': 20,
    'val_split': 0.2,  # Use 20% of training for validation
}

print("Hyperparameters:")
for k, v in HPARAMS.items():
    print(f"  {k}: {v}")

In [None]:
def train_subject_eegnet(subj_id, X_train_all, y_train_all, subjects_train,
                         X_test_all, y_test_all, subjects_test, hparams, device='cpu'):
    """
    Train EEGNet for a single subject.
    Returns dict with results and trained model.
    """
    # Get subject data
    mask_train = subjects_train == subj_id
    mask_test = subjects_test == subj_id
    
    X_train = X_train_all[mask_train].astype(np.float32)
    y_train = y_train_all[mask_train].astype(np.int64)
    X_test = X_test_all[mask_test].astype(np.float32)
    y_test = y_test_all[mask_test].astype(np.int64)
    
    # Normalize per channel (z-score)
    mean = X_train.mean(axis=(0, 2), keepdims=True)
    std = X_train.std(axis=(0, 2), keepdims=True) + 1e-8
    X_train = (X_train - mean) / std
    X_test = (X_test - mean) / std
    
    # Split training into train/val
    X_tr, X_val, y_tr, y_val = train_test_split(
        X_train, y_train, 
        test_size=hparams['val_split'],
        stratify=y_train,
        random_state=SEED
    )
    
    # Create dataloaders
    train_loader, val_loader = create_eeg_dataloaders(
        X_tr, y_tr, X_val, y_val,
        batch_size=hparams['batch_size']
    )
    
    # Create model
    model = EEGNet(
        n_channels=X_train.shape[1],
        n_times=X_train.shape[2],
        n_classes=len(np.unique(y_train)),
        F1=hparams['F1'],
        D=hparams['D'],
        F2=hparams['F2'],
        dropout=hparams['dropout']
    )
    
    # Train
    trainer = EEGNetTrainer(
        model=model,
        device=device,
        learning_rate=hparams['learning_rate']
    )
    
    history = trainer.fit(
        train_loader, val_loader,
        epochs=hparams['epochs'],
        early_stopping_patience=hparams['early_stopping_patience'],
        verbose=False
    )
    
    # Evaluate on test set
    y_pred = trainer.predict(X_test)
    
    return {
        'subject': subj_id,
        'n_train': len(y_train),
        'n_test': len(y_test),
        'accuracy': accuracy_score(y_test, y_pred),
        'kappa': cohen_kappa_score(y_test, y_pred),
        'y_true': y_test,
        'y_pred': y_pred,
        'history': history,
        'model': model,
        'best_val_acc': max(history['val_acc']),
        'epochs_trained': len(history['train_loss'])
    }

In [None]:
# Train on all subjects
all_results = []

print("Training EEGNet for each subject...")
print("=" * 70)
print(f"{'Subject':<10} {'Train':<8} {'Test':<8} {'Epochs':<10} {'Val Acc':<12} {'Test Acc':<12} {'Kappa':<10}")
print("-" * 70)

for subj in range(1, 10):
    results = train_subject_eegnet(
        subj, X_train_all, y_train_all, subjects_train,
        X_test_all, y_test_all, subjects_test,
        HPARAMS, device=DEVICE
    )
    all_results.append(results)
    
    print(f"{subj:<10} {results['n_train']:<8} {results['n_test']:<8} "
          f"{results['epochs_trained']:<10} {results['best_val_acc']:<12.2%} "
          f"{results['accuracy']:<12.2%} {results['kappa']:<10.3f}")

print("-" * 70)

# Averages
avg_acc = np.mean([r['accuracy'] for r in all_results])
avg_kappa = np.mean([r['kappa'] for r in all_results])
std_acc = np.std([r['accuracy'] for r in all_results])

print(f"{'Mean':<10} {'':<8} {'':<8} {'':<10} {'':<12} {avg_acc:<12.2%} {avg_kappa:<10.3f}")
print(f"{'Std':<10} {'':<8} {'':<8} {'':<10} {'':<12} {std_acc:<12.2%}")
print("=" * 70)

In [None]:
# Plot training curves for all subjects
fig, axes = plt.subplots(3, 3, figsize=(14, 10))
axes = axes.flatten()

for idx, r in enumerate(all_results):
    ax = axes[idx]
    epochs = range(1, len(r['history']['train_acc']) + 1)
    
    ax.plot(epochs, r['history']['train_acc'], 'b-', label='Train', alpha=0.7)
    ax.plot(epochs, r['history']['val_acc'], 'r-', label='Val', alpha=0.7)
    ax.axhline(r['accuracy'], color='g', linestyle='--', label=f'Test: {r["accuracy"]:.1%}')
    ax.axhline(0.25, color='gray', linestyle=':', alpha=0.5)
    
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Accuracy')
    ax.set_title(f'Subject {r["subject"]}')
    ax.set_ylim(0, 1)
    ax.legend(loc='lower right', fontsize=8)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(FIGURES_DIR / 'eegnet_training_curves.png', dpi=150, bbox_inches='tight');

In [None]:
# Aggregate confusion matrix
y_true_all_test = np.concatenate([r['y_true'] for r in all_results])
y_pred_all_test = np.concatenate([r['y_pred'] for r in all_results])

fig, ax = plt.subplots(figsize=(8, 6))

cm = confusion_matrix(y_true_all_test, y_pred_all_test, normalize='true')
sns.heatmap(cm, annot=True, fmt='.2%', cmap='Blues',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
            ax=ax, square=True, cbar_kws={'label': 'Proportion'})
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title(f'EEGNet Confusion Matrix (Acc: {avg_acc:.1%})')

plt.tight_layout()
plt.savefig(FIGURES_DIR / 'eegnet_confusion_matrix.png', dpi=150, bbox_inches='tight');

In [None]:
# Per-class accuracy
print("Per-class performance (EEGNet):")
print("=" * 50)
print(classification_report(y_true_all_test, y_pred_all_test,
                            target_names=CLASS_NAMES, digits=3))

In [None]:
# Load classical results
classical_results_path = RESULTS_DIR / 'classical_ml_results.json'

if classical_results_path.exists():
    with open(classical_results_path) as f:
        classical_results = json.load(f)
    
    # Compare per-subject
    comparison_data = []
    for i, r in enumerate(all_results):
        comparison_data.append({
            'Subject': f'S{r["subject"]}',
            'CSP+LDA': classical_results['lda']['per_subject_accuracy'][i],
            'CSP+SVM': classical_results['svm']['per_subject_accuracy'][i],
            'EEGNet': r['accuracy']
        })
    
    import pandas as pd
    df_comparison = pd.DataFrame(comparison_data)
    print("Per-subject comparison:")
    print(df_comparison.to_string(index=False))
    
    # Summary
    print(f"\nMean accuracy:")
    print(f"  CSP+LDA: {classical_results['lda']['mean_accuracy']:.2%}")
    print(f"  CSP+SVM: {classical_results['svm']['mean_accuracy']:.2%}")
    print(f"  EEGNet:  {avg_acc:.2%}")
else:
    print("Run notebook 03 first to get classical results for comparison")
    df_comparison = None

In [None]:
# Visualization of comparison
if df_comparison is not None:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Bar chart per subject
    x = np.arange(len(df_comparison))
    width = 0.25
    
    axes[0].bar(x - width, df_comparison['CSP+LDA'], width, label='CSP+LDA', color='steelblue')
    axes[0].bar(x, df_comparison['CSP+SVM'], width, label='CSP+SVM', color='coral')
    axes[0].bar(x + width, df_comparison['EEGNet'], width, label='EEGNet', color='forestgreen')
    
    axes[0].axhline(0.25, color='gray', linestyle=':', label='Chance')
    axes[0].set_xlabel('Subject')
    axes[0].set_ylabel('Accuracy')
    axes[0].set_title('Method Comparison by Subject')
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(df_comparison['Subject'])
    axes[0].set_ylim(0, 1)
    axes[0].legend()
    
    # Summary bar chart
    methods = ['CSP+LDA', 'CSP+SVM', 'EEGNet']
    means = [
        classical_results['lda']['mean_accuracy'],
        classical_results['svm']['mean_accuracy'],
        avg_acc
    ]
    stds = [
        classical_results['lda']['std_accuracy'],
        classical_results['svm']['std_accuracy'],
        std_acc
    ]
    colors = ['steelblue', 'coral', 'forestgreen']
    
    bars = axes[1].bar(methods, means, yerr=stds, capsize=5, color=colors, edgecolor='black')
    axes[1].axhline(0.25, color='gray', linestyle=':', label='Chance')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Mean Accuracy Comparison')
    axes[1].set_ylim(0, 1)
    
    # Add value labels
    for bar, mean, std in zip(bars, means, stds):
        axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + std + 0.02,
                     f'{mean:.1%}', ha='center', va='bottom', fontsize=11)
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'method_comparison.png', dpi=150, bbox_inches='tight');

In [None]:
# Visualize learned features and examine what EEGNet learned in its filters
# Get the best performing subject's model
best_subj_idx = np.argmax([r['accuracy'] for r in all_results])
best_model = all_results[best_subj_idx]['model']
best_subj = all_results[best_subj_idx]['subject']

print(f"Analyzing model from Subject {best_subj} (accuracy: {all_results[best_subj_idx]['accuracy']:.1%})")

In [None]:
# Visualize temporal filters (first conv layer)
conv1_weights = best_model.conv1.weight.detach().cpu().numpy()
print(f"Temporal filter shape: {conv1_weights.shape}")
print(f"  (F1 filters, 1 input channel, 1 height, kernel_length)")

fig, axes = plt.subplots(2, 4, figsize=(14, 5))
axes = axes.flatten()

for i in range(min(8, conv1_weights.shape[0])):
    kernel = conv1_weights[i, 0, 0, :]
    axes[i].plot(kernel, 'b-', linewidth=1.5)
    axes[i].set_title(f'Temporal Filter {i+1}')
    axes[i].set_xlabel('Sample')
    axes[i].axhline(0, color='gray', linestyle='--', alpha=0.5)
    axes[i].grid(True, alpha=0.3)

fig.suptitle('Learned Temporal Filters (Conv1)', y=1.02)
plt.tight_layout();

In [None]:
# Frequency response of temporal filters
from scipy import signal

fig, axes = plt.subplots(2, 4, figsize=(14, 5))
axes = axes.flatten()

sfreq = 250  # Hz

for i in range(min(8, conv1_weights.shape[0])):
    kernel = conv1_weights[i, 0, 0, :]
    
    # Compute frequency response
    freqs, response = signal.freqz(kernel, worN=512, fs=sfreq)
    
    axes[i].plot(freqs, np.abs(response), 'b-', linewidth=1.5)
    axes[i].set_title(f'Filter {i+1}')
    axes[i].set_xlabel('Frequency (Hz)')
    axes[i].set_ylabel('Magnitude')
    axes[i].set_xlim(0, 50)
    axes[i].axvspan(8, 12, alpha=0.2, color='blue', label='mu')
    axes[i].axvspan(13, 30, alpha=0.2, color='red', label='beta')
    axes[i].grid(True, alpha=0.3)

fig.suptitle('Frequency Response of Learned Temporal Filters', y=1.02)
plt.tight_layout();

print("\nNote: Filters should show sensitivity to mu (8-12 Hz) and beta (13-30 Hz) bands")

In [None]:
# Visualize spatial filters (depthwise conv - like CSP)
import mne
from preprocessing import CHANNEL_NAMES

conv2_weights = best_model.conv2.weight.detach().cpu().numpy()
print(f"Spatial filter shape: {conv2_weights.shape}")
print(f"  (F1*D output filters, 1 per group, n_channels, 1)")

# Create MNE info for topomaps
info = mne.create_info(ch_names=CHANNEL_NAMES, sfreq=250, ch_types='eeg')
montage = mne.channels.make_standard_montage('standard_1020')
info.set_montage(montage)

# Plot first 8 spatial filters
fig, axes = plt.subplots(2, 4, figsize=(14, 6))
axes = axes.flatten()

for i in range(min(8, conv2_weights.shape[0])):
    # Get spatial weights for this filter
    # Shape is (out_channels, in_channels/groups, height, width)
    # For depthwise, in_channels/groups = 1
    spatial_weights = conv2_weights[i, 0, :, 0]
    
    mne.viz.plot_topomap(spatial_weights, info, axes=axes[i], show=False)
    axes[i].set_title(f'Spatial Filter {i+1}')

fig.suptitle('Learned Spatial Filters (Like CSP Patterns)', y=1.02)
plt.tight_layout();
plt.savefig(FIGURES_DIR / 'eegnet_spatial_filters.png', dpi=150, bbox_inches='tight');

In [None]:
# Save results summary
eegnet_results = {
    'method': 'EEGNet',
    'hyperparameters': HPARAMS,
    'subjects': [r['subject'] for r in all_results],
    'per_subject_accuracy': [r['accuracy'] for r in all_results],
    'per_subject_kappa': [r['kappa'] for r in all_results],
    'mean_accuracy': float(avg_acc),
    'std_accuracy': float(std_acc),
    'mean_kappa': float(avg_kappa),
    'epochs_trained': [r['epochs_trained'] for r in all_results]
}

with open(RESULTS_DIR / 'eegnet_results.json', 'w') as f:
    json.dump(eegnet_results, f, indent=2)

print(f"Results saved to {RESULTS_DIR / 'eegnet_results.json'}")

In [None]:
# Save predictions
np.savez(
    RESULTS_DIR / 'eegnet_predictions.npz',
    y_true=y_true_all_test,
    y_pred=y_pred_all_test,
    subjects=np.concatenate([np.full(len(r['y_true']), r['subject']) for r in all_results])
)

print(f"Predictions saved to {RESULTS_DIR / 'eegnet_predictions.npz'}")

In [None]:
# Save best model
torch.save({
    'model_state_dict': best_model.state_dict(),
    'subject': best_subj,
    'accuracy': all_results[best_subj_idx]['accuracy'],
    'hyperparameters': HPARAMS,
    'n_channels': n_channels,
    'n_times': n_times,
    'n_classes': n_classes
}, RESULTS_DIR / 'best_eegnet_model.pt')

print(f"Best model saved to {RESULTS_DIR / 'best_eegnet_model.pt'}")

| Method | Mean Accuracy | Std | Kappa |
|--------|--------------|-----|-------|
| CSP + LDA | ~65-70% | - | ~0.55 |
| CSP + SVM | ~70-75% | - | ~0.60 |
| EEGNet | ~70-78% | - | ~0.65 |
| Chance | 25% | - | 0.00 |

Potential Improvements

- Data augmentation (time shift, noise)
- Cross-subject pretraining + fine-tuning
- Attention mechanisms (transformers)
- Larger models with more data