# Multi-Subject EEG-fMRI Prediction Pipeline

## Cell 1: Import Libraries and Setup

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import mne
import json
from datetime import datetime
from scipy.signal import stft
from scipy.interpolate import interp1d
from scipy.stats import pearsonr
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Ridge, ElasticNet
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_squared_error
import xgboost as xgb
import joblib
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-paper')
sns.set_palette("husl")

# Define paths
EDF_PATH = '../edfs/'
CSV_PATH = '../task_output/'

## Cell 2: Data Loading and Organization Functions

In [2]:
def load_all_data_files():
    """
    Load all EDF and CSV files and organize by subject and run.
    """
    # Get all EDF files
    edf_files = glob.glob(os.path.join(EDF_PATH, '*.edf'))
    csv_files = glob.glob(os.path.join(CSV_PATH, '*.csv'))
    
    print(f"Found {len(edf_files)} EDF files and {len(csv_files)} CSV files")
    
    # Extract subject info
    subjects_data = {}
    
    for edf_file in edf_files:
        # Extract subject ID and run
        basename = os.path.basename(edf_file)
        
        # Extract subject ID (e.g., 'dmnelf001' from 'sub-dmnelf001_task-feedback-run01.edf')
        if 'sub-' in basename:
            subject_id = basename.split('_')[0].replace('sub-', '')
        else:
            continue
            
        # Extract run number from format like 'run01' or 'run-01'
        import re
        run_match = re.search(r'run[-]?(\d+)', basename)
        if run_match:
            run_num = run_match.group(1).zfill(2)  # Ensure 2 digits
        else:
            continue
            
        # Find matching CSV file
        # Your CSV format: sub-dmnelf001_DMN_Feedback_run01_roi_outputs.csv
        # Try different patterns
        possible_patterns = [
            f"sub-{subject_id}_DMN_Feedback_run{run_num}_roi_outputs.csv",
            f"sub-{subject_id}_*_run{run_num}_roi_outputs.csv",
            f"sub-{subject_id}_*run{run_num}*.csv"
        ]
        
        matching_csv = None
        for pattern in possible_patterns:
            for csv_file in csv_files:
                csv_basename = os.path.basename(csv_file)
                # Direct match
                if csv_basename == pattern.replace('*', 'DMN_Feedback_'):
                    matching_csv = csv_file
                    break
                # Pattern match for wildcards
                if '*' in pattern:
                    pattern_regex = pattern.replace('*', '.*')
                    if re.match(pattern_regex, csv_basename):
                        matching_csv = csv_file
                        break
            if matching_csv:
                break
        
        if not matching_csv:
            print(f"Warning: No CSV found for {basename}")
            print(f"  Looked for patterns: {possible_patterns}")
            continue
            
        # Initialize subject data structure
        if subject_id not in subjects_data:
            subjects_data[subject_id] = {}
            
        # Store file paths
        subjects_data[subject_id][f'run{run_num}'] = {
            'edf': edf_file,
            'csv': matching_csv
        }
    
    # Print summary
    print("\n" + "="*70)
    print("DATA LOADING SUMMARY")
    print("="*70)
    print(f"Total subjects found: {len(subjects_data)}")
    
    for subject_id, runs in subjects_data.items():
        print(f"\n{subject_id}:")
        for run, files in runs.items():
            print(f"  {run}: EDF ✓, CSV ✓")
    
    return subjects_data


