In [2]:
# compute_group_connectivity_to_csv_fast.py
# Ultra-fast version with parallel processing and memory optimizations
# Key improvements:
# - Parallel subject processing
# - Batch connectivity computation
# - Memory safety (no mmap corruption)
# - Optimized MNE operations
# - NOW INTEGRATED WITH PHASE-SEGMENT EPOCHS (10s In/Out of Phase)

import os
from pathlib import Path
import json
import numpy as np
import pandas as pd
import mne
from mne_connectivity import spectral_connectivity_epochs
from nilearn import datasets
from typing import Dict, List, Tuple, Optional
from collections import defaultdict
import warnings
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")
mne.set_log_level('ERROR')

# Configuration
PROJECT_BASE = '/home/jaizor/jaizor/xtra'
BASE_DIR = Path(PROJECT_BASE)
GROUP_OUTPUT_DIR = Path(PROJECT_BASE) / "derivatives" / "group2"/ '10s'
GROUP_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

BANDS = {
    "Theta": (4, 8),
    "Alpha": (8, 12),
    "Low_Beta": (13, 20),
    "High_Beta": (20, 30),
    "Low_Gamma": (30, 60),
    "High_Gamma": (60, 100)
}

METHOD = 'wpli2_debiased'
SFREQ = 500.0
N_ROIS = 512
CONDITIONS = ['InPhase', 'OutofPhase']

# Use all available cores minus 1
N_JOBS = max(1, mp.cpu_count() - 1)


def load_roi_names() -> List[str]:
    """Load and clean DiFuMo ROI names once."""
    try:
        atlas = datasets.fetch_atlas_difumo(dimension=512, resolution_mm=2)
        roi_names = atlas.labels['difumo_names'].astype(str).tolist()
    except Exception:
        roi_names = [f"Component_{i}" for i in range(N_ROIS)]
    
    # Clean names for CSV compatibility
    return [name.replace(',', ';').replace('\n', ' ').replace('\r', ' ') 
            for name in roi_names]


def find_subjects() -> List[str]:
    """Find all subjects with complete data including phase epochs."""
    eeg_dir = BASE_DIR / "derivatives" / "eeg"
    if not eeg_dir.exists():
        return []

    subjects = []
    for item in eeg_dir.iterdir():
        if not (item.is_dir() and item.name.startswith("sub-")):
            continue
            
        # Check required files exist
        events_dir = item / "bima_DBSOFF"
        lcmv_dir = BASE_DIR / "derivatives" / "lcmv" / f"{item.name}_bima_full_off"
        epochs_dir = BASE_DIR / "derivatives" / "epochs" / item.name

        required_files = [
            lcmv_dir / "difumo_time_courses.npy",
            events_dir / f"{item.name}_events_mne_binary-eve.fif",
            events_dir / f"{item.name}_event_id_binary.json",
            epochs_dir / f"{item.name}_in_phase-epo.fif",
            epochs_dir / f"{item.name}_out_of_phase-epo.fif",
        ]
        
        if all(f.exists() for f in required_files):
            subjects.append(item.name)

    return sorted(subjects, key=lambda x: int(x.split('-')[1]))


def load_phase_epochs(subject: str, project_base: str = PROJECT_BASE) -> Optional[Dict[str, mne.Epochs]]:
    """
    Load precomputed In-Phase and Out-of-Phase epochs from disk.
    Returns dict: {'InPhase': epochs_in, 'OutofPhase': epochs_out}
    """
    base_path = Path(project_base)
    epochs_dir = base_path / "derivatives" / "epochs" / subject

    in_file = epochs_dir / f"{subject}_in_phase-epo.fif"
    out_file = epochs_dir / f"{subject}_out_of_phase-epo.fif"

    if not in_file.exists() or not out_file.exists():
        print(f"❌ Missing epoch files for {subject}")
        return None

    try:
        epochs_in = mne.read_epochs(in_file, preload=True, verbose=False)
        epochs_out = mne.read_epochs(out_file, preload=True, verbose=False)
        return {'InPhase': epochs_in, 'OutofPhase': epochs_out}
    except Exception as e:
        print(f"❌ Failed to load epochs for {subject}: {e}")
        return None


