In [1]:
# Compute Group Connectivity + Save Subject-Level Matrices

import os
from pathlib import Path
import json
import numpy as np
import pandas as pd
import mne
from mne_connectivity import spectral_connectivity_epochs
from nilearn import datasets
from typing import Dict, List, Tuple, Optional
from collections import defaultdict
import warnings
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")
mne.set_log_level('ERROR')

# Configuration
PROJECT_BASE = '/home/jaizor/jaizor/xtra'
BASE_DIR = Path(PROJECT_BASE)
GROUP_OUTPUT_DIR = Path(PROJECT_BASE) / "derivatives" / "group"
GROUP_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

BANDS = {
    "Theta": (4, 8),
    "Alpha": (8, 12),
    "Low_Beta": (13, 20),
    "High_Beta": (20, 30),
    "Low_Gamma": (30, 60),
    "High_Gamma": (60, 100)
}

METHOD = 'wpli2_debiased'
SFREQ = 500.0
N_ROIS = 512
CONDITIONS = ['InPhase', 'OutofPhase']

# Use all available cores minus 1
N_JOBS = max(1, mp.cpu_count() - 1)


def load_roi_names() -> List[str]:
    """Load and clean DiFuMo ROI names once."""
    try:
        atlas = datasets.fetch_atlas_difumo(dimension=512, resolution_mm=2)
        roi_names = atlas.labels['difumo_names'].astype(str).tolist()
    except Exception:
        roi_names = [f"Component_{i}" for i in range(N_ROIS)]
    
    # Clean names for CSV compatibility
    return [name.replace(',', ';').replace('\n', ' ').replace('\r', ' ') 
            for name in roi_names]


def find_subjects() -> List[str]:
    """Find all subjects with complete data."""
    eeg_dir = BASE_DIR / "derivatives" / "eeg"
    if not eeg_dir.exists():
        return []

    subjects = []
    for item in eeg_dir.iterdir():
        if not (item.is_dir() and item.name.startswith("sub-")):
            continue
            
        # Check required files exist
        events_dir = item / "bima_DBSOFF"
        required_files = [
            events_dir / f"{item.name}_events_mne_binary-eve.fif",
            events_dir / f"{item.name}_event_id_binary.json",
            BASE_DIR / "derivatives" / "lcmv" / f"{item.name}_bima_full_off" / "difumo_time_courses.npy"
        ]
        
        if all(f.exists() for f in required_files):
            subjects.append(item.name)

    return sorted(subjects, key=lambda x: int(x.split('-')[1]))


def compute_single_connectivity(epoch_data: np.ndarray, band_range: Tuple[float, float]) -> Optional[np.ndarray]:
    """Compute connectivity matrix for given epochs and frequency band - optimized version."""
    if len(epoch_data) == 0:
        return None
        
    try:
        # Use fewer tapers for speed while maintaining quality
        con = spectral_connectivity_epochs(
            data=epoch_data,
            method=METHOD,
            mode='multitaper',
            sfreq=SFREQ,
            fmin=band_range[0],
            fmax=band_range[1],
            faverage=True,
            verbose=False,
            n_jobs=1,  # Each process handles one job
            mt_bandwidth=2,  # Reduced bandwidth for speed
            mt_low_bias=True
        )
        
        matrix = con.get_data(output='dense').squeeze()
        
        # Fast symmetrization and diagonal zeroing
        matrix = np.maximum(matrix, matrix.T)  # Faster than (matrix + matrix.T) / 2
        np.fill_diagonal(matrix, 0)
        
        return matrix.astype(np.float32)  # Use float32 to save memory
        
    except Exception as e:
        print(f"⚠️  Connectivity computation failed: {e}")
        return None


def save_subject_matrices(subject: str, subject_matrices: Dict[Tuple[str, str], np.ndarray], base_output_dir: Path):
    """Save subject-level connectivity matrices as .npy files."""
    subject_dir = base_output_dir / "subject_level" / subject
    subject_dir.mkdir(parents=True, exist_ok=True)
    
    for (condition, band_name), matrix in subject_matrices.items():
        filename = f"{condition}_{band_name}.npy"
        filepath = subject_dir / filename
        np.save(filepath, matrix)
        print(f"   💾 Saved {filename} for {subject}")