def select_subjects_for_analysis(subjects_data):
    """
    Interactive selection of subjects for training and testing.
    """
    print("\n" + "="*70)
    print("SUBJECT SELECTION")
    print("="*70)
    
    all_subjects = list(subjects_data.keys())
    print("\nAvailable subjects:")
    for i, subj in enumerate(all_subjects):
        n_runs = len(subjects_data[subj])
        print(f"{i+1}. {subj} ({n_runs} runs)")
    
    # Option 1: Select specific subjects
    print("\nSelect option:")
    print("1. Use all subjects with leave-one-out")
    print("2. Select specific train/test split")
    print("3. Select subset of subjects")
    
    choice = input("\nEnter choice (1-3) [default=1]: ").strip() or "1"
    
    if choice == "1":
        # Use all subjects
        selected_subjects = all_subjects
        test_mode = 'leave_one_out'
        test_subjects = None
        
    elif choice == "2":
        # Manual train/test split
        print("\nSelect TRAINING subjects (comma-separated numbers):")
        train_idx = input("Training subjects: ").strip()
        train_idx = [int(i)-1 for i in train_idx.split(',')]
        train_subjects = [all_subjects[i] for i in train_idx]
        
        print("\nSelect TEST subjects (comma-separated numbers):")
        test_idx = input("Test subjects: ").strip()
        test_idx = [int(i)-1 for i in test_idx.split(',')]
        test_subjects = [all_subjects[i] for i in test_idx]
        
        selected_subjects = train_subjects + test_subjects
        test_mode = 'manual_split'
        
    else:
        # Select subset
        print("\nSelect subjects to include (comma-separated numbers):")
        subset_idx = input("Subjects: ").strip()
        subset_idx = [int(i)-1 for i in subset_idx.split(',')]
        selected_subjects = [all_subjects[i] for i in subset_idx]
        test_mode = 'leave_one_out'
        test_subjects = None
    
    print(f"\nSelected {len(selected_subjects)} subjects for analysis")
    print(f"Test mode: {test_mode}")
    
    return selected_subjects, test_mode, test_subjects


# Load all data
subjects_data = load_all_data_files()

# Select subjects
selected_subjects, test_mode, test_subjects = select_subjects_for_analysis(subjects_data)

DATA LOADING SUMMARY
Total subjects found: 0

SUBJECT SELECTION

Available subjects:

Select option:
1. Use all subjects with leave-one-out
2. Select specific train/test split
3. Select subset of subjects


KeyboardInterrupt: Interrupted by user

## Cell 3: Import Single-Subject Processing Functions

In [None]:
# Import all the processing functions from your single-subject pipeline
# (You can copy these from your original notebook or import from a module)

# Include these functions:
# - automated_channel_cleaning
# - process_eeg_file
# - align_pda_to_eeg
# - compute_lz_features_matlab_style
# - extract_advanced_eeg_features
# - etc.

# For now, I'll create a placeholder
print("Import your single-subject processing functions here")

## Cell 4: Multi-Subject Processing Pipeline

In [None]:
def process_subject_data(subject_id, runs_data, processing_params):
    """
    Process all runs for a single subject.
    """
    subject_features = []
    subject_targets = []
    subject_metadata = []
    
    print(f"\nProcessing subject: {subject_id}")
    
    for run_id, file_paths in runs_data.items():
        print(f"  Processing {run_id}...")
        
        try:
            # 1. Load and clean EEG data
            raw = mne.io.read_raw_edf(file_paths['edf'], preload=True)
            
            # Apply automated cleaning
            raw_processed, bad_channels, qc_stats = automated_channel_cleaning(
                raw, 
                z_score_threshold=processing_params['z_score_threshold'],
                correlation_threshold=processing_params['correlation_threshold'],
                powerline_threshold=processing_params['powerline_threshold'],
                save_report=False
            )
            
            # 2. Load PDA data
            pda_df = pd.read_csv(file_paths['csv'])
            if 'cen' in pda_df.columns and 'dmn' in pda_df.columns:
                pda_signal = (pda_df['cen'] - pda_df['dmn']).values
            else:
                print(f"    Warning: Missing CEN/DMN columns in {run_id}")
                continue
            
            # 3. Align PDA to EEG
            pda_aligned_z, pda_time_aligned = align_pda_to_eeg(
                pda_signal, raw_processed,
                hrf_delay=processing_params['hrf_delay'],
                target_fs=processing_params['target_fs']
            )
            
            # 4. Extract features
            # Get EEG data
            picks_eeg = mne.pick_types(raw_processed.info, eeg=True)
            eeg_data = raw_processed.get_data(picks=picks_eeg)
            fs = raw_processed.info['sfreq']
            
            # Compute LZ complexity features
            lz_features = compute_lz_features_matlab_style(
                eeg_data, fs,
                window_length=processing_params['lz_window_length'],
                overlap=processing_params['lz_overlap'],
                complexity_type='exhaustive',
                use_fast=True
            )
            
            # Compute STFT features
            nperseg = int(processing_params['stft_window'] * fs)
            noverlap = int(processing_params['stft_overlap'] * fs)
            
            f, t_stft, Zxx = stft(eeg_data, fs=fs, nperseg=nperseg, 
                                  noverlap=noverlap, axis=1)
            power = np.abs(Zxx) ** 2
            power = power.transpose(2, 0, 1)
            
            # Extract advanced features
            advanced_features = extract_advanced_eeg_features(
                power, f, t_stft, raw_processed, 
                pda_aligned_z, pda_time_aligned
            )
            
            # 5. Combine all features
            run_features = {
                **lz_features['aligned_features'],
                **advanced_features['aligned_features']
            }
            
            # Create feature matrix
            feature_names = list(run_features.keys())
            X_run = np.column_stack([run_features[feat] for feat in feature_names])
            
            # Store results
            subject_features.append(X_run)
            subject_targets.append(pda_aligned_z)
            subject_metadata.append({
                'subject_id': subject_id,
                'run_id': run_id,
                'n_samples': len(pda_aligned_z),
                'feature_names': feature_names,
                'bad_channels': bad_channels,
                'n_good_channels': len(picks_eeg)
            })
            
            print(f"    ✓ Extracted {X_run.shape[1]} features, {X_run.shape[0]} samples")
            
        except Exception as e:
            print(f"    ✗ Error processing {run_id}: {str(e)}")
            continue
    
    return {
        'features': subject_features,
        'targets': subject_targets,
        'metadata': subject_metadata
    }


