# T-Maze Deep Learning Classification

This notebook demonstrates deep learning approaches for T-maze decoding:
- EEGNet for EEG classification
- LSTM for temporal patterns
- Multimodal fusion
- Model interpretation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('..')

# Check for PyTorch
try:
    import torch
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
except ImportError:
    print("PyTorch not installed. Run: pip install torch")

from deeplearning import (
    EEGNet,
    LSTMDecoder,
    EEGTransformer,
    ROIMLP,
    CrossAttentionFusion
)
from deeplearning.training import (
    DeepTrainer,
    EarlyStopping,
    cross_validate_deep
)
from deeplearning.interpretation import (
    GradCAM,
    IntegratedGradients,
    feature_importance_deep
)

plt.style.use('seaborn-v0_8-whitegrid')
%matplotlib inline

## 1. Simulate EEG Data

In [None]:
# Simulate EEG epochs
np.random.seed(42)

n_epochs = 500
n_channels = 64
n_times = 256  # 1s at 256 Hz

# Generate synthetic EEG data
# Class 0: lower power in 200-300ms window
# Class 1: higher power (simulating reward positivity)

X_eeg = np.random.randn(n_epochs, n_channels, n_times) * 10  # Baseline noise
y_eeg = np.random.randint(0, 2, n_epochs)

# Add class-specific patterns
rewp_window = slice(50, 80)  # 200-310ms at 256 Hz
fcz_channels = [30, 31, 32]  # Frontocentral

for i in range(n_epochs):
    if y_eeg[i] == 1:
        # Add positive deflection (reward response)
        for ch in fcz_channels:
            X_eeg[i, ch, rewp_window] += np.sin(np.linspace(0, np.pi, 30)) * 15

# Split data
split = int(0.8 * n_epochs)
X_train, X_test = X_eeg[:split], X_eeg[split:]
y_train, y_test = y_eeg[:split], y_eeg[split:]

print(f"Training data: {X_train.shape}")
print(f"Test data: {X_test.shape}")
print(f"Class balance: {np.bincount(y_train)}")

## 2. EEGNet Classification

In [None]:
# Create and train EEGNet
eegnet = EEGNet(
    n_classes=2,
    n_channels=n_channels,
    n_times=n_times,
    f1=8,
    d=2,
    dropout=0.5,
    learning_rate=1e-3
)

# Compile model
eegnet.compile(input_shape=(n_channels, n_times))
print(f"EEGNet parameters: {eegnet.n_parameters():,}")

# Train with early stopping
trainer = DeepTrainer(eegnet)
early_stop = EarlyStopping(patience=15, mode='max')

result = trainer.train(
    X_train, y_train,
    X_val=X_test, y_val=y_test,
    epochs=100,
    batch_size=32,
    early_stopping=early_stop,
    verbose=True
)

print(f"\nFinal accuracy: {result.accuracy:.3f}")

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss
axes[0].plot(result.train_losses, label='Train')
if result.val_losses:
    axes[0].plot(result.val_losses, label='Validation')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()

# Accuracy
axes[1].plot(result.train_accuracies, label='Train')
if result.val_accuracies:
    axes[1].plot(result.val_accuracies, label='Validation')
axes[1].axhline(0.5, color='gray', linestyle='--', label='Chance')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training Accuracy')
axes[1].legend()

plt.tight_layout()
plt.show()

## 3. LSTM Decoder

In [None]:
# Train LSTM
lstm = LSTMDecoder(
    n_classes=2,
    hidden_size=64,
    n_layers=2,
    bidirectional=True,
    dropout=0.5
)

lstm.compile(input_shape=(n_channels, n_times))
print(f"LSTM parameters: {lstm.n_parameters():,}")

trainer_lstm = DeepTrainer(lstm)
result_lstm = trainer_lstm.train(
    X_train, y_train,
    X_val=X_test, y_val=y_test,
    epochs=50,
    batch_size=32,
    early_stopping=EarlyStopping(patience=10),
    verbose=True
)

print(f"\nLSTM accuracy: {result_lstm.accuracy:.3f}")

## 4. Model Comparison

In [None]:
# Compare models
models = {
    'EEGNet': (EEGNet, {'f1': 8, 'd': 2}),
    'LSTM': (LSTMDecoder, {'hidden_size': 64}),
    'Transformer': (EEGTransformer, {'d_model': 64, 'n_heads': 4})
}

comparison_results = {}

for name, (model_class, kwargs) in models.items():
    print(f"\n{'='*50}")
    print(f"Training {name}")
    print('='*50)
    
    try:
        model = model_class(n_classes=2, **kwargs)
        model.compile(input_shape=(n_channels, n_times))
        
        trainer = DeepTrainer(model)
        result = trainer.train(
            X_train, y_train,
            X_val=X_test, y_val=y_test,
            epochs=30,
            batch_size=32,
            early_stopping=EarlyStopping(patience=10),
            verbose=False
        )
        
        comparison_results[name] = {
            'accuracy': result.accuracy,
            'n_params': model.n_parameters()
        }
        print(f"{name}: accuracy={result.accuracy:.3f}, params={model.n_parameters():,}")
    except Exception as e:
        print(f"{name} failed: {e}")

