In [1]:
# Cell 1: Imports & Configuration
import os
import sys
import glob
import copy
import h5py
import logging
import random
import warnings
from pathlib import Path
from datetime import datetime
from collections import defaultdict
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (accuracy_score, f1_score, precision_score, recall_score,
                              roc_auc_score, confusion_matrix, roc_curve)
from sklearn.utils.class_weight import compute_class_weight

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')

# Add pipeline directories to path
sys.path.insert(0, os.path.join(os.getcwd(), 'EEG_CODE'))
sys.path.insert(0, os.path.join(os.getcwd(), 'fMRI_CODE'))

In [None]:
class BridgeConfig:
    def __init__(self):
        self.eeg_checkpoint_dir = Path('./EEG_CODE/checkpoints')
        self.fmri_checkpoint_dir = Path('./fMRI_CODE/checkpoints_fmri')
        self.eeg_hidden_dim = 128
        self.fmri_hidden_dim = 64
        self.bridge_hidden_dim = 128
        self.num_classes = 2
        self.overlap_subjects = list(range(1, 33))
        self.n_splits = 5
        self.batch_size = 8
        self.num_epochs = 50
        self.lr = 1e-4
        self.weight_decay = 1e-4
        self.patience = 10
        self.grad_clip = 1.0
        self.dropout = 0.3
        #self.eeg_base_path = Path(os.getenv('EEG_DATA_PATH', r'E:\Intermediate\BACON_ERIC\Head_neck'      ))
        self.eeg_base_path = Path(os.getenv('EEG_DATA_PATH', r'E:\Head_neck'))
        self.eeg_path_pw = self.eeg_base_path / 'EEG' / 'DATA' / 'PROC' / 'data_proc' / 'cleaned_data' / 'TF_dir' / 'pwspctrm' / 'PWS' / 'feat' / 'New'
        self.eeg_path_erp = self.eeg_base_path / 'EEG' / 'DATA' / 'PROC' / 'data_proc' / 'cleaned_data' / 'TF_dir' / 'ERP' / 'New'
        self.eeg_path_conn = self.eeg_base_path / 'EEG' / 'DATA' / 'PROC' / 'data_proc' / 'cleaned_data' / 'conn_dir' / 'New'
        self.eeg_label_path = self.eeg_base_path / 'EEG' / 'DATA' / 'PROC' / 'data_proc' / 'cleaned_data' / 'TF_dir'
        #self.fmri_base_path = Path(r'E:\Intermediate\BACON_ERIC\Head_neck\fMRI\Neck-Tumor_data\PATIENTS')
        self.fmri_base_path = Path(r'E:\Head_neck\fMRI')
        self.fmri_data_dir = self.fmri_base_path
        self.fmri_label_path = self.fmri_base_path / 'DATA' / 'labels'
        self.fmri_activation_types = ['sensory', 'AN', 'LN', 'cognitive', 'DMN']
        self.fmri_connectivity_types = ['DMN']
        self.fmri_agg_method = 'both'
        self.bands = {'alpha': 'Alpha', 'beta': 'Beta', 'theta': 'Theta'}
        self.eeg_segments = ['1_Hz', '2_Hz', '4_Hz', '6_Hz', '8_Hz', '10_Hz', '12_Hz',
                             '14_Hz', '16_Hz', '18_Hz', '20_Hz', '25_Hz', '30_Hz', '40_Hz']
        self.func_segments = ['open', 'close']
        self.output_dir = Path('./results_bridge')
        self.checkpoint_dir = Path('./checkpoints_bridge')
        self.log_dir = Path('./logs_bridge')
        for d in [self.output_dir, self.checkpoint_dir, self.log_dir]:
            d.mkdir(parents=True, exist_ok=True)

In [None]:
# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

config = BridgeConfig()
log_file = config.log_dir / 'bridge_fusion.log'
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger('bridge_fusion')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
logger.info(f'Device: {device}')
logger.info(f'Bridge Config: overlap_subjects={len(config.overlap_subjects)}, '
            f'bridge_hidden_dim={config.bridge_hidden_dim}')
print('Cell 1 complete: Imports & configuration loaded')

2026-02-17 03:01:02,680 - bridge_fusion - INFO - Device: cpu
2026-02-17 03:01:02,680 - bridge_fusion - INFO - Bridge Config: overlap_subjects=32, bridge_hidden_dim=128


Cell 1 complete: Imports & configuration loaded


In [None]:
# Import shared EEG components
from crossmodal_v4_enhancements import (
    EnhancedTriModalFusionNetV4,
    LearnedFusionModule,
    EnhancedERPEncoder,
    EnhancedPowerEncoder,
    get_fusion_weights_from_model
)


# ---- EEG Model Wrapper (matches checkpoint structure) ----
class ImprovedTriModalFusionNet(nn.Module):
    """Wrapper around EnhancedTriModalFusionNetV4 matching EEG checkpoint keys."""
    def __init__(self, in_pw_dim, in_erp_dim, in_conn_dim,
                 fusion_dim=128, num_classes=2, dropout=0.3,
                 num_transformer_layers=2, num_heads=4):
        super().__init__()
        self.model = EnhancedTriModalFusionNetV4(
            erp_channels=in_erp_dim,
            pw_channels=in_pw_dim,
            conn_features=in_conn_dim,
            hidden_dim=fusion_dim,
            num_classes=num_classes,
            dropout=dropout,
            num_transformer_layers=num_transformer_layers,
            num_heads=num_heads
        )
        self.fusion_weight_history = []

    def forward(self, erp, pw, conn, return_feats=False):
        if return_feats:
            logits, fusion_weights, fused_feats = self.model(
                erp, pw, conn,
                return_fusion_weights=True,
                return_fused_feats=True
            )
            return {
                'logits': logits,
                'gates': fusion_weights,
                'fused_feats': fused_feats
            }
        else:
            return self.model(erp, pw, conn)

    def get_fusion_weights(self):
        return get_fusion_weights_from_model(self.model)


# ---- fMRI Model Components (matching checkpoint structure) ----
class ActivationEncoder(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int = 64, dropout: float = 0.3):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(in_dim, hidden_dim * 2),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.encoder(x)


class ConnectivityEncoderFMRI(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int = 64, dropout: float = 0.3):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(in_dim, hidden_dim * 2),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.encoder(x)


class fMRIFusionNet(nn.Module):
    """fMRI fusion model matching checkpoint structure. Bug-fixed forward."""
    def __init__(self, activation_dim: int, connectivity_dim: int, hidden_dim: int = 64,
                 num_classes: int = 2, dropout: float = 0.4, task: str = 'classification'):
        super().__init__()
        self.task = task
        self.activation_encoder = ActivationEncoder(activation_dim, hidden_dim, dropout)
        self.connectivity_encoder = ConnectivityEncoderFMRI(connectivity_dim, hidden_dim, dropout)
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.activation_weight = nn.Parameter(torch.ones(1) * 0.5)
        self.connectivity_weight = nn.Parameter(torch.ones(1) * 0.5)
        if task == 'classification':
            self.head = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim // 2, num_classes)
            )
        else:
            self.head = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim // 2, 1)
            )

    def forward(self, activation, connectivity, return_features=False):
        act_feat = self.activation_encoder(activation)
        conn_feat = self.connectivity_encoder(connectivity)
        weights = F.softmax(torch.stack([self.activation_weight, self.connectivity_weight]), dim=0)
        weighted_act = act_feat * weights[0]
        weighted_conn = conn_feat * weights[1]
        combined = torch.cat([weighted_act, weighted_conn], dim=1)
        fused = self.fusion(combined)
        output = self.head(fused)
        if self.task == 'regression':
            output = output.squeeze(-1)
        if return_features:
            return output, fused
        return output

    def get_fusion_weights(self):
        with torch.no_grad():
            weights = F.softmax(torch.stack([self.activation_weight, self.connectivity_weight]), dim=0)
            return {'activation': weights[0].item(), 'connectivity': weights[1].item()}


print('Cell 2 complete: Pipeline model architectures defined')

CrossModal V4 Enhancements Module Loaded

Available classes:
  - EnhancedTriModalFusionNetV4 (tri-modal: ERP + PW + Connectivity)
  - EnhancedSmartFusionNetV4 (bi-modal: ERP + PW, with cross-attention)
  - BiDirectionalCrossAttention (NEW: mutual cross-modal attention)
  - EnhancedERPEncoder (1D CNN + temporal transformers)
  - EnhancedPowerEncoder (multi-scale CNN + transformers)
  - LearnedFusionModule (learned fusion weights with temperature)

