In [2]:
# %%
import os
from pathlib import Path
import json
import numpy as np
import pandas as pd
import mne

PROJECT_BASE = '/home/jaizor/jaizor/xtra'

class EventToTrialMapper:
    """
    🔥 FINAL VERSION: Maps events to trials using stim channels extracted directly from RAW .fif.
    - NO intermediate file saving.
    - Outputs ONE CSV per subject (optionally saved).
    - Uses rising edge detection on '1a'-'6a' stim channels.
    """

    def __init__(self, project_base, subject='sub-06', save_output=False):
        self.project_base = Path(project_base)
        self.subject = subject
        self.save_output = save_output  # Only save final CSV if True

        # Output dir for final CSV (if saving)
        self.output_dir = self.project_base / "derivatives" / "features" / self.subject
        if self.save_output:
            self.output_dir.mkdir(parents=True, exist_ok=True)

    def detect_rising_edges(self, data, sfreq):
        """
        Detect rising edges (0 → 1) in binary stim channel.
        Returns times in seconds of detected onsets.
        """
        binary = (data > 0.5).astype(int)
        diff = np.diff(binary)
        onset_indices = np.where(diff == 1)[0] + 1  # +1 because diff shifts left
        onset_times = onset_indices / sfreq
        return onset_times

    def extract_pacing_onsets_from_raw(self, raw):
        """
        Extract pacing trigger onsets ('1a'-'6a') directly from raw stim channels.
        Returns DataFrame: ['channel', 'eeg_time_s']
        """
        sfreq = raw.info['sfreq']
        pacing_channels = ['1a', '2a', '3a', '4a', '5a', '6a']
        all_onsets = []

        for ch_name in pacing_channels:
            if ch_name not in raw.ch_names:
                print(f"⚠️  Stim channel '{ch_name}' not found in {self.subject}")
                continue

            data = raw.get_data(picks=[ch_name])[0]
            onset_times = self.detect_rising_edges(data, sfreq)

            for t in onset_times:
                all_onsets.append({'channel': ch_name, 'eeg_time_s': t})

        if not all_onsets:
            raise ValueError(f"❌ No pacing triggers (1a-6a) found in raw data for {self.subject}")

        return pd.DataFrame(all_onsets)

    def define_trials_from_pacing_onsets(self, pacing_onsets_df, trial_gap_threshold_s=2.5, extension_duration_s=4.3):
        """
        Define trials from pacing onsets.
        Returns DataFrame: ['trial_number', 'start_time', 'end_time', 'end_time_extended', 'duration_s', 'subject_id']
        """
        if pacing_onsets_df.empty:
            return pd.DataFrame()

        df = pacing_onsets_df.sort_values('eeg_time_s').reset_index(drop=True)
        times = df['eeg_time_s'].values
        intervals = np.diff(times)
        gap_indices = np.where(intervals > trial_gap_threshold_s)[0]

        boundaries = []
        start_idx = 0
        for gap_idx in gap_indices:
            boundaries.append((start_idx, gap_idx))
            start_idx = gap_idx + 1
        boundaries.append((start_idx, len(df) - 1))

        trials = []
        for start_idx, end_idx in boundaries:
            block = df.iloc[start_idx:end_idx + 1]
            first_1a = block[block['channel'] == '1a']['eeg_time_s'].min()
            last_6a = block[block['channel'] == '6a']['eeg_time_s'].max()

            if pd.notna(first_1a) and pd.notna(last_6a) and last_6a > first_1a:
                trials.append({
                    'trial_number': len(trials) + 1,
                    'start_time': first_1a,
                    'end_time': last_6a,
                    'end_time_extended': last_6a + extension_duration_s,
                    'duration_s': last_6a - first_1a,
                    'subject_id': self.subject
                })

        return pd.DataFrame(trials)

    def load_binary_events(self):
        """
        Load binary events (behavioral/condition events) from .fif + .json.
        Returns DataFrame: ['eeg_time_s', 'event_id', 'condition']
        """
        events_file = self.project_base / "derivatives" / "eeg" / self.subject / "bima_DBSOFF" / f"{self.subject}_events_mne_binary-eve.fif"
        event_id_file = self.project_base / "derivatives" / "eeg" / self.subject / "bima_DBSOFF" / f"{self.subject}_event_id_binary.json"

        if not (events_file.exists() and event_id_file.exists()):
            raise FileNotFoundError(f"❌ Binary event files not found for {self.subject}")

        events = mne.read_events(str(events_file), verbose=False)
        with open(event_id_file, 'r') as f:
            event_id = json.load(f)

        sfreq = 500.0  # You can auto-detect this from raw if needed
        event_times = events[:, 0] / sfreq
        id_to_condition = {v: k for k, v in event_id.items()}
        conditions = [id_to_condition.get(e_id, "Unknown") for e_id in events[:, 2]]

        return pd.DataFrame({
            'eeg_time_s': event_times,
            'event_id': events[:, 2],
            'condition': conditions
        })

    def assign_events_to_trials(self, events_df, trials_df):
        """
        Assign each event to a trial based on time.
        Adds 'trial_number' column.
        """
        def find_trial(t):
            for _, trial in trials_df.iterrows():
                if trial['start_time'] <= t < trial['end_time_extended']:
                    return trial['trial_number']
            return -1

        events_df['trial_number'] = events_df['eeg_time_s'].apply(find_trial)
        return events_df[events_df['trial_number'] != -1].copy()

    def add_stim_channel_info(self, events_df, pacing_onsets_df):
        """
        Add 'stim_channel' column: last pacing trigger (1a-6a) before each event.
        """
        if events_df.empty or pacing_onsets_df.empty:
            return events_df

        pacing_sorted = pacing_onsets_df.sort_values('eeg_time_s').reset_index(drop=True)

        def get_last_stim(t):
            valid = pacing_sorted[pacing_sorted['eeg_time_s'] <= t]
            return valid.iloc[-1]['channel'] if not valid.empty else "unknown"

        events_df['stim_channel'] = events_df['eeg_time_s'].apply(get_last_stim)
        return events_df

    def run(self):
        """
        🔥 MAIN ENTRY POINT — returns final DataFrame (and optionally saves it).
        """
        print(f"\n🚀 Processing {self.subject} — extracting stim channels from RAW, defining trials, mapping events...")

        # --- 1. Load RAW EEG file ---
        raw_file = self.project_base / "derivatives" / "eeg" / self.subject / "bima_DBSOFF" / f"{self.subject}_ses-DBSOFF_task-bima_eeg_ica_cleaned_raw.fif"
        if not raw_file.exists():
            raise FileNotFoundError(f"❌ RAW file not found: {raw_file}")

        raw = mne.io.read_raw_fif(str(raw_file), preload=True)

        # --- 2. Extract pacing onsets from stim channels ---
        pacing_onsets_df = self.extract_pacing_onsets_from_raw(raw)
        print(f"⏱️  Extracted {len(pacing_onsets_df)} pacing triggers from stim channels.")

        # --- 3. Define trials from pacing triggers ---
        trials_df = self.define_trials_from_pacing_onsets(pacing_onsets_df)
        if trials_df.empty:
            raise ValueError(f"❌ No valid trials defined for {self.subject}")
        print(f"🧩 Defined {len(trials_df)} trials.")

        # --- 4. Load binary events (behavioral/condition) ---
        events_df = self.load_binary_events()
        print(f"📊 Loaded {len(events_df)} binary events.")

        # --- 5. Assign events to trials ---
        events_with_trials = self.assign_events_to_trials(events_df, trials_df)
        print(f"✅ Assigned {len(events_with_trials)} events to trials.")

        # --- 6. Add stim_channel info ---
        final_df = self.add_stim_channel_info(events_with_trials, pacing_onsets_df)
        print(f"🏷️  Added stim_channel info to all events.")

        # --- 7. Final cleanup: select and sort columns ---
        output_columns = ['eeg_time_s', 'trial_number', 'condition', 'stim_channel']
        final_df = final_df[output_columns].sort_values(['trial_number', 'eeg_time_s']).reset_index(drop=True)

        # --- 8. Optionally save to CSV ---
        if self.save_output:
            output_file = self.output_dir / "events_with_trials.csv"
            final_df.to_csv(output_file, index=False)
            print(f"✅ Saved final CSV: {output_file}")

        print(f"🎉 Done processing {self.subject} — returning DataFrame with {len(final_df)} rows.\n")
        return final_df  # 👈 THIS IS YOUR FINAL OUTPUT