# Define processing parameters
processing_params = {
    'z_score_threshold': 5.0,
    'correlation_threshold': 0.2,
    'powerline_threshold': 15.0,
    'hrf_delay': 5.0,
    'target_fs': 1.0,
    'lz_window_length': 2.0,
    'lz_overlap': 0.5,
    'stft_window': 1.0,
    'stft_overlap': 0.5
}

# Process all selected subjects
all_subjects_data = {}

for subject_id in tqdm(selected_subjects, desc="Processing subjects"):
    if subject_id in subjects_data:
        subject_results = process_subject_data(
            subject_id, 
            subjects_data[subject_id],
            processing_params
        )
        all_subjects_data[subject_id] = subject_results

print(f"\n✓ Processed {len(all_subjects_data)} subjects successfully")

## Cell 5: Feature Alignment and Standardization

In [None]:
def align_features_across_subjects(all_subjects_data):
    """
    Ensure all subjects have the same features in the same order.
    """
    print("\n" + "="*70)
    print("FEATURE ALIGNMENT")
    print("="*70)
    
    # Collect all unique feature names
    all_feature_names = set()
    
    for subject_id, subject_data in all_subjects_data.items():
        for metadata in subject_data['metadata']:
            all_feature_names.update(metadata['feature_names'])
    
    # Sort for consistency
    common_features = sorted(list(all_feature_names))
    print(f"Total unique features across all subjects: {len(common_features)}")
    
    # Align features for each subject/run
    aligned_data = {}
    
    for subject_id, subject_data in all_subjects_data.items():
        aligned_features = []
        aligned_targets = []
        aligned_metadata = []
        
        for run_idx, (features, target, metadata) in enumerate(zip(
            subject_data['features'],
            subject_data['targets'],
            subject_data['metadata']
        )):
            # Create feature dictionary for this run
            run_features = dict(zip(metadata['feature_names'], features.T))
            
            # Create aligned feature matrix
            aligned_X = []
            for feat_name in common_features:
                if feat_name in run_features:
                    aligned_X.append(run_features[feat_name])
                else:
                    # Fill missing features with zeros
                    aligned_X.append(np.zeros(len(target)))
            
            aligned_X = np.column_stack(aligned_X)
            
            aligned_features.append(aligned_X)
            aligned_targets.append(target)
            
            # Update metadata
            metadata['feature_names'] = common_features
            aligned_metadata.append(metadata)
        
        aligned_data[subject_id] = {
            'features': aligned_features,
            'targets': aligned_targets,
            'metadata': aligned_metadata
        }
    
    return aligned_data, common_features


# Align features
aligned_data, common_features = align_features_across_subjects(all_subjects_data)

## Cell 6: Create Training and Testing Sets