Utility functions:
  - get_fusion_weights_from_model(model)
  - count_parameters(model)

V4.1 Updates:
  - Bi-modal now has bi-directional cross-modal attention
  - Increased dropout (0.3 -> 0.4) for small datasets
  - Deeper classifier head matching tri-modal


V4-LITE Components for Small Datasets Added

New classes for improved trimodal performance:
  - EnhancedTriModalFusionNetV4Lite (lightweight, ~400K params)
  - BalancedTriModalDataset (handles sample count mismatch)
  - HybridFusionModule (early ERP-PW + late CONN fusion)
  - EnhancedC

In [None]:
# Cell 3: Data Loading - EEG Side
# Re-implements loading logic from CrossModal_script_01.ipynb for subjects 1-32

from scipy.io import loadmat
from tqdm import tqdm


def load_eeg_conn_features(conn_dir, subject_list, band_list, cond_list):
    """Load EEG connectivity features from .mat files."""
    conn_features = {}
    for subj in subject_list:
        subj_str = f'{subj:02d}'
        for band_key, band_name in band_list.items():
            for cond in cond_list:
                pattern = conn_dir / f'conn_{band_name}_{cond}_sub{subj_str}.mat'
                files = sorted(glob.glob(str(pattern)))
                if not files:
                    pattern_lower = conn_dir / f'conn_{band_key}_{cond}_sub{subj_str}.mat'
                    files = sorted(glob.glob(str(pattern_lower)))
                for f in files:
                    try:
                        mat = loadmat(f)
                        for k in mat:
                            if not k.startswith('_'):
                                data = np.array(mat[k], dtype=np.float32).flatten()
                                data = np.nan_to_num(data, nan=0.0)
                                label_val = 0  # placeholder
                                conn_key = (subj, band_key, cond, label_val)
                                conn_features[conn_key] = data
                                break
                    except Exception as e:
                        logger.warning(f'Error loading {f}: {e}')
    logger.info(f'Loaded {len(conn_features)} EEG connectivity samples')
    return conn_features


def load_eeg_pw_features(pw_dir, subject_list, band_list, freq_list):
    """Load EEG power spectrum features from .mat files."""
    pw_features = {}
    for subj in subject_list:
        subj_str = f'{subj:02d}'
        for band in band_list:
            for freq in freq_list:
                pattern = str(pw_dir / f'powspctrm_{band}_{freq}_sub{subj_str}.mat')
                for f in sorted(glob.glob(pattern)):
                    try:
                        mat = loadmat(f)
                        for k in mat:
                            if not k.startswith('_'):
                                data = np.array(mat[k], dtype=np.float32).flatten()
                                data = np.nan_to_num(data, nan=0.0)
                                label_val = 0
                                pw_key = (subj, band, freq, label_val)
                                pw_features[pw_key] = data
                                break
                    except Exception as e:
                        logger.warning(f'Error loading {f}: {e}')
    logger.info(f'Loaded {len(pw_features)} EEG power spectrum samples')
    return pw_features


def load_eeg_erp_features(erp_dir, subject_list, band_list, freq_list):
    """Load EEG ERP features from .mat/.h5 files."""
    erp_features = {}
    for subj in subject_list:
        subj_str = f'{subj:02d}'
        for band in band_list:
            for freq in freq_list:
                pattern = erp_dir / f'ERP_sub{subj_str}_{band}_{freq}*.mat'
                erp_files = sorted(glob.glob(str(pattern)))
                for f in erp_files:
                    try:
                        with h5py.File(f, 'r') as hf:
                            # Navigate to ERP data
                            if 'erp_struct' in hf:
                                erp_group = hf['erp_struct']
                            elif 'erp' in hf:
                                erp_group = hf['erp']
                            else:
                                erp_group = hf[list(hf.keys())[0]]

                            if 'avg' in erp_group:
                                data = np.array(erp_group['avg'], dtype=np.float32)
                            elif 'trial' in erp_group:
                                data = np.array(erp_group['trial'], dtype=np.float32)
                                if data.ndim == 3:
                                    data = np.mean(data, axis=0)
                            else:
                                # Try first dataset-like key
                                for dk in erp_group.keys():
                                    candidate = erp_group[dk]
                                    if hasattr(candidate, 'shape') and len(candidate.shape) >= 2:
                                        data = np.array(candidate, dtype=np.float32)
                                        break
                                else:
                                    continue

                            data = np.nan_to_num(data, nan=0.0)
                            label_val = 0
                            erp_key = (subj, band, freq, label_val)
                            erp_features[erp_key] = data
                    except Exception as e:
                        # Try scipy loadmat fallback
                        try:
                            mat = loadmat(f)
                            for k in mat:
                                if not k.startswith('_'):
                                    data = np.array(mat[k], dtype=np.float32)
                                    data = np.nan_to_num(data, nan=0.0)
                                    label_val = 0
                                    erp_key = (subj, band, freq, label_val)
                                    erp_features[erp_key] = data
                                    break
                        except Exception:
                            logger.warning(f'Error loading ERP {f}: {e}')
    logger.info(f'Loaded {len(erp_features)} EEG ERP samples')
    return erp_features


def load_eeg_labels(label_dir, binary=True):
    """Load EEG clinical labels."""
    csv_path = os.path.join(label_dir, 'medical_score.csv')
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f'Label file not found: {csv_path}')
    df = pd.read_csv(csv_path)
    df = df.dropna(subset=['Postoperative evaluation'])
    if df['Subject'].dtype == object:
        df['subject_id'] = df['Subject'].str.replace('sub', '', regex=False).astype(int)
    else:
        df['subject_id'] = df['Subject'].astype(int)
    label_dict = {}
    for _, row in df.iterrows():
        subj = int(row['subject_id'])
        score = row['Postoperative evaluation']
        label_dict[subj] = 0 if score <= 2 else 1 if binary else score
    return label_dict


# Load EEG data (filtered to subjects 1-32)
logger.info('Loading EEG data for subjects 1-32...')

eeg_label_dict = load_eeg_labels(str(config.eeg_label_path))
logger.info(f'EEG labels: {len(eeg_label_dict)} subjects')

band_keys = list(config.bands.keys())

eeg_erp_features = load_eeg_erp_features(
    config.eeg_path_erp, config.overlap_subjects, band_keys, config.eeg_segments
)
eeg_pw_features = load_eeg_pw_features(
    config.eeg_path_pw, config.overlap_subjects, band_keys, config.eeg_segments
)
eeg_conn_features = load_eeg_conn_features(
    config.eeg_path_conn, config.overlap_subjects, config.bands, config.func_segments
)

logger.info(f'EEG data loaded: ERP={len(eeg_erp_features)}, PW={len(eeg_pw_features)}, CONN={len(eeg_conn_features)}')
print('Cell 3 complete: EEG data loaded')

2026-02-17 03:01:02,784 - bridge_fusion - INFO - Loading EEG data for subjects 1-32...
2026-02-17 03:01:02,891 - bridge_fusion - INFO - EEG labels: 66 subjects


In [None]:
# Cell 4: Data Loading - fMRI Side

def load_fmri_activation_features(data_dir, subject_list, activation_types, agg_method='both'):
    """Load fMRI activation features."""
    features = {}
    for subj in tqdm(subject_list, desc='Loading fMRI activations'):
        subj_features = []
        subj_dir = data_dir / f'sub-{subj}'
        for act_type in activation_types:
            filepath = subj_dir / f'subject_{subj}_activation_{act_type}.csv'
            if not filepath.exists():
                continue
            try:
                df = pd.read_csv(filepath)
                if 'Subject' in df.columns:
                    df = df.drop('Subject', axis=1)
                data = df.values.astype(np.float32)
                data = np.nan_to_num(data, nan=0.0)
                if agg_method == 'mean':
                    agg_data = np.mean(data, axis=0)
                elif agg_method == 'std':
                    agg_data = np.std(data, axis=0)
                elif agg_method == 'both':
                    agg_data = np.concatenate([np.mean(data, axis=0), np.std(data, axis=0)])
                else:
                    raise ValueError(f'Unknown agg method: {agg_method}')
                subj_features.append(agg_data)
            except Exception as e:
                logger.warning(f'Error loading {filepath}: {e}')
        if subj_features:
            features[subj] = torch.tensor(np.concatenate(subj_features), dtype=torch.float32)
    logger.info(f'fMRI activation features: {len(features)} subjects')
    if features:
        sample = list(features.values())[0]
        logger.info(f'  Activation feature dim: {sample.shape[0]}')
    return features