def create_epochs_from_phase_segments(subject, project_base=PROJECT_BASE, tmin_in=-10.0, tmax_in=0.0, tmin_out=0.0, tmax_out=10.0):
    """
    Create MNE Epochs for In-Phase and Out-of-Phase segments based on detected phase breaks.
    Saves to disk and returns dict.
    """
    project_base = Path(project_base)
    
    # --- Load phase segments ---
    segments_file = project_base / "derivatives" / "phase_segments" / f"{subject}_phase_segments.csv"
    if not segments_file.exists():
        print(f"❌ Phase segments not found for {subject}")
        return None
    
    segments_df = pd.read_csv(segments_file)
    if segments_df.empty:
        print(f"❌ No phase segments for {subject}")
        return None

    # --- Load RAW EEG ---
    raw_file = project_base / "derivatives" / "eeg" / subject / "bima_DBSOFF" / f"{subject}_ses-DBSOFF_task-bima_eeg_ica_cleaned_raw.fif"
    if not raw_file.exists():
        print(f"❌ RAW file not found for {subject}")
        return None

    raw = mne.io.read_raw_fif(str(raw_file), preload=True)
    sfreq = raw.info['sfreq']

    # --- Create events array from break_time ---
    break_times = segments_df['break_time'].values
    event_samples = (break_times * sfreq).astype(int)
    event_id_in = 1  # arbitrary, for in-phase

    events = np.column_stack([
        event_samples,
        np.zeros(len(event_samples), dtype=int),
        np.ones(len(event_samples), dtype=int) * event_id_in
    ])

    # --- Create metadata ---
    metadata = segments_df.copy()
    metadata['event_time'] = break_times
    metadata['event_sample'] = event_samples

    # --- Create Epochs: In-Phase (10s before break) ---
    print(f"🧠 Creating In-Phase epochs for {subject} (tmin={tmin_in}, tmax={tmax_in})...")
    epochs_in = mne.Epochs(
        raw,
        events,
        event_id={'phase_break': event_id_in},
        tmin=tmin_in,
        tmax=tmax_in,
        baseline=None,
        preload=True,
        metadata=metadata
    )

    # --- Create Epochs: Out-of-Phase (10s after break) ---
    print(f"🔥 Creating Out-of-Phase epochs for {subject} (tmin={tmin_out}, tmax={tmax_out})...")
    epochs_out = mne.Epochs(
        raw,
        events,
        event_id={'phase_break': event_id_in},
        tmin=tmin_out,
        tmax=tmax_out,
        baseline=None,
        preload=True,
        metadata=metadata
    )

    print(f"✅ Created {len(epochs_in)} In-Phase and {len(epochs_out)} Out-of-Phase epochs for {subject}")

    # --- Save to disk ---
    output_dir = Path(project_base) / "derivatives" / "epochs" / subject
    output_dir.mkdir(parents=True, exist_ok=True)
    
    epochs_in.save(output_dir / f"{subject}_in_phase-epo.fif", overwrite=True)
    epochs_out.save(output_dir / f"{subject}_out_of_phase-epo.fif", overwrite=True)
    
    print(f"💾 Saved epochs to {output_dir}")

    return {
        'in_phase_epochs': epochs_in,
        'out_of_phase_epochs': epochs_out
    }


def ensure_epochs_exist(subjects: List[str]):
    """Ensure phase epochs exist for all subjects — create if missing."""
    for subject in subjects:
        epochs_dir = Path(PROJECT_BASE) / "derivatives" / "epochs" / subject
        in_file = epochs_dir / f"{subject}_in_phase-epo.fif"
        out_file = epochs_dir / f"{subject}_out_of_phase-epo.fif"

        if not in_file.exists() or not out_file.exists():
            print(f"⏳ Creating epochs for {subject}...")
            epochs_dict = create_epochs_from_phase_segments(subject)
            if epochs_dict:
                print(f"✅ Saved epochs for {subject}")
            else:
                print(f"❌ Failed to create epochs for {subject}")


def compute_single_connectivity(epoch_data: np.ndarray, band_range: Tuple[float, float]) -> Optional[np.ndarray]:
    """Compute connectivity matrix for given epochs and frequency band - optimized version."""
    if len(epoch_data) == 0:
        return None
        
    try:
        # Use fewer tapers for speed while maintaining quality
        con = spectral_connectivity_epochs(
            data=epoch_data,
            method=METHOD,
            mode='multitaper',
            sfreq=SFREQ,
            fmin=band_range[0],
            fmax=band_range[1],
            faverage=True,
            verbose=False,
            n_jobs=1,  # Each process handles one job
            mt_bandwidth=2,  # Reduced bandwidth for speed
            mt_low_bias=True
        )
        
        matrix = con.get_data(output='dense').squeeze()
        
        # Fast symmetrization and diagonal zeroing
        matrix = np.maximum(matrix, matrix.T)
        np.fill_diagonal(matrix, 0)
        
        return matrix.astype(np.float32)  # Use float32 to save memory
        
    except Exception as e:
        print(f"⚠️  Connectivity computation failed: {e}")
        return None