def find_all_subjects(project_base: Path):
    """Find all subjects with required files."""
    eeg_dir = project_base / "derivatives" / "eeg"
    subjects = []

    for item in eeg_dir.iterdir():
        if item.is_dir() and item.name.startswith("sub-"):
            session_dir = item / "bima_DBSOFF"
            if not session_dir.exists():
                continue

            # Must have RAW file + binary events
            raw_file = session_dir / f"{item.name}_ses-DBSOFF_task-bima_eeg_ica_cleaned_raw.fif"
            events_file = session_dir / f"{item.name}_events_mne_binary-eve.fif"
            event_id_file = session_dir / f"{item.name}_event_id_binary.json"

            if raw_file.exists() and events_file.exists() and event_id_file.exists():
                subjects.append(item.name)

    subjects.sort(key=lambda x: int(x.split('-')[1]))
    return subjects


# ========================
# 🔥 FIXED: PHASE SEGMENT DETECTION — stim_channel taken from OUT-PHASE WINDOW
# ========================

def find_phase_segments_per_trial(df, subject="unknown", min_in_phase_duration=3.0, max_outphase_in_inphase=1):
    """
    ✅ FINAL STRICT VERSION:
    - Find first 'OutofPhase' event occurring AFTER min_in_phase_duration (default=5s).
    - Check that in the 10s BEFORE, there are ≤ max_outphase_in_inphase (default=2) OutofPhase events.
    - If not, skip and look for next candidate.
    - Extract 10s BEFORE (in-phase) and 10s AFTER (out-of-phase).
    - Use stim_channel from the break event itself.
    """
    if df.empty:
        print(f"⚠️  No events to process for {subject}")
        return pd.DataFrame()

    # Ensure required columns exist
    required_cols = ['eeg_time_s', 'trial_number', 'condition', 'stim_channel']
    missing = [col for col in required_cols if col not in df.columns]
    if missing:
        raise ValueError(f"❌ Missing columns in data for {subject}: {missing}")

    results = []

    # Group by trial
    for trial_num, trial_group in df.groupby('trial_number'):
        try:
            trial_group = trial_group.sort_values('eeg_time_s').reset_index(drop=True)
            if len(trial_group) == 0:
                continue

            trial_start = trial_group['eeg_time_s'].min()
            trial_end = trial_group['eeg_time_s'].max()

            # Find ALL OutofPhase events AFTER minimum in-phase duration
            min_break_time = trial_start + min_in_phase_duration
            candidate_events = trial_group[
                (trial_group['condition'] == 'OutofPhase') &
                (trial_group['eeg_time_s'] >= min_break_time)
            ].sort_values('eeg_time_s')

            if len(candidate_events) == 0:
                print(f"⏭️  No phase break candidates in trial {trial_num} for {subject} (no OutofPhase after {min_in_phase_duration:.1f}s)")
                continue

            valid_break_found = False

            # Check each candidate break
            for _, first_out in candidate_events.iterrows():
                break_time = first_out['eeg_time_s']
                stim_ch = first_out['stim_channel']

                # Define 10s in-phase window BEFORE break
                in_phase_start = max(break_time - 10.0, trial_start)
                in_phase_end = break_time

                # Get events in in-phase window
                in_events = trial_group[
                    (trial_group['eeg_time_s'] >= in_phase_start) &
                    (trial_group['eeg_time_s'] < in_phase_end)
                ]

                # Count OutofPhase events in in-phase window
                n_out_in_inphase = (in_events['condition'] == 'OutofPhase').sum()

                # ✅ STRICT RULE: Reject if too many OutofPhase events in "in-phase" window
                if n_out_in_inphase > max_outphase_in_inphase:
                    print(f"   ⚠️  Skipping break at {break_time:.2f}s (trial {trial_num}): {n_out_in_inphase} OutofPhase events in in-phase window (max allowed: {max_outphase_in_inphase})")
                    continue

                # ✅ VALID BREAK FOUND
                out_phase_start = break_time
                out_phase_end = min(break_time + 10.0, trial_end)

                # Get events in out-phase window
                out_events = trial_group[
                    (trial_group['eeg_time_s'] >= out_phase_start) &
                    (trial_group['eeg_time_s'] < out_phase_end)
                ]

                results.append({
                    'subject': subject,
                    'trial_number': trial_num,
                    'stim_channel': stim_ch,
                    'break_time': break_time,
                    'in_phase_start': in_phase_start,
                    'in_phase_end': in_phase_end,
                    'out_phase_start': out_phase_start,
                    'out_phase_end': out_phase_end,
                    'n_in_phase_events': len(in_events),
                    'n_out_phase_events': len(out_events),
                    'n_outphase_in_inphase': n_out_in_inphase,  # 👈 NEW: explicit count
                    'n_inphase_in_inphase': (in_events['condition'] == 'InPhase').sum(),
                    'in_phase_condition_counts': in_events['condition'].value_counts().to_dict(),
                    'out_phase_condition_counts': out_events['condition'].value_counts().to_dict(),
                })

                print(f"✅ ACCEPTED phase break for {subject} trial {trial_num}: "
                      f"at {break_time:.2f}s (stim={stim_ch}) | "
                      f"In-phase [{in_phase_start:.2f}-{in_phase_end:.2f}] → {n_out_in_inphase} OutofPhase (≤{max_outphase_in_inphase}) | "
                      f"Out-phase [{out_phase_start:.2f}-{out_phase_end:.2f}]")

                valid_break_found = True
                break  # Take first VALID break

            if not valid_break_found:
                print(f"⏭️  No VALID phase break found in trial {trial_num} for {subject} (all candidates had too many OutofPhase events in pre-window)")

        except Exception as e:
            print(f"❌ Error processing trial {trial_num} for {subject}: {e}")
            continue

    if not results:
        print(f"❌ No valid phase breaks found for {subject}")
        return pd.DataFrame()

    result_df = pd.DataFrame(results)
    print(f"📊 Summary for {subject}: {len(result_df)} trials with STRICTLY VALID phase breaks")
    return result_df

    