def load_fmri_connectivity_features(data_dir, subject_list, connectivity_types):
    """Load fMRI connectivity features."""
    features = {}
    for subj in tqdm(subject_list, desc='Loading fMRI connectivity'):
        subj_features = []
        subj_dir = data_dir / f'sub-{subj}'
        for conn_type in connectivity_types:
            filepath = subj_dir / f'subject_{subj}_fdr_PPI_Connectivity_{conn_type}.csv'
            if not filepath.exists():
                continue
            try:
                df = pd.read_csv(filepath)
                if 'Subject' in df.columns:
                    df = df.drop('Subject', axis=1)
                data = df.values.astype(np.float32).flatten()
                data = np.nan_to_num(data, nan=0.0)
                subj_features.append(data)
            except Exception as e:
                logger.warning(f'Error loading {filepath}: {e}')
        if subj_features:
            features[subj] = torch.tensor(np.concatenate(subj_features), dtype=torch.float32)
    logger.info(f'fMRI connectivity features: {len(features)} subjects')
    if features:
        sample = list(features.values())[0]
        logger.info(f'  Connectivity feature dim: {sample.shape[0]}')
    return features


def load_fmri_labels(label_path, subject_list):
    """Load fMRI labels."""
    label_files = [label_path / 'labels.csv', label_path / 'outcomes.csv',
                   label_path / 'subjects_labels.csv', label_path.parent / 'labels.csv']
    label_file = None
    for lf in label_files:
        if lf.exists():
            label_file = lf
            break
    if label_file is None:
        logger.warning('No fMRI label file found. Using dummy labels.')
        return {subj: np.random.randint(0, 2) for subj in subject_list}

    df = pd.read_csv(label_file)
    subj_col = next((c for c in ['Subject', 'subject', 'SubjectID', 'ID', 'id'] if c in df.columns), None)
    label_col = next((c for c in ['Label', 'label', 'Outcome', 'outcome', 'Class', 'class', 'Group', 'group'] if c in df.columns), None)
    if not subj_col or not label_col:
        raise ValueError(f'Cannot identify columns in {label_file}: {df.columns.tolist()}')

    class_labels = {}
    for _, row in df.iterrows():
        subj = int(row[subj_col])
        if subj not in subject_list:
            continue
        label = row[label_col]
        if isinstance(label, str):
            label = 1 if label.lower() in ['good', 'positive', 'yes', '1'] else 0
        else:
            label = int(label)
        class_labels[subj] = label
    logger.info(f'fMRI labels: {len(class_labels)} subjects, classes={set(class_labels.values())}')
    return class_labels


# Load fMRI data
logger.info('Loading fMRI data...')

fmri_act_features = load_fmri_activation_features(
    config.fmri_data_dir, config.overlap_subjects,
    config.fmri_activation_types, config.fmri_agg_method
)
fmri_conn_features = load_fmri_connectivity_features(
    config.fmri_data_dir, config.overlap_subjects,
    config.fmri_connectivity_types
)
fmri_label_dict = load_fmri_labels(config.fmri_label_path, config.overlap_subjects)

logger.info(f'fMRI data loaded: Act={len(fmri_act_features)}, Conn={len(fmri_conn_features)}')
print('Cell 4 complete: fMRI data loaded')

In [None]:
# Cell 5: Subject Alignment & Bridge Dataset

# Use EEG labels as ground truth (both pipelines share same subjects 1-32)
# Prefer EEG labels; fall back to fMRI if an EEG label is missing
bridge_labels = {}
for subj in config.overlap_subjects:
    if subj in eeg_label_dict:
        bridge_labels[subj] = eeg_label_dict[subj]
    elif subj in fmri_label_dict:
        bridge_labels[subj] = fmri_label_dict[subj]

logger.info(f'Bridge labels: {len(bridge_labels)} subjects')
logger.info(f'Class distribution: {dict(zip(*np.unique(list(bridge_labels.values()), return_counts=True)))}')


class BridgeRawDataset(Dataset):
    """Dataset that holds aligned raw EEG + fMRI data per subject.

    EEG has multiple samples per subject (band x freq combos).
    For each subject we store the raw fMRI tensors and a list of
    EEG (erp, pw, conn) tuples. During __getitem__ we return all
    EEG samples for one subject so the feature extractor can aggregate.
    """
    def __init__(self, eeg_erp, eeg_pw, eeg_conn, fmri_act, fmri_conn,
                 labels, subject_list, bands, func_segments):
        self.samples = []

        # Build per-subject EEG sample lists
        eeg_by_subj = defaultdict(list)
        for key, erp_val in eeg_erp.items():
            subj = key[0]
            band = key[1]
            freq = key[2]
            label_val = key[3]
            pw_val = eeg_pw.get(key)
            # Look up connectivity
            lookup_band = str(band).lower()
            conn_val = None
            for cond in func_segments:
                conn_key = (subj, lookup_band, cond, label_val)
                if conn_key in eeg_conn:
                    conn_val = eeg_conn[conn_key]
                    break
            if pw_val is not None and conn_val is not None:
                eeg_by_subj[subj].append((erp_val, pw_val, conn_val))

        # Align subjects
        for subj in sorted(subject_list):
            if (subj not in eeg_by_subj or subj not in fmri_act or
                    subj not in fmri_conn or subj not in labels):
                continue
            self.samples.append({
                'subject': subj,
                'label': labels[subj],
                'eeg_samples': eeg_by_subj[subj],
                'fmri_act': fmri_act[subj],
                'fmri_conn': fmri_conn[subj],
            })

        logger.info(f'BridgeRawDataset: {len(self.samples)} aligned subjects')
        logger.info(f'  EEG samples per subject: '
                    f'min={min(len(s["eeg_samples"]) for s in self.samples)}, '
                    f'max={max(len(s["eeg_samples"]) for s in self.samples)}')

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        return s['eeg_samples'], s['fmri_act'], s['fmri_conn'], s['label'], s['subject']


bridge_raw_dataset = BridgeRawDataset(
    eeg_erp_features, eeg_pw_features, eeg_conn_features,
    fmri_act_features, fmri_conn_features,
    bridge_labels, config.overlap_subjects,
    config.bands, config.func_segments
)

print(f'\nAligned subjects: {len(bridge_raw_dataset)}')
print('Cell 5 complete: Subject alignment done')

In [None]:
# Cell 6: Load Pre-trained Models

def find_best_checkpoint(checkpoint_dir, pattern):
    """Find the best checkpoint file matching a glob pattern."""
    files = sorted(glob.glob(str(Path(checkpoint_dir) / pattern)))
    if not files:
        logger.warning(f'No checkpoint found matching {checkpoint_dir}/{pattern}')
        return None
    # Return the last one (most recent fold or best)
    return files[-1]


def load_eeg_model(checkpoint_path, dataset_sample, fusion_dim=128):
    """Instantiate and load an EEG tri-modal model from checkpoint."""
    eeg_samples = dataset_sample[0]  # list of (erp, pw, conn)
    sample_erp, sample_pw, sample_conn = eeg_samples[0]

    # Determine input dimensions
    if sample_erp.ndim == 2:
        in_erp_dim = sample_erp.shape[0]  # channels
    else:
        in_erp_dim = sample_erp.shape[0]

    if sample_pw.ndim == 2:
        in_pw_dim = sample_pw.shape[0]
    else:
        in_pw_dim = sample_pw.shape[0]

    in_conn_dim = sample_conn.shape[0] if sample_conn.ndim == 1 else np.prod(sample_conn.shape)

    logger.info(f'EEG model dims: ERP={in_erp_dim}, PW={in_pw_dim}, CONN={in_conn_dim}')

    model = ImprovedTriModalFusionNet(
        in_pw_dim=in_pw_dim,
        in_erp_dim=in_erp_dim,
        in_conn_dim=in_conn_dim,
        fusion_dim=fusion_dim,
        num_classes=config.num_classes
    )

    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)
        logger.info(f'EEG model loaded from {checkpoint_path}')
    else:
        logger.warning('No EEG checkpoint found, using random weights')

    model.to(device)
    model.eval()
    model.requires_grad_(False)
    return model


