In [None]:
# run_multisubject_connectivity.py
# Fully automatic — adapted to your real structure and event keys.

import os
from pathlib import Path
import json
import numpy as np
import pandas as pd
import mne
from mne_connectivity import spectral_connectivity_epochs
from nilearn import datasets
from numpy.random import default_rng

PROJECT_BASE = '/home/jaizor/jaizor/xtra'
BASE_DIR = Path(PROJECT_BASE)


class CleanConnectivityToCSV:
    """
    Extract connectivity and save ONE clean CSV per band.
    No .npy, no extra files. Just: (trials x connections)
    """
    
    def __init__(self, project_base, subject='sub-06'):
        self.project_base = project_base
        self.subject = subject
        self.sfreq = 500.0
        self.method = 'wpli2_debiased'
        
        self.bands = {
            "Theta":      (4, 8),
            "Alpha":      (8, 12),
            "Low_Beta":   (13, 20),
            "High_Beta":  (20, 30),
            "Low_Gamma":  (30, 60),
            "High_Gamma": (60, 120)
        }

        self.n_bootstraps = 50
        self.chunk_size = 30
        self.rng = default_rng(seed=42)
        
        self.component_names = self._load_difumo_names()
        self.regions = self._define_brain_regions()
        self.selected_indices = sorted(list(set(self.regions)))
        
        self.output_dir = f"{project_base}/derivatives/features/{subject}"
        os.makedirs(self.output_dir, exist_ok=True)
    
    def _load_difumo_names(self):
        print("📥 Loading DiFuMo 512 anatomical labels...")
        try:
            atlas = datasets.fetch_atlas_difumo(dimension=512, resolution_mm=2)
            col = 'difumo_names'
            names = atlas.labels[col].astype(str).tolist()
            print(f"✅ Loaded {len(names)} ROI names")
        except Exception as e:
            print(f"❌ Failed to load DiFuMo: {e}")
            names = [f"Component_{i}" for i in range(512)]
        return names

    def _define_brain_regions(self):
        Motor_M1 = [40, 86, 198, 268, 305, 437, 458, 465]
        Motor_SMA_Premotor = [17, 18, 288, 291, 296, 297, 302, 305, 314, 315, 335, 375, 379, 448]
        Motor_Medial = [101, 102, 388, 409, 498]
        Thalamus = [70, 73, 297, 334, 414, 420] 
        Basal_Ganglia = [30, 53, 224, 260, 405, 422, 109, 110, 315, 331, 467, 479, 55, 71, 307, 223]  
        Cerebellum_Motor = [43, 47, 83, 84, 127, 183, 220, 221, 295, 304, 310, 311, 374, 378, 381, 403, 441, 490, 491]
        Somatosensory = [44, 131, 210, 411, 413, 436]
        Executive_Control = [3, 85, 104, 148, 184, 337, 377, 446, 447, 506, 507]
        Interoception = [2, 387, 358, 389, 165, 469]
        Error_Monitoring = [185, 219, 326, 473, 492]

        return (Motor_M1 + Motor_SMA_Premotor + Motor_Medial + Thalamus + Basal_Ganglia + Executive_Control + Interoception + Error_Monitoring)

    def load_data(self):
        """Load epochs for selected ROIs — FIXED TO MATCH YOUR EVENT KEYS."""
        print("📥 Loading data...")

        # ✅ Paths for your structure
        data_file = f"{self.project_base}/derivatives/lcmv/{self.subject}_bima_full_off/difumo_time_courses.npy"
        events_file = f"{self.project_base}/derivatives/eeg/{self.subject}/bima_DBSOFF/{self.subject}_events_mne_binary-eve.fif"
        event_id_file = f"{self.project_base}/derivatives/eeg/{self.subject}/bima_DBSOFF/{self.subject}_event_id_binary.json"

        # Load data
        data = np.load(data_file)
        if data.shape[0] > data.shape[1]:
            data = data.T
            
        events = mne.read_events(events_file, verbose=False)
        with open(event_id_file, 'r') as f:
            event_id = json.load(f)
        
        print(f"🔍 event_id keys: {list(event_id.keys())}")  # DEBUG

        info = mne.create_info(ch_names=[f'C{i}' for i in range(512)], sfreq=self.sfreq, ch_types='misc')
        raw = mne.io.RawArray(data, info, verbose=False)
        
        # ✅ FIXED: Use actual keys from your JSON: "InPhase", "OutofPhase"
        conditions = {
            'InPhase': 'InPhase',
            'OutofPhase': 'OutofPhase'
        }
        epoch_data = {}
        
        for raw_cond, save_cond in conditions.items():
            if raw_cond in event_id:
                try:
                    epochs = mne.Epochs(raw, events, {raw_cond: event_id[raw_cond]},
                                      tmin=0, tmax=1.5, preload=True, verbose=False, baseline=None, event_repeated='drop')
                    data_cond = epochs.get_data()[:, self.selected_indices, :]
                    epoch_data[save_cond] = data_cond
                    print(f"   • {len(data_cond)} epochs for {raw_cond}")
                except Exception as e:
                    print(f"   ✗ Failed to create epochs for {raw_cond}: {e}")
            else:
                print(f"   ⚠️ Condition '{raw_cond}' not found in event_id.")

        if not epoch_data:
            print("❌ WARNING: No epochs created. Check event_id keys and events file.")

        return epoch_data

    def run(self):
        """Run analysis and save ONE clean CSV per band."""
        print("🚀 Starting clean connectivity export...")
        
        epoch_data = self.load_data()
        if not epoch_data:
            print("❌ Skipping subject — no valid epochs.")
            return

        n_nodes = len(self.selected_indices)
        
        for band_name, (fmin, fmax) in self.bands.items():
            print(f"\n🌈 Processing {band_name} band ({fmin}-{fmax} Hz)")
            
            all_rows = []
            
            for condition, data_cond in epoch_data.items():
                n_epochs = len(data_cond)
                
                for bootstrap_idx in range(self.n_bootstraps):
                    try:
                        size_to_sample = min(self.chunk_size, n_epochs)
                        sample_idx = self.rng.choice(n_epochs, size=size_to_sample, replace=False)
                        chunk_data = data_cond[sample_idx]
                        
                        con = spectral_connectivity_epochs(
                            data=chunk_data,
                            method=self.method,
                            mode='multitaper',
                            sfreq=self.sfreq,
                            fmin=fmin,
                            fmax=fmax,
                            faverage=True,
                            verbose=False,
                            n_jobs=1
                        )
                        matrix = con.get_data(output='dense').squeeze()
                        matrix = (matrix + matrix.T) / 2
                        np.fill_diagonal(matrix, 0)
                        
                        row = {'condition': condition, 'bootstrap': bootstrap_idx}
                        
                        for i in range(n_nodes):
                            for j in range(i + 1, n_nodes):
                                name_i = self.component_names[self.selected_indices[i]]
                                name_j = self.component_names[self.selected_indices[j]]
                                conn_name = f"{name_i} ↔ {name_j}"
                                row[conn_name] = matrix[i, j]
                        
                        all_rows.append(row)
                        
                    except Exception as e:
                        print(f"  ✗ Bootstrap {bootstrap_idx+1} failed: {str(e)}")
                        continue
            
            if not all_rows:
                print(f"⚠️  No connectivity data generated for {band_name}.")
                continue

            df = pd.DataFrame(all_rows)
            file_path = f"{self.output_dir}/ml_features_{band_name}.csv"
            df.to_csv(file_path, index=False)
            print(f"✅ Saved {band_name}: {df.shape} → {file_path}")