def process_subject_parallel(subject: str) -> Dict[Tuple[str, str], np.ndarray]:
    """Process one subject - designed for parallel execution."""
    # File paths
    data_file = BASE_DIR / "derivatives" / "lcmv" / f"{subject}_bima_full_off" / "difumo_time_courses.npy"
    
    subject_matrices = {}
    
    try:
        # 🛡️ SAFE LOAD: Avoid mmap to prevent file corruption
        data = np.load(data_file)  # Load fully into memory
        print(f"📥 Loaded {data.shape} from {data_file.name}")

        if data.size == 0:
            print(f"⚠️  Empty data in {data_file}")
            return {}

        # Ensure (channels, time) format
        if data.shape[0] > data.shape[1]:
            data = data.T.copy()  # 🚨 COPY to ensure data ownership — critical fix!
        else:
            data = data.copy()  # Still copy to be safe

        # Create MNE info object
        info = mne.create_info(
            ch_names=[f'C{i}' for i in range(N_ROIS)], 
            sfreq=SFREQ, 
            ch_types='misc'
        )
        raw = mne.io.RawArray(data, info, verbose=False)
        
        # 🆕 Load precomputed phase-based epochs
        phase_epochs = load_phase_epochs(subject, PROJECT_BASE)
        if not phase_epochs:
            print(f"⚠️  Skipping {subject} — no phase epochs found")
            return {}

        # Process both conditions: InPhase and OutofPhase
        for condition in CONDITIONS:
            epochs = phase_epochs.get(condition)
            if epochs is None:
                print(f"⚠️  No epochs for {condition} in {subject}")
                continue

            # ⚡ Optional: Downsample to 250Hz to speed up connectivity (safe for bands < 120Hz)
            # epochs = epochs.copy().resample(250.0, npad='auto')

            epoch_data = epochs.get_data()
            if len(epoch_data) == 0:
                print(f"⚠️  No epochs for {condition} in {subject}")
                continue

            # Compute connectivity for all bands
            for band_name, band_range in BANDS.items():
                matrix = compute_single_connectivity(epoch_data, band_range)
                if matrix is not None:
                    subject_matrices[(condition, band_name)] = matrix
                else:
                    print(f"⚠️  Failed to compute {band_name} for {condition} in {subject}")
                        
    except Exception as e:
        print(f"❌ Subject {subject} failed: {e}")
        return {}
    
    return subject_matrices


def batch_process_subjects(subjects: List[str]) -> Dict[Tuple[str, str], List[np.ndarray]]:
    """Process all subjects in parallel and collect matrices."""
    print(f"🚀 Processing {len(subjects)} subjects in parallel using {N_JOBS} cores...")
    
    all_matrices = defaultdict(list)
    
    # Process subjects in parallel
    with ProcessPoolExecutor(max_workers=N_JOBS) as executor:
        # Submit all jobs
        future_to_subject = {
            executor.submit(process_subject_parallel, subject): subject 
            for subject in subjects
        }
        
        # Collect results as they complete
        completed = 0
        for future in as_completed(future_to_subject):
            subject = future_to_subject[future]
            try:
                subject_matrices = future.result()
                
                # Add matrices to collections
                for (condition, band_name), matrix in subject_matrices.items():
                    all_matrices[(condition, band_name)].append(matrix)
                
                completed += 1
                success_count = len(subject_matrices)
                print(f"✅ {subject}: {success_count} matrices ({completed}/{len(subjects)})")
                
            except Exception as e:
                print(f"❌ {subject}: failed with exception ({completed}/{len(subjects)}) — {e}")
                completed += 1
    
    return dict(all_matrices)


def compute_fast_averages(all_matrices: Dict[Tuple[str, str], List[np.ndarray]]) -> Dict[Tuple[str, str], np.ndarray]:
    """Compute group averages using optimized numpy operations."""
    print(f"\n⚡ Computing group averages...")
    
    group_averages = {}
    for (condition, band_name), matrix_list in all_matrices.items():
        if not matrix_list:
            print(f"⚠️  No data for {condition} {band_name}")
            continue
        
        # Stack and average in one optimized operation
        stacked = np.stack(matrix_list, axis=0, dtype=np.float32)
        group_avg = np.mean(stacked, axis=0, dtype=np.float32)
        
        group_averages[(condition, band_name)] = group_avg
        print(f"   ✅ {condition} {band_name}: {len(matrix_list)} subjects")
    
    return group_averages