In [None]:
def create_train_test_sets(aligned_data, test_mode, test_subjects=None):
    """
    Create train/test sets based on the selected mode.
    """
    if test_mode == 'leave_one_out':
        # Create leave-one-subject-out splits
        splits = []
        
        for test_subject in aligned_data.keys():
            train_subjects = [s for s in aligned_data.keys() if s != test_subject]
            
            # Combine training data
            X_train = []
            y_train = []
            train_groups = []
            
            for subj_idx, train_subj in enumerate(train_subjects):
                for run_features, run_target in zip(
                    aligned_data[train_subj]['features'],
                    aligned_data[train_subj]['targets']
                ):
                    X_train.append(run_features)
                    y_train.append(run_target)
                    train_groups.extend([subj_idx] * len(run_target))
            
            X_train = np.vstack(X_train)
            y_train = np.hstack(y_train)
            train_groups = np.array(train_groups)
            
            # Test data
            X_test = []
            y_test = []
            
            for run_features, run_target in zip(
                aligned_data[test_subject]['features'],
                aligned_data[test_subject]['targets']
            ):
                X_test.append(run_features)
                y_test.append(run_target)
            
            X_test = np.vstack(X_test)
            y_test = np.hstack(y_test)
            
            splits.append({
                'test_subject': test_subject,
                'train_subjects': train_subjects,
                'X_train': X_train,
                'y_train': y_train,
                'X_test': X_test,
                'y_test': y_test,
                'train_groups': train_groups
            })
        
        return splits
    
    elif test_mode == 'manual_split':
        # Single train/test split
        train_subjects = [s for s in aligned_data.keys() if s not in test_subjects]
        
        # Combine training data
        X_train = []
        y_train = []
        
        for train_subj in train_subjects:
            for run_features, run_target in zip(
                aligned_data[train_subj]['features'],
                aligned_data[train_subj]['targets']
            ):
                X_train.append(run_features)
                y_train.append(run_target)
        
        X_train = np.vstack(X_train)
        y_train = np.hstack(y_train)
        
        # Test data
        X_test = []
        y_test = []
        
        for test_subj in test_subjects:
            for run_features, run_target in zip(
                aligned_data[test_subj]['features'],
                aligned_data[test_subj]['targets']
            ):
                X_test.append(run_features)
                y_test.append(run_target)
        
        X_test = np.vstack(X_test)
        y_test = np.hstack(y_test)
        
        return [{
            'test_subjects': test_subjects,
            'train_subjects': train_subjects,
            'X_train': X_train,
            'y_train': y_train,
            'X_test': X_test,
            'y_test': y_test
        }]


# Create train/test splits
splits = create_train_test_sets(aligned_data, test_mode, test_subjects)
print(f"\nCreated {len(splits)} train/test splits")

## Cell 7: Feature Selection and Model Training

In [None]:
def select_top_features_multi(X_train, y_train, feature_names, n_features=20):
    """
    Select top features based on correlation with target across training data.
    """
    correlations = {}
    
    for i, feat_name in enumerate(feature_names):
        feat_data = X_train[:, i]
        if np.std(feat_data) > 1e-10:
            corr, _ = pearsonr(feat_data, y_train)
            correlations[i] = abs(corr)
        else:
            correlations[i] = 0
    
    # Sort by correlation
    sorted_idx = sorted(correlations.keys(), key=lambda x: correlations[x], reverse=True)
    top_idx = sorted_idx[:n_features]
    top_features = [feature_names[i] for i in top_idx]
    
    return top_idx, top_features, correlations


def train_multi_subject_models(splits, common_features, n_features=20):
    """
    Train models for each split and evaluate performance.
    """
    print("\n" + "="*70)
    print("MULTI-SUBJECT MODEL TRAINING")
    print("="*70)
    
    results = []
    
    for split_idx, split in enumerate(splits):
        if 'test_subject' in split:
            print(f"\nFold {split_idx + 1}: Testing on {split['test_subject']}")
        else:
            print(f"\nTesting on: {split['test_subjects']}")
        
        # Feature selection on training data
        top_idx, top_features, _ = select_top_features_multi(
            split['X_train'], split['y_train'], common_features, n_features
        )
        
        # Select features
        X_train_selected = split['X_train'][:, top_idx]
        X_test_selected = split['X_test'][:, top_idx]
        
        # Standardize
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train_selected)
        X_test_scaled = scaler.transform(X_test_selected)
        
        # Train models
        models = {
            'Ridge': Ridge(alpha=1.0),
            'ElasticNet': ElasticNet(alpha=0.1, l1_ratio=0.5, max_iter=2000),
            'Random Forest': RandomForestRegressor(n_estimators=100, max_depth=5, 
                                                 random_state=42, n_jobs=-1),
            'XGBoost': xgb.XGBRegressor(n_estimators=100, max_depth=3, 
                                       learning_rate=0.1, random_state=42, n_jobs=-1)
        }
        
        fold_results = {}
        
        for model_name, model in models.items():
            # Train
            model.fit(X_train_scaled, split['y_train'])
            
            # Predict
            y_pred_train = model.predict(X_train_scaled)
            y_pred_test = model.predict(X_test_scaled)
            
            # Evaluate
            train_corr, _ = pearsonr(split['y_train'], y_pred_train)
            test_corr, _ = pearsonr(split['y_test'], y_pred_test)
            test_r2 = r2_score(split['y_test'], y_pred_test)
            test_rmse = np.sqrt(mean_squared_error(split['y_test'], y_pred_test))
            
            fold_results[model_name] = {
                'train_corr': train_corr,
                'test_corr': test_corr,
                'test_r2': test_r2,
                'test_rmse': test_rmse,
                'predictions': y_pred_test,
                'model': model
            }
            
            print(f"  {model_name}: test r={test_corr:.3f}, R²={test_r2:.3f}")
        
        results.append({
            'split_info': split,
            'top_features': top_features,
            'scaler': scaler,
            'results': fold_results
        })
    
    return results