# ========================
# 🧩 MAIN EXECUTION — Process all subjects and extract phase segments
# ========================

if __name__ == "__main__":
    print("="*80)
    print("🔎 STEP 1: Discovering subjects...")
    print("="*80)

    subjects = find_all_subjects(Path(PROJECT_BASE))
    if not subjects:
        print("❌ No valid subjects found.")
        exit(1)

    print(f"✅ Found {len(subjects)} subjects: {subjects}")

    # Output directory for phase segments
    PHASE_OUTPUT_DIR = Path(PROJECT_BASE) / "derivatives" / "phase_segments"
    PHASE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    all_phase_segments = []

    print("\n" + "="*80)
    print("🧬 STEP 2: Processing each subject → events + phase segments")
    print("="*80)

    for subject in subjects:
        print(f"\n{'='*80}")
        print(f"🧬 PROCESSING SUBJECT: {subject}")
        print(f"{'='*80}")

        try:
            # --- 1. Generate events_with_trials.csv ---
            mapper = EventToTrialMapper(PROJECT_BASE, subject=subject, save_output=True)
            df = mapper.run()  # This saves CSV and returns DataFrame

            if df is None or len(df) == 0:
                print(f"❌ No event data for {subject}")
                continue

            # --- 2. Find phase segments ---
            print(f"\n🔍 STEP 3: Finding phase segments for {subject}...")
            segments_df = find_phase_segments_per_trial(df, subject=subject)

            if segments_df.empty:
                print(f"⏭️  No phase segments found for {subject}")
                continue

            # --- 3. Save per-subject phase segments ---
            phase_file = PHASE_OUTPUT_DIR / f"{subject}_phase_segments.csv"
            segments_df.to_csv(phase_file, index=False)
            print(f"✅ Saved phase segments: {phase_file}")

            # Collect for combined file
            all_phase_segments.append(segments_df)

        except Exception as e:
            print(f"❌ FAILED {subject}: {e}")
            continue

    # --- 4. Save combined phase segments ---
    if all_phase_segments:
        combined_df = pd.concat(all_phase_segments, ignore_index=True)
        combined_file = PHASE_OUTPUT_DIR / "all_subjects_phase_segments.csv"
        combined_df.to_csv(combined_file, index=False)
        print(f"\n🎉 SAVED COMBINED FILE: {combined_file}")
        print(f"📊 Total trials with phase segments: {len(combined_df)}")

        # Print stim_channel distribution to verify fix
        print("\n✅ STIM CHANNEL DISTRIBUTION (should NOT be all 1a):")
        print(combined_df['stim_channel'].value_counts().sort_index())

        # Show sample
        print("\n📋 SAMPLE OF CORRECTED DATA:")
        print(combined_df[['subject', 'trial_number', 'stim_channel', 'out_phase_start', 'in_phase_start']].head(10))

    else:
        print("\n❌ No phase segments found for any subject.")

    print(f"\n🎉🎉🎉 ALL DONE — Check output in: {PHASE_OUTPUT_DIR}")