def load_fmri_model(checkpoint_path, fmri_act_dim, fmri_conn_dim, hidden_dim=64):
    """Instantiate and load an fMRI fusion model from checkpoint."""
    model = fMRIFusionNet(
        activation_dim=fmri_act_dim,
        connectivity_dim=fmri_conn_dim,
        hidden_dim=hidden_dim,
        num_classes=config.num_classes
    )

    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
        else:
            model.load_state_dict(checkpoint)
        logger.info(f'fMRI model loaded from {checkpoint_path}')
    else:
        logger.warning('No fMRI checkpoint found, using random weights')

    model.to(device)
    model.eval()
    model.requires_grad_(False)
    return model


# Find checkpoints
eeg_ckpt = find_best_checkpoint(config.eeg_checkpoint_dir, 'best_trimodal_fold*.pt')
fmri_ckpt = find_best_checkpoint(config.fmri_checkpoint_dir, 'best_fusion_fold*.pt')

# Determine fMRI input dimensions
sample_fmri_act = list(fmri_act_features.values())[0]
sample_fmri_conn = list(fmri_conn_features.values())[0]
fmri_act_dim = sample_fmri_act.shape[0]
fmri_conn_dim = sample_fmri_conn.shape[0]
logger.info(f'fMRI dims: activation={fmri_act_dim}, connectivity={fmri_conn_dim}')

# Load models
sample_data = bridge_raw_dataset[0]
eeg_model = load_eeg_model(eeg_ckpt, sample_data, fusion_dim=config.eeg_hidden_dim)
fmri_model = load_fmri_model(fmri_ckpt, fmri_act_dim, fmri_conn_dim, hidden_dim=config.fmri_hidden_dim)

n_eeg_params = sum(p.numel() for p in eeg_model.parameters())
n_fmri_params = sum(p.numel() for p in fmri_model.parameters())
logger.info(f'EEG model params: {n_eeg_params:,} (all frozen)')
logger.info(f'fMRI model params: {n_fmri_params:,} (all frozen)')

print('Cell 6 complete: Pre-trained models loaded and frozen')

In [None]:
# Cell 7: Feature Extraction Functions

@torch.no_grad()
def extract_eeg_features(model, raw_dataset, device):
    """Extract fused features from frozen EEG model.

    EEG has multiple samples per subject (band x freq). We mean-pool
    the fused features across all samples to get one 128-d vector per subject.
    """
    model.eval()
    features = {}

    for idx in range(len(raw_dataset)):
        eeg_samples, _, _, label, subj = raw_dataset[idx]
        feat_list = []

        for erp_np, pw_np, conn_np in eeg_samples:
            erp_t = torch.tensor(erp_np, dtype=torch.float32).unsqueeze(0).to(device)
            pw_t = torch.tensor(pw_np, dtype=torch.float32).unsqueeze(0).to(device)
            conn_t = torch.tensor(conn_np, dtype=torch.float32).unsqueeze(0).to(device)

            # Flatten conn if needed
            if conn_t.dim() > 2:
                conn_t = conn_t.view(conn_t.size(0), -1)

            try:
                out = model(erp=erp_t, pw=pw_t, conn=conn_t, return_feats=True)
                fused = out['fused_feats']  # (1, 128)
                feat_list.append(fused.cpu())
            except Exception as e:
                # Skip samples that cause dimension mismatches
                continue

        if feat_list:
            # Mean pool across all band x freq samples
            stacked = torch.cat(feat_list, dim=0)  # (N, 128)
            mean_feat = stacked.mean(dim=0)  # (128,)
            features[subj] = mean_feat

    logger.info(f'Extracted EEG features for {len(features)} subjects, dim={list(features.values())[0].shape[0] if features else "N/A"}')
    return features


@torch.no_grad()
def extract_fmri_features(model, fmri_act, fmri_conn, subject_list, device):
    """Extract fused features from frozen fMRI model."""
    model.eval()
    features = {}

    for subj in subject_list:
        if subj not in fmri_act or subj not in fmri_conn:
            continue
        act_t = fmri_act[subj].unsqueeze(0).to(device)
        conn_t = fmri_conn[subj].unsqueeze(0).to(device)

        try:
            _, fused = model(act_t, conn_t, return_features=True)
            features[subj] = fused.squeeze(0).cpu()  # (64,)
        except Exception as e:
            logger.warning(f'fMRI feature extraction failed for subject {subj}: {e}')

    logger.info(f'Extracted fMRI features for {len(features)} subjects, dim={list(features.values())[0].shape[0] if features else "N/A"}')
    return features


# Extract features
logger.info('Extracting EEG fused features...')
eeg_fused_features = extract_eeg_features(eeg_model, bridge_raw_dataset, device)

logger.info('Extracting fMRI fused features...')
fmri_fused_features = extract_fmri_features(
    fmri_model, fmri_act_features, fmri_conn_features, config.overlap_subjects, device
)

# Verify alignment
common_subjects = sorted(set(eeg_fused_features.keys()) & set(fmri_fused_features.keys()) & set(bridge_labels.keys()))
logger.info(f'Common subjects with both EEG and fMRI features: {len(common_subjects)}')

if common_subjects:
    s = common_subjects[0]
    logger.info(f'  Sample subject {s}: EEG={eeg_fused_features[s].shape}, fMRI={fmri_fused_features[s].shape}')

print('Cell 7 complete: Features extracted')

In [None]:
# Cell 8: Bridge Fusion Model

