# 03 - Multimodal EEG-fMRI Fusion

Combine EEG and fMRI features for improved classification.

**Key Analysis:**
1. Extract REWP features from EEG
2. Extract ROI values from fMRI
3. Early fusion (feature concatenation)
4. Late fusion (probability averaging)
5. Compare unimodal vs. multimodal

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt

from core.containers import TMazeEEGData, TMAzefMRIData, TMazeSubject
from classification.multimodal import (
    early_fusion,
    late_fusion,
    multimodal_fusion,
    compare_fusion_methods
)
from visualization.plots import plot_multimodal_comparison

## 1. Create Simulated Multimodal Data

In [None]:
# Simulate aligned EEG and fMRI data
np.random.seed(42)

n_trials = 100
n_channels = 64
n_times = 200
n_rois = 426

# Labels (shared across modalities)
y = np.random.randint(0, 2, n_trials)

# EEG data with reward signal
times = np.linspace(-0.2, 0.8, n_times)
X_eeg = np.random.randn(n_trials, n_channels, n_times) * 5
rewp_mask = (times >= 0.24) & (times <= 0.34)
for i in range(n_trials):
    if y[i] == 1:
        X_eeg[i, :10, rewp_mask] += 3

# fMRI data with reward signal in some ROIs
X_fmri = np.random.randn(n_trials, n_rois)
for i in range(n_trials):
    if y[i] == 1:
        X_fmri[i, [10, 25, 50]] += 0.8

# Create containers
channels = [f'Ch{i}' for i in range(n_channels)]
channels[:5] = ['FCz', 'Fz', 'Cz', 'FC1', 'FC2']

eeg_data = TMazeEEGData(
    data=X_eeg,
    times=times,
    labels=y,
    condition_names=['NoReward', 'Reward'],
    channels=channels,
    sfreq=200,
    subject_id='demo'
)

roi_names = [f'ROI_{i:03d}' for i in range(n_rois)]
fmri_data = TMAzefMRIData(
    data=X_fmri,
    labels=y,
    condition_names=['NoReward', 'Reward'],
    roi_names=roi_names,
    subject_id='demo'
)

# Create subject container
subject = TMazeSubject(
    subject_id='demo',
    eeg_data=eeg_data,
    fmri_data=fmri_data
)

print(subject)
print(f"Multimodal: {subject.is_multimodal}")

## 2. Extract Features

In [None]:
# Extract REWP features from EEG
rewp_data, rewp_times = eeg_data.get_rewp_window()
eeg_features = rewp_data.mean(axis=(1, 2)).reshape(-1, 1)  # Mean REWP amplitude

print(f"EEG features: {eeg_features.shape}")

# fMRI features (all ROIs)
fmri_features = fmri_data.data

print(f"fMRI features: {fmri_features.shape}")

## 3. Early Fusion

In [None]:
# Early fusion: concatenate features
early_result = early_fusion(
    eeg_features,
    fmri_features,
    y,
    classifier_type='svm',
    cv=5,
    normalize_modalities=True
)

print("Early Fusion Results:")
print(f"  EEG only:  {early_result.eeg_only_accuracy:.1%}")
print(f"  fMRI only: {early_result.fmri_only_accuracy:.1%}")
print(f"  Fused:     {early_result.accuracy:.1%} (+/- {early_result.accuracy_std:.1%})")
print(f"  Improvement: {early_result.fusion_improvement:+.1%}")

## 4. Late Fusion

In [None]:
# Late fusion: average predictions
late_result = late_fusion(
    eeg_features,
    fmri_features,
    y,
    classifier_type='svm',
    cv=5,
    weights=(0.5, 0.5)  # Equal weights
)

print("Late Fusion Results:")
print(f"  EEG only:  {late_result.eeg_only_accuracy:.1%}")
print(f"  fMRI only: {late_result.fmri_only_accuracy:.1%}")
print(f"  Fused:     {late_result.accuracy:.1%} (+/- {late_result.accuracy_std:.1%})")
print(f"  Improvement: {late_result.fusion_improvement:+.1%}")

## 5. Using TMazeSubject Container

In [None]:
# Use high-level multimodal_fusion function
result = multimodal_fusion(
    subject,
    fusion_type='early',
    eeg_feature_type='rewp_mean',  # Options: rewp_mean, all_times, peak
    classifier_type='svm',
    cv=5
)

print("Multimodal Fusion (Subject Container):")
print(f"  Fused accuracy: {result.accuracy:.1%}")

## 6. Compare Methods

In [None]:
# Compare different fusion approaches
comparison = compare_fusion_methods(
    subject,
    classifier_types=['lda', 'svm'],
    cv=5
)

print("Fusion Method Comparison:")
print("=" * 50)
for method, result in comparison.items():
    print(f"{method:15s}: {result.accuracy:.1%} (improvement: {result.fusion_improvement:+.1%})")

## 7. Visualization

In [None]:
# Plot comparison
plot_multimodal_comparison(
    early_result,
    title='Early Fusion: EEG + fMRI'
)
plt.show()

In [None]:
# Compare all methods
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Early vs Late
methods = ['EEG\nOnly', 'fMRI\nOnly', 'Early\nFusion', 'Late\nFusion']
accuracies = [
    early_result.eeg_only_accuracy,
    early_result.fmri_only_accuracy,
    early_result.accuracy,
    late_result.accuracy
]
colors = ['#3498db', '#e74c3c', '#2ecc71', '#9b59b6']

axes[0].bar(methods, accuracies, color=colors, edgecolor='black')
axes[0].axhline(0.5, color='gray', linestyle='--', label='Chance')
axes[0].set_ylabel('Accuracy')
axes[0].set_title('Unimodal vs. Multimodal')
axes[0].set_ylim([0.4, 1.0])

# Classifier comparison
clf_methods = list(comparison.keys())
clf_accs = [comparison[m].accuracy for m in clf_methods]

axes[1].bar(clf_methods, clf_accs, color='steelblue', edgecolor='black')
axes[1].axhline(0.5, color='gray', linestyle='--')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Fusion Method Comparison')
axes[1].tick_params(axis='x', rotation=45)
axes[1].set_ylim([0.4, 1.0])

plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated:
1. Extracting REWP features from EEG
2. Early fusion (feature concatenation)
3. Late fusion (probability averaging)
4. Comparing fusion approaches

### Key Findings Template:
- EEG-only accuracy: XX%
- fMRI-only accuracy: XX%
- Fused accuracy: XX%
- Improvement: +X%
- Best method: Early/Late fusion with LDA/SVM

### Next Steps:
- **04_rsa_analysis.ipynb**: Representational similarity analysis