🔎 STEP 1: Discovering subjects...
✅ Found 12 subjects: ['sub-01', 'sub-02', 'sub-03', 'sub-05', 'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-10', 'sub-11', 'sub-12', 'sub-14']

🧬 STEP 2: Processing each subject → events + phase segments

🧬 PROCESSING SUBJECT: sub-01

🚀 Processing sub-01 — extracting stim channels from RAW, defining trials, mapping events...
Opening raw data file /home/jaizor/jaizor/xtra/derivatives/eeg/sub-01/bima_DBSOFF/sub-01_ses-DBSOFF_task-bima_eeg_ica_cleaned_raw.fif...
    Range : 125209 ... 340681 =    250.418 ...   681.362 secs
Ready.
Reading 0 ... 215472  =      0.000 ...   430.944 secs...
⏱️  Extracted 1200 pacing triggers from stim channels.
🧩 Defined 10 trials.
📊 Loaded 578 binary events.
✅ Assigned 505 events to trials.
🏷️  Added stim_channel info to all events.
✅ Saved final CSV: /home/jaizor/jaizor/xtra/derivatives/features/sub-01/events_with_trials.csv
🎉 Done processing sub-01 — returning DataFrame with 505 rows.


🔍 STEP 3: Finding phase segments for s

In [3]:
import pandas as pd