def process_subject_parallel(subject: str) -> Dict[Tuple[str, str], np.ndarray]:
    """Process one subject - designed for parallel execution."""
    # File paths
    data_file = BASE_DIR / "derivatives" / "lcmv" / f"{subject}_bima_full_off" / "difumo_time_courses.npy"
    events_file = BASE_DIR / "derivatives" / "eeg" / subject / "bima_DBSOFF" / f"{subject}_events_mne_binary-eve.fif"
    event_id_file = BASE_DIR / "derivatives" / "eeg" / subject / "bima_DBSOFF" / f"{subject}_event_id_binary.json"
    
    subject_matrices = {}
    
    try:
        # 🛡️ SAFE LOAD: Avoid mmap to prevent file corruption
        data = np.load(data_file)  # Load fully into memory
        print(f"📥 Loaded {data.shape} from {data_file.name}")

        if data.size == 0:
            print(f"⚠️  Empty data in {data_file}")
            return {}

        # Ensure (channels, time) format
        if data.shape[0] > data.shape[1]:
            data = data.T.copy()  # 🚨 COPY to ensure data ownership — critical fix!
        else:
            data = data.copy()  # Still copy to be safe

        events = mne.read_events(events_file, verbose=False)
        with open(event_id_file, 'r') as f:
            event_id = json.load(f)
        
        # Create MNE objects with minimal overhead
        info = mne.create_info(
            ch_names=[f'C{i}' for i in range(N_ROIS)], 
            sfreq=SFREQ, 
            ch_types='misc'
        )
        raw = mne.io.RawArray(data, info, verbose=False)
        
        # Process all conditions
        for condition in CONDITIONS:
            if condition not in event_id:
                continue
                
            try:
                # Create epochs with optimized parameters
                epochs = mne.Epochs(
                    raw, events, {condition: event_id[condition]},
                    tmin=0, tmax=1.5, preload=True, baseline=None, 
                    event_repeated='drop', verbose=False,
                    proj=False,  # Skip projection for speed
                    reject=None  # Skip rejection for speed
                )
                
                epoch_data = epochs.get_data()
                if len(epoch_data) == 0:
                    print(f"⚠️  No epochs for {condition} in {subject}")
                    continue
                
                # Compute connectivity for all bands at once per condition
                for band_name, band_range in BANDS.items():
                    matrix = compute_single_connectivity(epoch_data, band_range)
                    if matrix is not None:
                        subject_matrices[(condition, band_name)] = matrix
                    else:
                        print(f"⚠️  Failed to compute {band_name} for {condition} in {subject}")
                        
            except Exception as e:
                print(f"⚠️  Epoch creation failed for {condition} in {subject}: {e}")
                continue
    
    except Exception as e:
        print(f"❌ Subject {subject} failed: {e}")
        return {}
    
    return subject_matrices


def batch_process_subjects(subjects: List[str]) -> Dict[Tuple[str, str], List[np.ndarray]]:
    """Process all subjects in parallel and collect matrices + save subject-level."""
    print(f"🚀 Processing {len(subjects)} subjects in parallel using {N_JOBS} cores...")
    
    all_matrices = defaultdict(list)
    
    # Process subjects in parallel
    with ProcessPoolExecutor(max_workers=N_JOBS) as executor:
        # Submit all jobs
        future_to_subject = {
            executor.submit(process_subject_parallel, subject): subject 
            for subject in subjects
        }
        
        # Collect results as they complete
        completed = 0
        for future in as_completed(future_to_subject):
            subject = future_to_subject[future]
            try:
                subject_matrices = future.result()
                
                # Add matrices to collections
                for (condition, band_name), matrix in subject_matrices.items():
                    all_matrices[(condition, band_name)].append(matrix)
                
                # 💾 SAVE SUBJECT-LEVEL MATRICES
                save_subject_matrices(subject, subject_matrices, GROUP_OUTPUT_DIR)
                
                completed += 1
                success_count = len(subject_matrices)
                print(f"✅ {subject}: {success_count} matrices ({completed}/{len(subjects)})")
                
            except Exception as e:
                print(f"❌ {subject}: failed with exception ({completed}/{len(subjects)}) — {e}")
                completed += 1
    
    return dict(all_matrices)