def find_all_subjects(base_dir: Path):
    """Find all subjects with complete data."""
    eeg_dir = base_dir / "derivatives" / "eeg"
    if not eeg_dir.exists():
        print(f"❌ EEG directory not found: {eeg_dir}")
        return []

    subjects = []
    for item in eeg_dir.iterdir():
        if item.is_dir() and item.name.startswith("sub-"):
            events_dir = item / "bima_DBSOFF"
            if not events_dir.exists():
                continue

            events_file = events_dir / f"{item.name}_events_mne_binary-eve.fif"
            event_id_file = events_dir / f"{item.name}_event_id_binary.json"

            if not (events_file.exists() and event_id_file.exists()):
                continue

            # Check DiFuMo file
            difumo_file = base_dir / "derivatives" / "lcmv" / f"{item.name}_bima_full_off" / "difumo_time_courses.npy"
            if not difumo_file.exists():
                continue

            subjects.append(item.name)

    # Sort by subject number
    subjects.sort(key=lambda x: int(x.split('-')[1]))
    return subjects


if __name__ == "__main__":
    print("🔎 Discovering subjects...")

    subjects = find_all_subjects(BASE_DIR)

    if not subjects:
        print("❌ No valid subjects found. Check paths:")
        print("   - derivatives/eeg/sub-XX/bima_DBSOFF/sub-XX_*.fif/.json")
        print("   - derivatives/lcmv/sub-XX_bima_full_off/difumo_time_courses.npy")
        exit(1)

    print(f"✅ Found {len(subjects)} subjects: {subjects}")

    for subject in subjects:
        print(f"\n{'='*80}")
        print(f"🧬 PROCESSING SUBJECT: {subject}")
        print(f"{'='*80}")

        extractor = CleanConnectivityToCSV(PROJECT_BASE, subject=subject)
        extractor.run()

    print(f"\n🎉 ALL {len(subjects)} SUBJECTS PROCESSED SUCCESSFULLY!")