# Load combined file
df = pd.read_csv("/home/jaizor/jaizor/xtra/derivatives/phase_segments/all_subjects_phase_segments.csv")

print(f"Found {len(df)} valid phase segments across {df['subject'].nunique()} subjects.")

Found 74 valid phase segments across 12 subjects.


In [4]:
df

Unnamed: 0,subject,trial_number,stim_channel,break_time,in_phase_start,in_phase_end,out_phase_start,out_phase_end,n_in_phase_events,n_out_phase_events,n_outphase_in_inphase,n_inphase_in_inphase,in_phase_condition_counts,out_phase_condition_counts
0,sub-01,4,4a,154.78,144.78,154.78,154.78,164.78,10,15,1,9,"{'InPhase': 9, 'OutofPhase': 1}","{'OutofPhase': 8, 'InPhase': 7}"
1,sub-01,6,2a,235.20,225.20,235.20,235.20,245.20,9,18,1,8,"{'InPhase': 8, 'OutofPhase': 1}","{'InPhase': 11, 'OutofPhase': 7}"
2,sub-02,1,3a,22.30,12.30,22.30,22.30,32.30,2,17,0,2,{'InPhase': 2},"{'InPhase': 11, 'OutofPhase': 6}"
3,sub-02,2,3a,73.35,63.35,73.35,73.35,83.35,2,17,0,2,{'InPhase': 2},"{'InPhase': 12, 'OutofPhase': 5}"
4,sub-02,3,3a,123.53,113.53,123.53,123.53,133.53,3,8,0,3,{'InPhase': 3},"{'InPhase': 5, 'OutofPhase': 3}"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
69,sub-14,6,5a,308.50,298.50,308.50,308.50,318.05,0,7,0,0,{},"{'OutofPhase': 6, 'InPhase': 1}"
70,sub-14,7,5a,360.50,350.50,360.50,360.50,370.50,0,6,0,0,{},{'OutofPhase': 6}
71,sub-14,8,1a,383.35,379.30,383.35,383.35,393.35,1,1,1,0,{'OutofPhase': 1},{'OutofPhase': 1}
72,sub-14,9,4a,456.47,446.47,456.47,456.47,466.47,0,5,0,0,{},{'OutofPhase': 5}