def compute_fast_averages(all_matrices: Dict[Tuple[str, str], List[np.ndarray]]) -> Dict[Tuple[str, str], np.ndarray]:
    """Compute group averages using optimized numpy operations."""
    print(f"\n⚡ Computing group averages...")
    
    group_averages = {}
    for (condition, band_name), matrix_list in all_matrices.items():
        if not matrix_list:
            print(f"⚠️  No data for {condition} {band_name}")
            continue
        
        # Stack and average in one optimized operation
        # Use float32 to reduce memory usage
        stacked = np.stack(matrix_list, axis=0, dtype=np.float32)
        group_avg = np.mean(stacked, axis=0, dtype=np.float32)
        
        group_averages[(condition, band_name)] = group_avg
        print(f"   ✅ {condition} {band_name}: {len(matrix_list)} subjects")
    
    return group_averages


def save_matrices_fast(group_averages: Dict[Tuple[str, str], np.ndarray], roi_names: List[str]) -> None:
    """Save only the final group average matrices as CSV."""
    print(f"\n💾 Saving {len(group_averages)} group average matrices...")
    
    for (condition, band_name), matrix in group_averages.items():
        df = pd.DataFrame(matrix, index=roi_names, columns=roi_names)
        csv_filename = f"matrix_{condition}_{band_name}_group_avg.csv"
        csv_filepath = GROUP_OUTPUT_DIR / csv_filename
        df.to_csv(csv_filepath, float_format='%.6f')
        print(f"   ✅ {csv_filename}")


def main():
    """Main execution with timing."""
    import time
    start_time = time.time()
    
    print("Starting ultra-fast group connectivity analysis...")
    
    # Load ROI names once
    print("📥 Loading ROI names...")
    roi_names = load_roi_names()
    
    # Find subjects
    subjects = find_subjects()
    if not subjects:
        print("❌ No valid subjects found.")
        return
    
    print(f"🧬 Found {len(subjects)} subjects")
    
    # Phase 1: Extract all matrices in parallel
    all_matrices = batch_process_subjects(subjects)
    
    if not all_matrices:
        print("❌ No matrices computed.")
        return
    
    # Phase 2: Compute averages
    group_averages = compute_fast_averages(all_matrices)
    
    if not group_averages:
        print("❌ No group averages computed.")
        return
    
    # Phase 3: Save CSVs
    save_matrices_fast(group_averages, roi_names)
    
    # Summary
    elapsed = time.time() - start_time
    total_subjects = sum(len(matrices) for matrices in all_matrices.values())
    
    print(f"\n🎉 Analysis Complete!")
    print(f"   • Time elapsed: {elapsed:.1f} seconds")
    print(f"   • Subjects found: {len(subjects)}")
    print(f"   • Total matrices computed: {total_subjects}")
    print(f"   • Conditions: {CONDITIONS}")
    print(f"   • Frequency bands: {list(BANDS.keys())}")
    print(f"   • Output: {GROUP_OUTPUT_DIR}")
    print(f"   • Speed: {total_subjects/elapsed:.1f} matrices/second")
    print(f"   • Subject-level matrices saved in: {GROUP_OUTPUT_DIR / 'subject_level'}")


if __name__ == "__main__":
    main()

Starting ultra-fast group connectivity analysis...
📥 Loading ROI names...


🧬 Found 12 subjects
🚀 Processing 12 subjects in parallel using 39 cores...
📥 Loaded (512, 126537) from difumo_time_courses.npy
📥 Loaded (512, 131679) from difumo_time_courses.npy
📥 Loaded (512, 215473) from difumo_time_courses.npy
📥 Loaded (512, 210656) from difumo_time_courses.npy
📥 Loaded (512, 248466) from difumo_time_courses.npy
📥 Loaded (512, 263523) from difumo_time_courses.npy📥 Loaded (512, 263517) from difumo_time_courses.npy

📥 Loaded (512, 263356) from difumo_time_courses.npy
📥 Loaded (512, 263506) from difumo_time_courses.npy
📥 Loaded (512, 263481) from difumo_time_courses.npy📥 Loaded (512, 263515) from difumo_time_courses.npy
📥 Loaded (512, 263524) from difumo_time_courses.npy

   💾 Saved InPhase_Theta.npy for sub-08
   💾 Saved InPhase_Alpha.npy for sub-08
   💾 Saved InPhase_Low_Beta.npy for sub-08
   💾 Saved InPhase_High_Beta.npy for sub-08
   💾 Saved InPhase_Low_Gamma.npy for sub-08
   💾 Saved InPhase_High_Gamma.npy for sub-08
   💾 Saved OutofPhase_Theta.npy for sub-08
  