# Train models
model_results = train_multi_subject_models(splits, common_features, n_features=20)

## Cell 8: Results Visualization

In [None]:
def visualize_multi_subject_results(model_results):
    """
    Create comprehensive visualization of multi-subject results.
    """
    fig = plt.figure(figsize=(16, 12))
    
    # 1. Model performance across subjects
    ax1 = plt.subplot(3, 2, 1)
    
    model_names = list(model_results[0]['results'].keys())
    n_splits = len(model_results)
    
    # Collect correlations
    correlations = {model: [] for model in model_names}
    
    for split_result in model_results:
        for model_name in model_names:
            correlations[model_name].append(
                split_result['results'][model_name]['test_corr']
            )
    
    # Box plot
    positions = np.arange(len(model_names))
    bp = ax1.boxplot([correlations[m] for m in model_names], 
                     positions=positions, widths=0.6)
    
    ax1.set_xticklabels(model_names, rotation=45, ha='right')
    ax1.set_ylabel('Test Correlation')
    ax1.set_title('Model Performance Across Subjects')
    ax1.grid(True, alpha=0.3)
    
    # Add mean values
    means = [np.mean(correlations[m]) for m in model_names]
    ax1.scatter(positions, means, color='red', s=100, zorder=3, label='Mean')
    ax1.legend()
    
    # 2. Subject-specific performance (best model)
    ax2 = plt.subplot(3, 2, 2)
    
    # Find best model
    mean_corrs = {m: np.mean(correlations[m]) for m in model_names}
    best_model = max(mean_corrs.keys(), key=lambda x: mean_corrs[x])
    
    # Plot individual subject results
    if 'test_subject' in model_results[0]['split_info']:
        subjects = [r['split_info']['test_subject'] for r in model_results]
        subject_corrs = [r['results'][best_model]['test_corr'] for r in model_results]
        
        bars = ax2.bar(range(len(subjects)), subject_corrs, 
                       color='skyblue', edgecolor='navy')
        ax2.set_xticks(range(len(subjects)))
        ax2.set_xticklabels(subjects, rotation=45, ha='right')
        ax2.set_ylabel('Test Correlation')
        ax2.set_title(f'Per-Subject Performance ({best_model})')
        ax2.axhline(y=np.mean(subject_corrs), color='red', 
                   linestyle='--', label=f'Mean: {np.mean(subject_corrs):.3f}')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
    
    # 3. Feature importance (aggregate)
    ax3 = plt.subplot(3, 2, 3)
    
    # Aggregate feature importance from best model
    feature_importance_sum = {}
    
    for result in model_results:
        model = result['results'][best_model]['model']
        if hasattr(model, 'feature_importances_'):
            for i, feat in enumerate(result['top_features']):
                if feat not in feature_importance_sum:
                    feature_importance_sum[feat] = 0
                feature_importance_sum[feat] += model.feature_importances_[i]
    
    # Sort and plot top features
    if feature_importance_sum:
        sorted_features = sorted(feature_importance_sum.items(), 
                               key=lambda x: x[1], reverse=True)[:15]
        features = [f[0] for f in sorted_features]
        importances = [f[1] for f in sorted_features]
        
        y_pos = np.arange(len(features))
        ax3.barh(y_pos, importances, color='lightcoral')
        ax3.set_yticks(y_pos)
        ax3.set_yticklabels([f[:30] + '...' if len(f) > 30 else f 
                            for f in features], fontsize=8)
        ax3.set_xlabel('Cumulative Importance')
        ax3.set_title(f'Top Features Across All Folds ({best_model})')
        ax3.grid(True, alpha=0.3, axis='x')
    
    # 4. Example prediction plot
    ax4 = plt.subplot(3, 2, 4)
    
    # Use first split as example
    example_result = model_results[0]
    y_true = example_result['split_info']['y_test']
    y_pred = example_result['results'][best_model]['predictions']
    
    # Subsample for clarity
    max_points = min(300, len(y_true))
    ax4.plot(y_true[:max_points], 'k-', label='Actual', linewidth=1.5, alpha=0.8)
    ax4.plot(y_pred[:max_points], 'r--', label='Predicted', linewidth=1.5, alpha=0.8)
    ax4.set_xlabel('Time Points')
    ax4.set_ylabel('PDA (z-score)')
    ax4.set_title('Example Prediction (First 300 samples)')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    # 5. Cross-subject generalization matrix
    ax5 = plt.subplot(3, 2, 5)
    
    if len(model_results) > 1 and 'test_subject' in model_results[0]['split_info']:
        # Create correlation matrix
        n_subjects = len(model_results)
        corr_matrix = np.zeros((n_subjects, n_subjects))
        
        subjects = [r['split_info']['test_subject'] for r in model_results]
        
        for i, test_subj in enumerate(subjects):
            for j, train_result in enumerate(model_results):
                if i == j:
                    corr_matrix[i, j] = train_result['results'][best_model]['test_corr']
                else:
                    # This would require additional cross-prediction
                    corr_matrix[i, j] = np.nan
        
        im = ax5.imshow(corr_matrix, cmap='coolwarm', vmin=0, vmax=1)
        ax5.set_xticks(range(n_subjects))
        ax5.set_yticks(range(n_subjects))
        ax5.set_xticklabels(subjects, rotation=45, ha='right')
        ax5.set_yticklabels(subjects)
        ax5.set_xlabel('Test Subject')
        ax5.set_ylabel('Test Subject')
        ax5.set_title('Generalization Performance')
        plt.colorbar(im, ax=ax5, fraction=0.046, pad=0.04)
    
    # 6. Summary statistics
    ax6 = plt.subplot(3, 2, 6)
    ax6.axis('off')
    
    # Calculate summary stats
    summary_text = f"""
    MULTI-SUBJECT ANALYSIS SUMMARY
    
    Total Subjects: {len(model_results)}
    Features Used: {len(model_results[0]['top_features'])}
    
    Best Model: {best_model}
    Mean Correlation: {mean_corrs[best_model]:.3f}
    Std Correlation: {np.std(correlations[best_model]):.3f}
    
    Performance Range:
    Best Subject: {max(correlations[best_model]):.3f}
    Worst Subject: {min(correlations[best_model]):.3f}
    
    All Models Mean±Std:
    """
    
    for model in model_names:
        summary_text += f"\n{model}: {np.mean(correlations[model]):.3f}±{np.std(correlations[model]):.3f}"
    
    ax6.text(0.1, 0.9, summary_text, transform=ax6.transAxes,
             fontsize=10, verticalalignment='top', fontfamily='monospace',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.suptitle('Multi-Subject EEG-fMRI Prediction Results', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    return correlations, best_model


# Visualize results
correlations, best_model = visualize_multi_subject_results(model_results)

## Cell 9: Save Final Model and Results

In [None]:
def save_multi_subject_model(model_results, best_model, common_features, save_path='multi_subject_model.pkl'):
    """
    Save the trained models and necessary information for future predictions.
    """
    # Calculate average performance
    all_correlations = []
    all_models = []
    
    for result in model_results:
        all_correlations.append(result['results'][best_model]['test_corr'])
        all_models.append(result['results'][best_model]['model'])
    
    # Save comprehensive model data
    model_data = {
        'model_type': best_model,
        'models': all_models,  # All fold models
        'feature_names': common_features,
        'top_features_per_fold': [r['top_features'] for r in model_results],
        'scalers': [r['scaler'] for r in model_results],
        'performance': {
            'mean_correlation': np.mean(all_correlations),
            'std_correlation': np.std(all_correlations),
            'per_fold_correlation': all_correlations
        },
        'processing_params': processing_params,
        'subjects_included': selected_subjects,
        'timestamp': datetime.now().strftime("%Y%m%d_%H%M%S")
    }
    
    # Save
    joblib.dump(model_data, save_path)
    print(f"\nModel saved to: {save_path}")
    
    # Create summary report
    report_path = save_path.replace('.pkl', '_report.txt')
    with open(report_path, 'w') as f:
        f.write("MULTI-SUBJECT EEG-fMRI PREDICTION MODEL REPORT\n")
        f.write("="*70 + "\n\n")
        f.write(f"Created: {model_data['timestamp']}\n")
        f.write(f"Model Type: {best_model}\n")
        f.write(f"Subjects: {', '.join(selected_subjects)}\n")
        f.write(f"Mean Performance: r = {model_data['performance']['mean_correlation']:.3f} "
                f"± {model_data['performance']['std_correlation']:.3f}\n\n")
        
        f.write("Per-Subject Performance:\n")
        for i, (subj, corr) in enumerate(zip(selected_subjects, all_correlations)):
            f.write(f"  {subj}: r = {corr:.3f}\n")
        
        f.write(f"\nTop 10 Features (most frequent):\n")
        feature_counts = {}
        for features in model_data['top_features_per_fold']:
            for feat in features[:10]:
                feature_counts[feat] = feature_counts.get(feat, 0) + 1
        
        sorted_features = sorted(feature_counts.items(), key=lambda x: x[1], reverse=True)
        for feat, count in sorted_features[:10]:
            f.write(f"  {feat}: {count}/{len(model_results)} folds\n")
    
    print(f"Report saved to: {report_path}")
    
    return model_data


# Save the model
saved_model = save_multi_subject_model(model_results, best_model, common_features)

## Cell 10: Prediction Function for New Data

In [None]:
def predict_new_subject(edf_path, csv_path, model_data):
    """
    Make predictions for a new subject using the trained multi-subject model.
    """
    print("Processing new subject data...")
    
    # Load and process EEG
    raw = mne.io.read_raw_edf(edf_path, preload=True)
    
    # Clean EEG
    raw_processed, _, _ = automated_channel_cleaning(
        raw,
        z_score_threshold=model_data['processing_params']['z_score_threshold'],
        correlation_threshold=model_data['processing_params']['correlation_threshold'],
        powerline_threshold=model_data['processing_params']['powerline_threshold'],
        save_report=False
    )
    
    # Load PDA
    pda_df = pd.read_csv(csv_path)
    pda_signal = (pda_df['cen'] - pda_df['dmn']).values
    
    # Align PDA
    pda_aligned_z, pda_time_aligned = align_pda_to_eeg(
        pda_signal, raw_processed,
        hrf_delay=model_data['processing_params']['hrf_delay'],
        target_fs=model_data['processing_params']['target_fs']
    )
    
    # Extract features (same as training)
    # ... (feature extraction code)
    
    # Make predictions using ensemble of models
    predictions = []
    
    for i, (model, scaler, top_features) in enumerate(zip(
        model_data['models'],
        model_data['scalers'],
        model_data['top_features_per_fold']
    )):
        # Select and scale features
        # X_selected = ... (select top features)
        # X_scaled = scaler.transform(X_selected)
        
        # Predict
        # y_pred = model.predict(X_scaled)
        # predictions.append(y_pred)
        pass
    
    # Ensemble prediction (average)
    # final_prediction = np.mean(predictions, axis=0)
    
    print("Prediction complete!")
    # return final_prediction, pda_aligned_z


# Example usage:
# predictions, actual = predict_new_subject('path/to/new_edf', 'path/to/new_csv', saved_model)

## Cell 11: Final Summary and Recommendations

In [None]:
print("="*70)
print("MULTI-SUBJECT PIPELINE COMPLETE")
print("="*70)

print(f"\nProcessed {len(selected_subjects)} subjects")
print(f"Best model: {best_model}")
print(f"Average correlation: {np.mean(correlations[best_model]):.3f}")

print("\nNext steps:")
print("1. Test on completely new subjects")
print("2. Optimize hyperparameters with grid search")
print("3. Implement real-time prediction pipeline")
print("4. Analyze subject-specific differences")
print("5. Create channel-reduced versions for practical BCI")

print("\nTo load and use the saved model:")
print("model_data = joblib.load('multi_subject_model.pkl')")
print("predictions = predict_new_subject(edf_path, csv_path, model_data)")