# Plot comparison
if comparison_results:
    fig, ax = plt.subplots(figsize=(10, 5))
    names = list(comparison_results.keys())
    accuracies = [comparison_results[n]['accuracy'] for n in names]
    
    ax.bar(names, accuracies, color='steelblue', edgecolor='black')
    ax.axhline(0.5, color='red', linestyle='--', label='Chance')
    ax.set_ylabel('Accuracy')
    ax.set_title('Model Comparison')
    ax.set_ylim([0.4, 1.0])
    
    for i, acc in enumerate(accuracies):
        ax.text(i, acc + 0.02, f'{acc:.1%}', ha='center')
    
    plt.tight_layout()
    plt.show()

## 5. Model Interpretation

In [None]:
# Feature importance using integrated gradients
try:
    importances = feature_importance_deep(
        eegnet.model,
        X_test[:50],
        y_test[:50],
        method='gradient',
        n_samples=50
    )
    
    # Plot channel importance
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Temporal importance (averaged across channels)
    temporal_imp = np.mean(importances, axis=0)
    times = np.linspace(0, 1, n_times)
    axes[0].plot(times, temporal_imp)
    axes[0].axvspan(0.2, 0.35, alpha=0.2, color='yellow', label='REWP window')
    axes[0].set_xlabel('Time (s)')
    axes[0].set_ylabel('Importance')
    axes[0].set_title('Temporal Feature Importance')
    axes[0].legend()
    
    # Channel importance (averaged across time)
    channel_imp = np.mean(importances, axis=1)
    axes[1].bar(range(len(channel_imp)), channel_imp)
    axes[1].set_xlabel('Channel')
    axes[1].set_ylabel('Importance')
    axes[1].set_title('Channel Feature Importance')
    
    plt.tight_layout()
    plt.show()
    
except Exception as e:
    print(f"Interpretation failed: {e}")

## 6. Cross-Validation

In [None]:
# Cross-validate EEGNet
cv_result = cross_validate_deep(
    EEGNet,
    X_eeg, y_eeg,
    n_folds=5,
    epochs=30,
    batch_size=32,
    early_stopping=True,
    verbose=False,
    # Model kwargs
    n_classes=2,
    f1=8,
    d=2
)

print(f"\n{'='*50}")
print("CROSS-VALIDATION RESULTS")
print('='*50)
print(f"Mean accuracy: {cv_result.accuracy:.3f} Â± {cv_result.accuracy_std:.3f}")
print(f"Fold accuracies: {cv_result.cv_accuracies}")

## 7. fMRI Classification with MLP

In [None]:
# Simulate fMRI ROI data
np.random.seed(42)

n_trials = 400
n_rois = 426  # HCP atlas

# Generate synthetic ROI betas
X_fmri = np.random.randn(n_trials, n_rois)
y_fmri = np.random.randint(0, 2, n_trials)

# Add signal to reward-related ROIs
reward_rois = [100, 101, 102, 200, 201]  # Simulated VS, vmPFC
for i in range(n_trials):
    if y_fmri[i] == 1:
        X_fmri[i, reward_rois] += 0.5

# Split
split = int(0.8 * n_trials)
X_train_fmri, X_test_fmri = X_fmri[:split], X_fmri[split:]
y_train_fmri, y_test_fmri = y_fmri[:split], y_fmri[split:]

# Train MLP
mlp = ROIMLP(
    n_classes=2,
    hidden_sizes=[256, 128, 64],
    dropout=0.5
)

mlp.compile(input_shape=(n_rois,))
print(f"MLP parameters: {mlp.n_parameters():,}")

trainer_mlp = DeepTrainer(mlp)
result_mlp = trainer_mlp.train(
    X_train_fmri, y_train_fmri,
    X_val=X_test_fmri, y_val=y_test_fmri,
    epochs=50,
    batch_size=32,
    early_stopping=EarlyStopping(patience=10),
    verbose=True
)

print(f"\nfMRI MLP accuracy: {result_mlp.accuracy:.3f}")

## Summary

In [None]:
print("\n" + "="*60)
print("DEEP LEARNING SUMMARY")
print("="*60)
print(f"{'Model':<20} {'Accuracy':<12} {'Parameters':<15}")
print("-"*60)
print(f"{'EEGNet':<20} {result.accuracy:<12.3f} {eegnet.n_parameters():<15,}")
print(f"{'LSTM':<20} {result_lstm.accuracy:<12.3f} {lstm.n_parameters():<15,}")
print(f"{'fMRI MLP':<20} {result_mlp.accuracy:<12.3f} {mlp.n_parameters():<15,}")
print("="*60)