In [5]:
# %%
import matplotlib.pyplot as plt
import seaborn as sns

def plot_validation_for_subject(subject, project_base=PROJECT_BASE, max_trials=5):
    """
    Validate phase segment detection by plotting:
    - Original behavioral events (InPhase/OutofPhase)
    - Stim channel transitions
    - Detected 10s in-phase / out-of-phase segments
    All aligned to EEG time.
    """
    print(f"\n🔍 VALIDATING PHASE SEGMENTS FOR {subject}...")

    # --- Load files ---
    base_path = Path(project_base)
    events_file = base_path / "derivatives" / "features" / subject / "events_with_trials.csv"
    segments_file = base_path / "derivatives" / "phase_segments" / f"{subject}_phase_segments.csv"
    raw_file = base_path / "derivatives" / "eeg" / subject / "bima_DBSOFF" / f"{subject}_ses-DBSOFF_task-bima_eeg_ica_cleaned_raw.fif"

    if not events_file.exists():
        print(f"❌ events_with_trials.csv not found for {subject}")
        return
    if not segments_file.exists():
        print(f"❌ phase_segments.csv not found for {subject}")
        return
    if not raw_file.exists():
        print(f"❌ RAW file not found for {subject}")
        return

    # Load data
    events_df = pd.read_csv(events_file)
    segments_df = pd.read_csv(segments_file)

    # Optional: Load raw to get total duration (for x-limits)
    try:
        raw = mne.io.read_raw_fif(str(raw_file), preload=False)
        total_duration = raw.times[-1]
    except:
        total_duration = events_df['eeg_time_s'].max() + 10

    # Get unique trials
    trials = sorted(events_df['trial_number'].unique())
    trials = trials[:max_trials]  # Limit for readability

    if len(trials) == 0:
        print(f"⚠️ No trials found for {subject}")
        return

    # Color maps
    condition_colors = {'InPhase': 'green', 'OutofPhase': 'red'}
    stim_colors = {
        '1a': '#FF6B6B', '2a': '#4ECDC4', '3a': '#45B7D1',
        '4a': '#96CEB4', '5a': '#FFEAA7', '6a': '#DDA0DD',
        'unknown': '#888888'
    }

    # Plot per trial
    for trial_num in trials:
        trial_events = events_df[events_df['trial_number'] == trial_num].sort_values('eeg_time_s')
        trial_segments = segments_df[segments_df['trial_number'] == trial_num]

        if len(trial_events) == 0:
            continue

        # Get time range
        tmin = trial_events['eeg_time_s'].min() - 2
        tmax = min(trial_events['eeg_time_s'].max() + 12, total_duration)

        fig, ax = plt.subplots(1, 1, figsize=(16, 6))

        # --- Plot events as vertical lines ---
        for _, row in trial_events.iterrows():
            color = condition_colors.get(row['condition'], 'gray')
            ax.axvline(row['eeg_time_s'], color=color, alpha=0.7, lw=2,
                      label=row['condition'] if row.name == trial_events.index[0] else "")
            # Annotate stim channel
            stim = row['stim_channel']
            ax.text(row['eeg_time_s'], ax.get_ylim()[1] * 0.95, stim,
                   rotation=90, fontsize=8, color=stim_colors.get(stim, '#000'),
                   verticalalignment='top', horizontalalignment='center')

        # --- Plot detected segments ---
        for _, seg in trial_segments.iterrows():
            # In-phase segment
            ax.axvspan(seg['in_phase_start'], seg['in_phase_end'],
                      color='blue', alpha=0.2, label='Detected In-Phase (10s)' if seg.name == 0 else "")
            # Out-of-phase segment
            ax.axvspan(seg['out_phase_start'], seg['out_phase_end'],
                      color='red', alpha=0.2, label='Detected Out-of-Phase (10s)' if seg.name == 0 else "")

            # Mark segment boundaries
            ax.axvline(seg['in_phase_start'], color='blue', linestyle='--', lw=1.5, alpha=0.8)
            ax.axvline(seg['out_phase_end'], color='red', linestyle='--', lw=1.5, alpha=0.8)

        # --- Labels and formatting ---
        ax.set_title(f"{subject} - Trial {trial_num} | EEG Time", fontweight='bold', fontsize=14)
        ax.set_xlabel("Time (s) - EEG Time", fontweight='bold')
        ax.set_ylabel("Events & Segments", fontweight='bold')
        ax.set_xlim(tmin, tmax)
        ax.grid(True, alpha=0.3)

        # Create legend without duplicates
        handles, labels = ax.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        ax.legend(by_label.values(), by_label.keys(), loc='upper right')

        # Save plot
        output_dir = base_path / "derivatives" / "validation_plots"
        output_dir.mkdir(parents=True, exist_ok=True)
        plot_file = output_dir / f"{subject}_trial{trial_num}_validation.png"
        plt.tight_layout()
        plt.savefig(plot_file, dpi=300, bbox_inches='tight')
        plt.close()

        print(f"✅ Saved validation plot: {plot_file}")

        # Print summary
        n_in = len(trial_events[trial_events['condition'] == 'InPhase'])
        n_out = len(trial_events[trial_events['condition'] == 'OutofPhase'])
        n_segments = len(trial_segments)
        stim_summary = trial_events['stim_channel'].value_counts().to_dict()

        print(f"   📊 Trial {trial_num}: {n_in} InPhase, {n_out} OutofPhase events")
        print(f"   🎛️  Stim channels: {stim_summary}")
        print(f"   🟦🟥 Detected segments: {n_segments}")

    print(f"🎉 Validation plots generated for {subject}")