class EEGfMRIBridgeFusionNet(nn.Module):
    """Cross-modality bridge fusion: EEG (128-d) + fMRI (64-d).

    Projects both modalities to a shared space, applies cross-modal
    attention, learned temperature-scaled fusion, and classification.
    """
    def __init__(self, eeg_dim=128, fmri_dim=64, bridge_dim=128,
                 num_classes=2, num_heads=4, dropout=0.3):
        super().__init__()
        self.bridge_dim = bridge_dim

        # Project to shared space
        self.eeg_proj = nn.Sequential(
            nn.Linear(eeg_dim, bridge_dim),
            nn.LayerNorm(bridge_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        self.fmri_proj = nn.Sequential(
            nn.Linear(fmri_dim, bridge_dim),
            nn.LayerNorm(bridge_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # Cross-modal attention (EEG and fMRI attend to each other)
        self.cross_attn = nn.MultiheadAttention(
            bridge_dim, num_heads=num_heads, dropout=dropout, batch_first=True
        )

        # Learned fusion with temperature scaling
        self.fusion = LearnedFusionModule(
            num_modalities=2,
            hidden_dim=bridge_dim,
            use_temperature=True
        )

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(bridge_dim, bridge_dim // 2),
            nn.BatchNorm1d(bridge_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(bridge_dim // 2, num_classes)
        )

    def forward(self, eeg_feats, fmri_feats, return_features=False, return_weights=False):
        """
        Args:
            eeg_feats: (batch, eeg_dim)
            fmri_feats: (batch, fmri_dim)
        Returns:
            logits, and optionally fused features and weights.
        """
        # Project to shared space
        eeg_proj = self.eeg_proj(eeg_feats)    # (batch, bridge_dim)
        fmri_proj = self.fmri_proj(fmri_feats)  # (batch, bridge_dim)

        # Cross-modal attention: stack as sequence of 2 tokens
        modality_seq = torch.stack([eeg_proj, fmri_proj], dim=1)  # (batch, 2, bridge_dim)

        # EEG attends to both modalities
        eeg_q = eeg_proj.unsqueeze(1)  # (batch, 1, bridge_dim)
        attn_out, attn_weights_raw = self.cross_attn(
            eeg_q, modality_seq, modality_seq
        )
        eeg_enhanced = attn_out.squeeze(1)  # (batch, bridge_dim)

        # Learned fusion
        if return_weights:
            fused, fusion_weights = self.fusion(
                [eeg_enhanced, fmri_proj], return_weights=True
            )
        else:
            fused = self.fusion([eeg_enhanced, fmri_proj])
            fusion_weights = None

        # Classify
        logits = self.classifier(fused)

        results = [logits]
        if return_features:
            results.append(fused)
        if return_weights:
            results.append(fusion_weights)
            results.append(attn_weights_raw)

        return results[0] if len(results) == 1 else tuple(results)

    def get_fusion_weights(self):
        with torch.no_grad():
            logits = self.fusion.fusion_logits
            temp = self.fusion.temperature
            weights = F.softmax(logits / temp, dim=0)
            return {
                'eeg_weight': weights[0].item(),
                'fmri_weight': weights[1].item(),
                'temperature': temp.item()
            }


# Quick architecture test
_test_bridge = EEGfMRIBridgeFusionNet(
    eeg_dim=config.eeg_hidden_dim,
    fmri_dim=config.fmri_hidden_dim,
    bridge_dim=config.bridge_hidden_dim,
    num_classes=config.num_classes,
    dropout=config.dropout
)
n_bridge_params = sum(p.numel() for p in _test_bridge.parameters())
n_trainable = sum(p.numel() for p in _test_bridge.parameters() if p.requires_grad)
print(f'Bridge model: {n_bridge_params:,} total params, {n_trainable:,} trainable')

# Smoke test
_eeg_dummy = torch.randn(4, config.eeg_hidden_dim)
_fmri_dummy = torch.randn(4, config.fmri_hidden_dim)
_logits, _fused, _fw, _aw = _test_bridge(_eeg_dummy, _fmri_dummy, return_features=True, return_weights=True)
print(f'Smoke test: logits={_logits.shape}, fused={_fused.shape}, fusion_weights={_fw.shape}, attn_weights={_aw.shape}')
del _test_bridge, _eeg_dummy, _fmri_dummy

print('Cell 8 complete: Bridge fusion model defined')

In [None]:
# Cell 9: Bridge Dataset from Pre-extracted Features

class BridgeFeatureDataset(Dataset):
    """Dataset of pre-extracted EEG and fMRI features, aligned by subject."""
    def __init__(self, eeg_features, fmri_features, labels, subject_list):
        self.samples = []
        for subj in sorted(subject_list):
            if subj in eeg_features and subj in fmri_features and subj in labels:
                self.samples.append({
                    'eeg': eeg_features[subj],
                    'fmri': fmri_features[subj],
                    'label': labels[subj],
                    'subject': subj
                })
        logger.info(f'BridgeFeatureDataset: {len(self.samples)} samples')

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        return s['eeg'], s['fmri'], s['label'], s['subject']


def collate_bridge(batch):
    eeg = torch.stack([b[0] for b in batch])
    fmri = torch.stack([b[1] for b in batch])
    labels = torch.tensor([b[2] for b in batch], dtype=torch.long)
    subjects = [b[3] for b in batch]
    return eeg, fmri, labels, subjects


bridge_dataset = BridgeFeatureDataset(
    eeg_fused_features, fmri_fused_features, bridge_labels, common_subjects
)

all_labels = np.array([s['label'] for s in bridge_dataset.samples])
print(f'Bridge dataset: {len(bridge_dataset)} samples')
print(f'Class distribution: {dict(zip(*np.unique(all_labels, return_counts=True)))}')
print('Cell 9 complete: Bridge feature dataset created')

In [None]:
# Cell 10: Training Loop

def train_bridge_epoch(model, loader, optimizer, criterion, device, grad_clip=1.0):
    model.train()
    total_loss = 0.0
    for eeg, fmri, labels, _ in loader:
        eeg, fmri, labels = eeg.to(device), fmri.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(eeg, fmri)
        loss = criterion(logits, labels)
        loss.backward()
        if grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / max(len(loader), 1)


def evaluate_bridge(model, loader, device):
    model.eval()
    all_preds, all_targets, all_probs = [], [], []
    all_subjects = []
    with torch.no_grad():
        for eeg, fmri, labels, subjects in loader:
            eeg, fmri = eeg.to(device), fmri.to(device)
            logits = model(eeg, fmri)
            probs = F.softmax(logits, dim=1)
            preds = logits.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(labels.numpy())
            all_probs.extend(probs.cpu().numpy())
            all_subjects.extend(subjects)

    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    all_probs = np.array(all_probs)

    metrics = {
        'Accuracy': accuracy_score(all_targets, all_preds),
        'F1': f1_score(all_targets, all_preds, average='weighted', zero_division=0),
        'Precision': precision_score(all_targets, all_preds, average='weighted', zero_division=0),
        'Recall': recall_score(all_targets, all_preds, average='weighted', zero_division=0),
    }
    try:
        metrics['AUC'] = roc_auc_score(all_targets, all_probs[:, 1])
    except Exception:
        metrics['AUC'] = 0.5
    return metrics, all_targets, all_probs, all_subjects


# K-Fold Cross-Validation
cv = StratifiedKFold(n_splits=config.n_splits, shuffle=True, random_state=SEED)
fold_results = []
fold_fusion_weights = []
fold_roc_data = []
all_fold_fused_features = {}  # subject -> fused features (for visualization)

logger.info(f'Starting {config.n_splits}-fold cross-validation')

for fold_idx, (train_idx, test_idx) in enumerate(cv.split(np.zeros(len(bridge_dataset)), all_labels), 1):
    print(f'\n{"="*60}')
    print(f'FOLD {fold_idx}/{config.n_splits}')
    print(f'{"="*60}')

    train_subset = Subset(bridge_dataset, train_idx)
    test_subset = Subset(bridge_dataset, test_idx)

    train_loader = DataLoader(train_subset, batch_size=config.batch_size,
                              shuffle=True, collate_fn=collate_bridge)
    test_loader = DataLoader(test_subset, batch_size=config.batch_size,
                             shuffle=False, collate_fn=collate_bridge)

    # Class weights
    train_labels = all_labels[train_idx]
    cw = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
    cw_tensor = torch.tensor(cw, dtype=torch.float32).to(device)
    criterion = nn.CrossEntropyLoss(weight=cw_tensor)

    # Create bridge model
    bridge_model = EEGfMRIBridgeFusionNet(
        eeg_dim=config.eeg_hidden_dim,
        fmri_dim=config.fmri_hidden_dim,
        bridge_dim=config.bridge_hidden_dim,
        num_classes=config.num_classes,
        dropout=config.dropout
    ).to(device)

    optimizer = optim.AdamW(bridge_model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    best_f1 = -1.0
    best_state = None
    patience_counter = 0

    for epoch in range(1, config.num_epochs + 1):
        train_loss = train_bridge_epoch(bridge_model, train_loader, optimizer, criterion, device, config.grad_clip)
        test_metrics, _, _, _ = evaluate_bridge(bridge_model, test_loader, device)
        scheduler.step(1.0 - test_metrics['F1'])

        if test_metrics['F1'] > best_f1:
            best_f1 = test_metrics['F1']
            best_state = copy.deepcopy(bridge_model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1

        if epoch % 10 == 0:
            logger.info(f'  Fold {fold_idx} Epoch {epoch}: loss={train_loss:.4f}, '
                        f'F1={test_metrics["F1"]:.4f}, Acc={test_metrics["Accuracy"]:.4f}')

        if patience_counter >= config.patience:
            logger.info(f'  Early stopping at epoch {epoch}')
            break

    # Load best model
    if best_state:
        bridge_model.load_state_dict(best_state)

    # Final evaluation
    test_metrics, targets, probs, subjects = evaluate_bridge(bridge_model, test_loader, device)
    fold_results.append(test_metrics)

    # Collect ROC data
    fold_roc_data.append((targets, probs[:, 1]))

    # Collect fusion weights
    fw = bridge_model.get_fusion_weights()
    fold_fusion_weights.append(fw)

    # Extract fused features for visualization
    bridge_model.eval()
    with torch.no_grad():
        for eeg, fmri, labels, subjs in test_loader:
            eeg, fmri = eeg.to(device), fmri.to(device)
            _, fused, _, _ = bridge_model(eeg, fmri, return_features=True, return_weights=True)
            for s, feat in zip(subjs, fused.cpu()):
                all_fold_fused_features[s] = feat

    # Save checkpoint
    ckpt_path = config.checkpoint_dir / f'best_bridge_fold{fold_idx}.pt'
    torch.save({
        'model_state_dict': bridge_model.state_dict(),
        'metrics': test_metrics,
        'fold': fold_idx,
        'fusion_weights': fw
    }, ckpt_path)

    logger.info(f'Fold {fold_idx} Results: Acc={test_metrics["Accuracy"]:.4f}, '
                f'F1={test_metrics["F1"]:.4f}, AUC={test_metrics["AUC"]:.4f}')
    logger.info(f'  Fusion weights: EEG={fw["eeg_weight"]:.3f}, fMRI={fw["fmri_weight"]:.3f}, '
                f'T={fw["temperature"]:.3f}')

# Aggregate results
print(f'\n{"="*60}')
print('BRIDGE FUSION RESULTS SUMMARY')
print(f'{"="*60}')
for metric in ['Accuracy', 'F1', 'Precision', 'Recall', 'AUC']:
    values = [r[metric] for r in fold_results]
    print(f'  {metric:12s}: {np.mean(values):.4f} +/- {np.std(values):.4f}')

eeg_w = [fw['eeg_weight'] for fw in fold_fusion_weights]
fmri_w = [fw['fmri_weight'] for fw in fold_fusion_weights]
print(f'\n  EEG weight:  {np.mean(eeg_w):.4f} +/- {np.std(eeg_w):.4f}')
print(f'  fMRI weight: {np.mean(fmri_w):.4f} +/- {np.std(fmri_w):.4f}')

print('\nCell 10 complete: Training finished')

In [None]:
# Cell 11: Results Visualization

fig_dir = config.output_dir / 'figures'
fig_dir.mkdir(parents=True, exist_ok=True)

# --- 1. Performance Summary Table ---
summary_rows = []
for metric in ['Accuracy', 'F1', 'Precision', 'Recall', 'AUC']:
    values = [r[metric] for r in fold_results]
    summary_rows.append({
        'Metric': metric,
        'Mean': np.mean(values),
        'Std': np.std(values),
        'Summary': f'{np.mean(values):.4f} +/- {np.std(values):.4f}'
    })
summary_df = pd.DataFrame(summary_rows)
summary_df.to_csv(config.output_dir / f'bridge_summary_{timestamp}.csv', index=False)
print('Performance Summary:')
print(summary_df.to_string(index=False))

# --- 2. ROC Curves ---
fig, ax = plt.subplots(figsize=(8, 6))
for fold_idx, (targets, scores) in enumerate(fold_roc_data, 1):
    fpr, tpr, _ = roc_curve(targets, scores)
    auc_val = roc_auc_score(targets, scores) if len(np.unique(targets)) > 1 else 0.5
    ax.plot(fpr, tpr, label=f'Fold {fold_idx} (AUC={auc_val:.3f})')
ax.plot([0, 1], [0, 1], 'k--', alpha=0.5)
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('Bridge Fusion ROC Curves')
ax.legend(loc='lower right')
ax.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(fig_dir / f'roc_curves_{timestamp}.png', dpi=300, bbox_inches='tight')
plt.show()

# --- 3. Confusion Matrices ---
fig, axes = plt.subplots(1, config.n_splits, figsize=(4 * config.n_splits, 4))
if config.n_splits == 1:
    axes = [axes]
for fold_idx, (targets, scores) in enumerate(fold_roc_data):
    preds = (scores >= 0.5).astype(int)
    cm = confusion_matrix(targets, preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[fold_idx],
                xticklabels=['Class 0', 'Class 1'], yticklabels=['Class 0', 'Class 1'])
    axes[fold_idx].set_title(f'Fold {fold_idx + 1}')
    axes[fold_idx].set_ylabel('True')
    axes[fold_idx].set_xlabel('Predicted')
plt.suptitle('Confusion Matrices', fontweight='bold')
plt.tight_layout()
plt.savefig(fig_dir / f'confusion_matrices_{timestamp}.png', dpi=300, bbox_inches='tight')
plt.show()

# --- 4. Fusion Weight Evolution ---
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
folds = list(range(1, config.n_splits + 1))
eeg_weights = [fw['eeg_weight'] for fw in fold_fusion_weights]
fmri_weights = [fw['fmri_weight'] for fw in fold_fusion_weights]

axes[0].plot(folds, eeg_weights, 'o-', label='EEG', color='#2ecc71', linewidth=2, markersize=8)
axes[0].plot(folds, fmri_weights, 's-', label='fMRI', color='#e74c3c', linewidth=2, markersize=8)
axes[0].set_xlabel('Fold')
axes[0].set_ylabel('Weight')
axes[0].set_title('Fusion Weights Across Folds')
axes[0].legend()
axes[0].set_ylim(0, 1)
axes[0].grid(alpha=0.3)

bars = axes[1].bar(['EEG', 'fMRI'],
                    [np.mean(eeg_weights), np.mean(fmri_weights)],
                    yerr=[np.std(eeg_weights), np.std(fmri_weights)],
                    capsize=10, color=['#2ecc71', '#e74c3c'], edgecolor='black', alpha=0.8)
axes[1].set_ylabel('Average Weight')
axes[1].set_title('Average Fusion Weights')
axes[1].set_ylim(0, 1)
for bar, mean in zip(bars, [np.mean(eeg_weights), np.mean(fmri_weights)]):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
                 f'{mean:.3f}', ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig(fig_dir / f'fusion_weights_{timestamp}.png', dpi=300, bbox_inches='tight')
plt.show()

# --- 5. t-SNE of Bridge Fused Features ---
if all_fold_fused_features:
    from sklearn.manifold import TSNE
    feat_subjects = sorted(all_fold_fused_features.keys())
    feat_matrix = np.stack([all_fold_fused_features[s].numpy() for s in feat_subjects])
    feat_labels = np.array([bridge_labels[s] for s in feat_subjects])

    if len(feat_subjects) > 5:
        perplexity = min(30, len(feat_subjects) - 1)
        tsne = TSNE(n_components=2, perplexity=perplexity, random_state=SEED)
        embedded = tsne.fit_transform(feat_matrix)

        fig, ax = plt.subplots(figsize=(8, 6))
        for cls in np.unique(feat_labels):
            mask = feat_labels == cls
            ax.scatter(embedded[mask, 0], embedded[mask, 1],
                       label=f'Class {cls}', alpha=0.7, s=60)
        ax.set_title('t-SNE of Bridge Fused Features')
        ax.set_xlabel('t-SNE 1')
        ax.set_ylabel('t-SNE 2')
        ax.legend()
        ax.grid(alpha=0.3)
        plt.tight_layout()
        plt.savefig(fig_dir / f'tsne_bridge_features_{timestamp}.png', dpi=300, bbox_inches='tight')
        plt.show()

print('Cell 11 complete: Visualizations saved')

In [None]:
# Cell 12: XAI - Gradient Saliency

class BridgeGradientSaliency:
    """Gradient saliency for the bridge model."""
    def __init__(self, model, device):
        self.model = model
        self.device = device

    def compute(self, eeg_feats, fmri_feats, target_class=None):
        self.model.eval()
        eeg_feats = eeg_feats.clone().detach().to(self.device).requires_grad_(True)
        fmri_feats = fmri_feats.clone().detach().to(self.device).requires_grad_(True)

        logits = self.model(eeg_feats, fmri_feats)

        if target_class is None:
            target_class = logits.argmax(dim=1)

        self.model.zero_grad()
        one_hot = torch.zeros_like(logits)
        one_hot.scatter_(1, target_class.view(-1, 1), 1)
        logits.backward(gradient=one_hot)

        return {
            'eeg': eeg_feats.grad.abs().cpu().numpy(),
            'fmri': fmri_feats.grad.abs().cpu().numpy()
        }


# Load the best fold model for XAI
best_fold_idx = np.argmax([r['F1'] for r in fold_results]) + 1
xai_model = EEGfMRIBridgeFusionNet(
    eeg_dim=config.eeg_hidden_dim,
    fmri_dim=config.fmri_hidden_dim,
    bridge_dim=config.bridge_hidden_dim,
    num_classes=config.num_classes,
    dropout=config.dropout
).to(device)

xai_ckpt = torch.load(config.checkpoint_dir / f'best_bridge_fold{best_fold_idx}.pt',
                       map_location=device, weights_only=False)
xai_model.load_state_dict(xai_ckpt['model_state_dict'])
logger.info(f'Loaded best model from fold {best_fold_idx} for XAI')

# Compute gradient saliency for all subjects
saliency = BridgeGradientSaliency(xai_model, device)
all_eeg_saliency = []
all_fmri_saliency = []

for idx in range(len(bridge_dataset)):
    eeg, fmri, label, subj = bridge_dataset[idx]
    result = saliency.compute(eeg.unsqueeze(0), fmri.unsqueeze(0))
    all_eeg_saliency.append(result['eeg'].squeeze())
    all_fmri_saliency.append(result['fmri'].squeeze())

mean_eeg_sal = np.mean(all_eeg_saliency, axis=0)
mean_fmri_sal = np.mean(all_fmri_saliency, axis=0)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].bar(range(len(mean_eeg_sal)), mean_eeg_sal, color='#2ecc71', alpha=0.7)
axes[0].set_title('Gradient Saliency: EEG Features')
axes[0].set_xlabel('Feature Dimension')
axes[0].set_ylabel('Saliency')

axes[1].bar(range(len(mean_fmri_sal)), mean_fmri_sal, color='#e74c3c', alpha=0.7)
axes[1].set_title('Gradient Saliency: fMRI Features')
axes[1].set_xlabel('Feature Dimension')
axes[1].set_ylabel('Saliency')

plt.suptitle('Gradient Saliency Analysis', fontweight='bold')
plt.tight_layout()
plt.savefig(fig_dir / f'gradient_saliency_{timestamp}.png', dpi=300, bbox_inches='tight')
plt.show()

# Modality importance comparison
total_eeg = np.sum(mean_eeg_sal)
total_fmri = np.sum(mean_fmri_sal)
total = total_eeg + total_fmri
print(f'Gradient saliency modality importance:')
print(f'  EEG:  {total_eeg/total:.4f} ({total_eeg:.4f})')
print(f'  fMRI: {total_fmri/total:.4f} ({total_fmri:.4f})')

print('Cell 12 complete: Gradient saliency analysis done')

In [None]:
# Cell 13: XAI - Integrated Gradients

class BridgeIntegratedGradients:
    """Integrated Gradients for the bridge model."""
    def __init__(self, model, device, n_steps=50):
        self.model = model
        self.device = device
        self.n_steps = n_steps

    def compute(self, eeg_feats, fmri_feats, target_class=None):
        self.model.eval()
        eeg_feats = eeg_feats.to(self.device)
        fmri_feats = fmri_feats.to(self.device)

        eeg_baseline = torch.zeros_like(eeg_feats)
        fmri_baseline = torch.zeros_like(fmri_feats)

        eeg_diff = eeg_feats - eeg_baseline
        fmri_diff = fmri_feats - fmri_baseline

        eeg_grads, fmri_grads = [], []

        for alpha in np.linspace(0, 1, self.n_steps):
            eeg_interp = (eeg_baseline + alpha * eeg_diff).requires_grad_(True)
            fmri_interp = (fmri_baseline + alpha * fmri_diff).requires_grad_(True)

            logits = self.model(eeg_interp, fmri_interp)

            if target_class is None:
                target_class = logits.argmax(dim=1)

            self.model.zero_grad()
            one_hot = torch.zeros_like(logits)
            one_hot.scatter_(1, target_class.view(-1, 1), 1)
            logits.backward(gradient=one_hot)

            eeg_grads.append(eeg_interp.grad.detach().cpu().numpy())
            fmri_grads.append(fmri_interp.grad.detach().cpu().numpy())

        eeg_ig = eeg_diff.cpu().numpy() * np.mean(eeg_grads, axis=0)
        fmri_ig = fmri_diff.cpu().numpy() * np.mean(fmri_grads, axis=0)

        return {'eeg': np.abs(eeg_ig), 'fmri': np.abs(fmri_ig)}


ig = BridgeIntegratedGradients(xai_model, device, n_steps=50)

all_eeg_ig = []
all_fmri_ig = []

for idx in range(len(bridge_dataset)):
    eeg, fmri, label, subj = bridge_dataset[idx]
    result = ig.compute(eeg.unsqueeze(0), fmri.unsqueeze(0))
    all_eeg_ig.append(result['eeg'].squeeze())
    all_fmri_ig.append(result['fmri'].squeeze())

mean_eeg_ig = np.mean(all_eeg_ig, axis=0)
mean_fmri_ig = np.mean(all_fmri_ig, axis=0)

# Plot top features
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

top_k = 20
eeg_top_idx = np.argsort(mean_eeg_ig)[-top_k:][::-1]
fmri_top_idx = np.argsort(mean_fmri_ig)[-top_k:][::-1]

axes[0].barh(range(top_k), mean_eeg_ig[eeg_top_idx], color='#2ecc71', alpha=0.7)
axes[0].set_yticks(range(top_k))
axes[0].set_yticklabels([f'EEG-{i}' for i in eeg_top_idx])
axes[0].set_title(f'Top {top_k} EEG Feature Attributions (IG)')
axes[0].set_xlabel('Attribution')
axes[0].invert_yaxis()

axes[1].barh(range(top_k), mean_fmri_ig[fmri_top_idx], color='#e74c3c', alpha=0.7)
axes[1].set_yticks(range(top_k))
axes[1].set_yticklabels([f'fMRI-{i}' for i in fmri_top_idx])
axes[1].set_title(f'Top {top_k} fMRI Feature Attributions (IG)')
axes[1].set_xlabel('Attribution')
axes[1].invert_yaxis()

plt.suptitle('Integrated Gradients Attribution', fontweight='bold')
plt.tight_layout()
plt.savefig(fig_dir / f'integrated_gradients_{timestamp}.png', dpi=300, bbox_inches='tight')
plt.show()

# Modality comparison
total_eeg_ig = np.sum(mean_eeg_ig)
total_fmri_ig = np.sum(mean_fmri_ig)
total_ig = total_eeg_ig + total_fmri_ig
print(f'Integrated Gradients modality importance:')
print(f'  EEG:  {total_eeg_ig/total_ig:.4f}')
print(f'  fMRI: {total_fmri_ig/total_ig:.4f}')

print('Cell 13 complete: Integrated gradients analysis done')

In [None]:
# Cell 14: XAI - SHAP Analysis

try:
    import shap
    SHAP_AVAILABLE = True
except ImportError:
    SHAP_AVAILABLE = False
    print('SHAP not available. Install with: pip install shap')

if SHAP_AVAILABLE:
    # Build wrapper function for SHAP
    def bridge_predict(inputs):
        """Wrapper: concatenated [eeg, fmri] -> class probabilities."""
        inputs_t = torch.tensor(inputs, dtype=torch.float32).to(device)
        eeg_part = inputs_t[:, :config.eeg_hidden_dim]
        fmri_part = inputs_t[:, config.eeg_hidden_dim:]
        xai_model.eval()
        with torch.no_grad():
            logits = xai_model(eeg_part, fmri_part)
            probs = F.softmax(logits, dim=1)
        return probs.cpu().numpy()

    # Prepare data
    all_features = []
    for idx in range(len(bridge_dataset)):
        eeg, fmri, _, _ = bridge_dataset[idx]
        combined = torch.cat([eeg, fmri]).numpy()
        all_features.append(combined)
    all_features = np.array(all_features)

    # Use a subset as background
    n_background = min(20, len(all_features))
    background = all_features[:n_background]

    # Create SHAP explainer
    explainer = shap.KernelExplainer(bridge_predict, background)
    shap_values = explainer.shap_values(all_features, nsamples=100)

    # SHAP values for positive class
    if isinstance(shap_values, list):
        sv = shap_values[1]  # Class 1
    else:
        sv = shap_values

    # Split into EEG and fMRI
    eeg_shap = sv[:, :config.eeg_hidden_dim]
    fmri_shap = sv[:, config.eeg_hidden_dim:]

    # Feature names
    feature_names = ([f'EEG-{i}' for i in range(config.eeg_hidden_dim)] +
                     [f'fMRI-{i}' for i in range(config.fmri_hidden_dim)])

    # SHAP summary plot
    fig, ax = plt.subplots(figsize=(12, 8))
    shap.summary_plot(sv, all_features, feature_names=feature_names,
                      max_display=20, show=False)
    plt.title('SHAP Feature Importance (Top 20)')
    plt.tight_layout()
    plt.savefig(fig_dir / f'shap_summary_{timestamp}.png', dpi=300, bbox_inches='tight')
    plt.show()

    # Modality-level SHAP importance
    eeg_importance = np.mean(np.abs(eeg_shap))
    fmri_importance = np.mean(np.abs(fmri_shap))
    total_shap = eeg_importance + fmri_importance

    fig, ax = plt.subplots(figsize=(6, 4))
    bars = ax.bar(['EEG', 'fMRI'],
                  [eeg_importance/total_shap, fmri_importance/total_shap],
                  color=['#2ecc71', '#e74c3c'], edgecolor='black', alpha=0.8)
    ax.set_ylabel('Relative SHAP Importance')
    ax.set_title('SHAP Modality Importance')
    for bar in bars:
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{bar.get_height():.3f}', ha='center', fontweight='bold')
    plt.tight_layout()
    plt.savefig(fig_dir / f'shap_modality_importance_{timestamp}.png', dpi=300, bbox_inches='tight')
    plt.show()

    print(f'SHAP modality importance: EEG={eeg_importance/total_shap:.4f}, fMRI={fmri_importance/total_shap:.4f}')

    # SHAP force plot for first subject
    print('\nSHAP force plot for subject', bridge_dataset.samples[0]['subject'], ':')
    shap.force_plot(explainer.expected_value[1], sv[0], all_features[0],
                    feature_names=feature_names, matplotlib=True, show=False)
    plt.savefig(fig_dir / f'shap_force_subject1_{timestamp}.png', dpi=300, bbox_inches='tight')
    plt.show()

print('Cell 14 complete: SHAP analysis done')

In [None]:
# Cell 15: XAI - Attention & Fusion Weight Analysis

def extract_attention_and_fusion_weights(model, dataset, device):
    """Extract cross-modal attention weights and fusion weights per subject."""
    model.eval()
    subject_data = []

    with torch.no_grad():
        for idx in range(len(dataset)):
            eeg, fmri, label, subj = dataset[idx]
            eeg_t = eeg.unsqueeze(0).to(device)
            fmri_t = fmri.unsqueeze(0).to(device)

            logits, fused, fusion_weights, attn_weights = model(
                eeg_t, fmri_t, return_features=True, return_weights=True
            )

            subject_data.append({
                'subject': subj,
                'label': label,
                'prediction': logits.argmax(dim=1).item(),
                'fusion_weights': fusion_weights.cpu().numpy().squeeze(),
                'attn_weights': attn_weights.cpu().numpy().squeeze(),
            })

    return subject_data


subject_xai = extract_attention_and_fusion_weights(xai_model, bridge_dataset, device)

# --- Attention Heatmap ---
attn_matrix = np.stack([s['attn_weights'] for s in subject_xai])  # (N, num_heads, 1, 2) or (N, 1, 2)
# Average across subjects and heads
if attn_matrix.ndim == 4:
    mean_attn = np.mean(attn_matrix, axis=(0, 1))  # (1, 2)
elif attn_matrix.ndim == 3:
    mean_attn = np.mean(attn_matrix, axis=0)  # (1, 2)
else:
    mean_attn = attn_matrix.reshape(-1, 2).mean(axis=0, keepdims=True)

fig, ax = plt.subplots(figsize=(6, 3))
sns.heatmap(mean_attn.reshape(1, -1), annot=True, fmt='.3f', cmap='YlOrRd',
            xticklabels=['EEG', 'fMRI'], yticklabels=['Query'], ax=ax)
ax.set_title('Cross-Modal Attention Weights (Mean)')
plt.tight_layout()
plt.savefig(fig_dir / f'attention_heatmap_{timestamp}.png', dpi=300, bbox_inches='tight')
plt.show()

# --- Per-Subject Fusion Weights ---
fusion_weights_arr = np.stack([s['fusion_weights'] for s in subject_xai])
# fusion_weights_arr shape: (N, 2)
if fusion_weights_arr.ndim == 1:
    fusion_weights_arr = fusion_weights_arr.reshape(-1, 2)

fig, ax = plt.subplots(figsize=(12, 5))
subjects_list = [s['subject'] for s in subject_xai]
x = np.arange(len(subjects_list))
width = 0.35

ax.bar(x - width/2, fusion_weights_arr[:, 0], width, label='EEG', color='#2ecc71', alpha=0.8)
ax.bar(x + width/2, fusion_weights_arr[:, 1], width, label='fMRI', color='#e74c3c', alpha=0.8)
ax.set_xlabel('Subject')
ax.set_ylabel('Fusion Weight')
ax.set_title('Per-Subject Dynamic Fusion Weights')
ax.set_xticks(x)
ax.set_xticklabels(subjects_list, rotation=45)
ax.legend()
ax.grid(alpha=0.3, axis='y')
plt.tight_layout()
plt.savefig(fig_dir / f'per_subject_fusion_weights_{timestamp}.png', dpi=300, bbox_inches='tight')
plt.show()

# --- Class-wise Fusion Weight Comparison ---
class0_mask = np.array([s['label'] == 0 for s in subject_xai])
class1_mask = np.array([s['label'] == 1 for s in subject_xai])

print('Fusion weights by class:')
if class0_mask.any():
    c0_eeg = fusion_weights_arr[class0_mask, 0].mean()
    c0_fmri = fusion_weights_arr[class0_mask, 1].mean()
    print(f'  Class 0: EEG={c0_eeg:.4f}, fMRI={c0_fmri:.4f}')
if class1_mask.any():
    c1_eeg = fusion_weights_arr[class1_mask, 0].mean()
    c1_fmri = fusion_weights_arr[class1_mask, 1].mean()
    print(f'  Class 1: EEG={c1_eeg:.4f}, fMRI={c1_fmri:.4f}')

print('Cell 15 complete: Attention & fusion weight analysis done')

In [None]:
# Cell 16: Summary & Export

print('=' * 70)
print('EEG-fMRI BRIDGE FUSION - FINAL SUMMARY')
print('=' * 70)

# Performance table
print(f'\nDataset: {len(bridge_dataset)} subjects (overlap of EEG & fMRI)')
print(f'Cross-validation: {config.n_splits}-fold stratified')
print(f'\nBridge Fusion Performance:')
for metric in ['Accuracy', 'F1', 'Precision', 'Recall', 'AUC']:
    values = [r[metric] for r in fold_results]
    print(f'  {metric:12s}: {np.mean(values):.4f} +/- {np.std(values):.4f}')

print(f'\nLearned Fusion Weights:')
print(f'  EEG:  {np.mean(eeg_w):.4f} +/- {np.std(eeg_w):.4f}')
print(f'  fMRI: {np.mean(fmri_w):.4f} +/- {np.std(fmri_w):.4f}')

# Save all results
results_dict = {
    'summary': summary_df.to_dict(),
    'fold_results': fold_results,
    'fold_fusion_weights': fold_fusion_weights,
    'gradient_saliency': {
        'eeg_mean': mean_eeg_sal.tolist(),
        'fmri_mean': mean_fmri_sal.tolist()
    },
    'integrated_gradients': {
        'eeg_mean': mean_eeg_ig.tolist(),
        'fmri_mean': mean_fmri_ig.tolist()
    },
    'per_subject_fusion_weights': [
        {'subject': s['subject'], 'label': s['label'],
         'eeg_weight': float(s['fusion_weights'][0]),
         'fmri_weight': float(s['fusion_weights'][1])}
        for s in subject_xai
    ]
}

# Save as CSV
fold_df = pd.DataFrame(fold_results)
fold_df['Fold'] = range(1, config.n_splits + 1)
fold_df.to_csv(config.output_dir / f'bridge_fold_results_{timestamp}.csv', index=False)

fw_df = pd.DataFrame(fold_fusion_weights)
fw_df['Fold'] = range(1, config.n_splits + 1)
fw_df.to_csv(config.output_dir / f'bridge_fusion_weights_{timestamp}.csv', index=False)

subj_fw_df = pd.DataFrame(results_dict['per_subject_fusion_weights'])
subj_fw_df.to_csv(config.output_dir / f'bridge_subject_fusion_weights_{timestamp}.csv', index=False)

# Save XAI arrays
np.savez(
    config.output_dir / f'bridge_xai_arrays_{timestamp}.npz',
    gradient_saliency_eeg=mean_eeg_sal,
    gradient_saliency_fmri=mean_fmri_sal,
    integrated_gradients_eeg=mean_eeg_ig,
    integrated_gradients_fmri=mean_fmri_ig,
    per_subject_fusion_weights=fusion_weights_arr,
)

logger.info(f'All results saved to {config.output_dir}')
logger.info(f'Figures saved to {fig_dir}')
logger.info(f'Checkpoints saved to {config.checkpoint_dir}')

print(f'\nOutput directory: {config.output_dir}')
print(f'Figures directory: {fig_dir}')
print(f'Checkpoints directory: {config.checkpoint_dir}')
print(f'Timestamp: {timestamp}')
print('\nBridge fusion pipeline complete.')