def save_matrices_fast(group_averages: Dict[Tuple[str, str], np.ndarray], roi_names: List[str]) -> None:
    """Save only the final group average matrices as CSV."""
    print(f"\n💾 Saving {len(group_averages)} group average matrices...")
    
    for (condition, band_name), matrix in group_averages.items():
        df = pd.DataFrame(matrix, index=roi_names, columns=roi_names)
        csv_filename = f"matrix_{condition}_{band_name}_group_avg.csv"
        csv_filepath = GROUP_OUTPUT_DIR / csv_filename
        df.to_csv(csv_filepath, float_format='%.6f')
        print(f"   ✅ {csv_filename}")


def main():
    """Main execution with timing."""
    import time
    start_time = time.time()
    
    print("Starting ultra-fast group connectivity analysis with PHASE SEGMENTS...")
    
    # Load ROI names once
    print("📥 Loading ROI names...")
    roi_names = load_roi_names()
    
    # Find subjects
    subjects = find_subjects()
    if not subjects:
        print("❌ No valid subjects found.")
        return
    
    print(f"🧬 Found {len(subjects)} subjects")

    # 🆕 Ensure epochs exist (auto-create if missing)
    ensure_epochs_exist(subjects)

    # Re-find subjects (in case some were created now)
    subjects = find_subjects()
    if not subjects:
        print("❌ Still no valid subjects after epoch creation.")
        return

    # Phase 1: Extract all matrices in parallel
    all_matrices = batch_process_subjects(subjects)
    
    if not all_matrices:
        print("❌ No matrices computed.")
        return
    
    # Phase 2: Compute averages
    group_averages = compute_fast_averages(all_matrices)
    
    if not group_averages:
        print("❌ No group averages computed.")
        return
    
    # Phase 3: Save CSVs
    save_matrices_fast(group_averages, roi_names)
    
    # Summary
    elapsed = time.time() - start_time
    total_subjects = sum(len(matrices) for matrices in all_matrices.values())
    
    print(f"\n🎉 Analysis Complete!")
    print(f"   • Time elapsed: {elapsed:.1f} seconds")
    print(f"   • Subjects found: {len(subjects)}")
    print(f"   • Total matrices computed: {total_subjects}")
    print(f"   • Conditions: {CONDITIONS}")
    print(f"   • Frequency bands: {list(BANDS.keys())}")
    print(f"   • Output: {GROUP_OUTPUT_DIR}")
    print(f"   • Speed: {total_subjects/elapsed:.1f} matrices/second")


if __name__ == "__main__":
    main()

Starting ultra-fast group connectivity analysis with PHASE SEGMENTS...
📥 Loading ROI names...


🧬 Found 11 subjects
🚀 Processing 11 subjects in parallel using 39 cores...


📥 Loaded (512, 131679) from difumo_time_courses.npy
📥 Loaded (512, 248466) from difumo_time_courses.npy
📥 Loaded (512, 263506) from difumo_time_courses.npy
📥 Loaded (512, 263524) from difumo_time_courses.npy
📥 Loaded (512, 215473) from difumo_time_courses.npy
📥 Loaded (512, 263481) from difumo_time_courses.npy
📥 Loaded (512, 263515) from difumo_time_courses.npy
📥 Loaded (512, 263356) from difumo_time_courses.npy📥 Loaded (512, 210656) from difumo_time_courses.npy

📥 Loaded (512, 263523) from difumo_time_courses.npy
📥 Loaded (512, 263517) from difumo_time_courses.npy
⚠️  No epochs for InPhase in sub-11
⚠️  No epochs for OutofPhase in sub-11
✅ sub-11: 0 matrices (1/11)
⚠️  No epochs for InPhase in sub-08
⚠️  No epochs for OutofPhase in sub-08
⚠️  No epochs for InPhase in sub-14
⚠️  No epochs for OutofPhase in sub-14
⚠️  No epochs for InPhase in sub-07
⚠️  No epochs for OutofPhase in sub-07
⚠️  No epochs for InPhase in sub-06
⚠️  No epochs for OutofPhase in sub-06
⚠️  No epochs for InPhase

ValueError: all input arrays must have the same shape