In [6]:
# %%
print("\n" + "="*80)
print("🔍 STEP 4: VALIDATING PHASE SEGMENTS WITH PLOTS")
print("="*80)

for subject in subjects:
    try:
        plot_validation_for_subject(subject, max_trials=5)  # Limit to 5 trials per subject
    except Exception as e:
        print(f"❌ Failed to generate validation plot for {subject}: {e}")
        continue

print(f"\n✅ ALL VALIDATION PLOTS SAVED TO: {Path(PROJECT_BASE) / 'derivatives' / 'validation_plots'}")


🔍 STEP 4: VALIDATING PHASE SEGMENTS WITH PLOTS

🔍 VALIDATING PHASE SEGMENTS FOR sub-01...
Opening raw data file /home/jaizor/jaizor/xtra/derivatives/eeg/sub-01/bima_DBSOFF/sub-01_ses-DBSOFF_task-bima_eeg_ica_cleaned_raw.fif...
    Range : 125209 ... 340681 =    250.418 ...   681.362 secs
Ready.
✅ Saved validation plot: /home/jaizor/jaizor/xtra/derivatives/validation_plots/sub-01_trial1_validation.png
   📊 Trial 1: 6 InPhase, 33 OutofPhase events
   🎛️  Stim channels: {'1a': 11, '6a': 10, '4a': 6, '3a': 5, '2a': 4, '5a': 3}
   🟦🟥 Detected segments: 0
✅ Saved validation plot: /home/jaizor/jaizor/xtra/derivatives/validation_plots/sub-01_trial2_validation.png
   📊 Trial 2: 31 InPhase, 29 OutofPhase events
   🎛️  Stim channels: {'3a': 13, '1a': 11, '6a': 11, '4a': 11, '2a': 7, '5a': 7}
   🟦🟥 Detected segments: 0
✅ Saved validation plot: /home/jaizor/jaizor/xtra/derivatives/validation_plots/sub-01_trial3_validation.png
   📊 Trial 3: 9 InPhase, 26 OutofPhase events
   🎛️  Stim channels: {'1a