## Single Subject  VIDEO EVENT DETECTOR
Incorporates balanced event detection, enhanced trial plotting, and MNE format saving

In [None]:
# --- Libraries ---
import mne
import json
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter, hilbert, butter, filtfilt
from scipy.interpolate import interp1d
from scipy.ndimage import uniform_filter1d
from matplotlib.lines import Line2D
import warnings
import shutil
warnings.filterwarnings('ignore')


# === Coordination Analyzer Class ===
class CoordinationAnalyzer:
    """Enhanced coordination analysis with improved event detection and plotting"""
    def __init__(self, subject_id, config=None):
        self.subject_id = subject_id
        self.config = self._get_default_config()
        if config:
            self.config.update(config)
        # Data storage
        self.times = None
        self.right_x = None
        self.left_x = None
        self.event_df = None
        self.trials_df = None
        # Preprocessing results
        self.right_x_smooth = None
        self.left_x_smooth = None
        self.right_x_filtered = None
        self.left_x_filtered = None
        self.right_vel = None
        self.left_vel = None
        # Analysis results
        self.phase_diff_wrapped = None
        self.coordination_index = None
        self.coord_smooth = None
        self.std_coord = None

    def _get_default_config(self):
        return {
            'keypoint_index': 8,
            'sample_rate': 60.0,
            'window_sec': 0.5,
            'anti_phase_threshold_rad': 5 * np.pi / 6,
            'breakdown_std_threshold': 0.60,
            'window_size_ratio': 0.5,
            'initial_window_ratio': 1.0,
            'grouping_window_ratio': 0.5,
            'trial_gap_threshold_s': 2.5,
            'extension_duration_s': 4.3,
            'eeg_to_behavior_delay': 0.0,
            'filter_lowcut': 0.5,
            'filter_highcut': 10.0,
            'stim_channels': ['TT140', 'TT255', '1a', '2a', '3a', '4a', '5a', '6a'],
            'pacing_channels': ['1a', '2a', '3a', '4a', '5a', '6a'],
            'target_in_phase_events': 280  # NEW: Target number for balanced sampling
        }

    def load_behavioral_data(self, file_path):
        """Load and preprocess behavioral keypoint data"""
        print(f"Loading behavioral data for {self.subject_id}...")
        with open(file_path, 'r') as f:
            data = json.load(f)
        # Extract data
        times, right_x, left_x = [], [], []
        for frame in data['instance_info']:
            t = frame['frame_id'] / self.config['sample_rate']
            times.append(t)
            rx, lx = np.nan, np.nan
            for inst in frame['instances']:
                label = inst['label']
                kpts = inst['keypoints']
                if self.config['keypoint_index'] < len(kpts):
                    x, y = kpts[self.config['keypoint_index']]
                    if label == "Right":
                        rx = x
                    elif label == "Left":
                        lx = x
            right_x.append(rx)
            left_x.append(lx)
        self.times = np.array(times)
        self.right_x = np.array(right_x)
        self.left_x = np.array(left_x)
        print(f"✓ Loaded {len(self.times)} frames, duration: {self.times[-1]:.1f}s")
        return self

    def preprocess_signals(self):
        """Enhanced preprocessing with better filtering"""
        # Interpolate missing values
        self.right_x_smooth = self._interpolate_and_smooth(self.right_x)
        self.left_x_smooth = self._interpolate_and_smooth(self.left_x)
        # Apply bandpass filter (IMPROVED)
        self.right_x_filtered = self._bandpass_filter(self.right_x_smooth)
        self.left_x_filtered = self._bandpass_filter(self.left_x_smooth)
        # Compute velocities
        self.right_vel = self._compute_velocity(self.right_x_smooth)
        self.left_vel = self._compute_velocity(self.left_x_smooth)
        return self

    def _interpolate_and_smooth(self, x):
        """Interpolate NaN values and apply smoothing"""
        valid = ~np.isnan(x)
        if np.sum(valid) < 2:
            return np.zeros_like(self.times) * np.nan
        f = interp1d(self.times[valid], x[valid], kind='linear', fill_value='extrapolate')
        x_interp = f(self.times)
        window = min(int(self.config['window_sec'] * self.config['sample_rate']) | 1, len(self.times) // 4)
        return savgol_filter(x_interp, window_length=window, polyorder=3)

    def _bandpass_filter(self, data):
        """Enhanced bandpass filter with better error handling"""
        try:
            nyquist = 0.5 * self.config['sample_rate']
            low = self.config['filter_lowcut'] / nyquist
            high = self.config['filter_highcut'] / nyquist
            if low <= 0 or high >= 1 or low >= high:
                print(f"⚠️ Invalid filter frequencies for {self.subject_id}, using unfiltered data")
                return data
            b, a = butter(5, [low, high], btype='band')
            return filtfilt(b, a, data)
        except Exception as e:
            print(f"⚠️ Filter failed for {self.subject_id}: {e}, using unfiltered data")
            return data

    def _compute_velocity(self, pos):
        """Compute and smooth velocity"""
        vel = np.gradient(pos, self.times)
        window = min(int(self.config['window_sec'] * self.config['sample_rate']) | 1, len(self.times) // 4)
        return savgol_filter(vel, window_length=window, polyorder=3)

    def analyze_coordination(self):
        """Enhanced coordination analysis"""
        # Compute phases using Hilbert transform
        phase_left = self._compute_phase_from_velocity(self.left_vel)
        phase_right = self._compute_phase_from_velocity(self.right_vel)
        # Phase difference analysis
        phase_diff_raw = phase_left - phase_right
        phase_diff_smooth = uniform_filter1d(phase_diff_raw, size=int(0.2 * self.config['sample_rate']))
        self.phase_diff_wrapped = ((phase_diff_smooth + np.pi) % (2 * np.pi)) - np.pi
        # Coordination metrics
        self.coordination_index = np.cos(self.phase_diff_wrapped)
        self.coord_smooth = uniform_filter1d(self.coordination_index, size=int(0.3 * self.config['sample_rate']))
        # Compute coordination variability
        window_size = int(self.config['window_size_ratio'] * self.config['sample_rate'])
        self.std_coord = self._compute_rolling_std(self.coord_smooth, window_size)
        return self

    def _compute_phase_from_velocity(self, vel):
        """Compute phase from velocity using Hilbert transform"""
        vel_centered = vel - np.mean(vel)
        analytic = hilbert(vel_centered)
        return np.unwrap(np.angle(analytic))

    def _compute_rolling_std(self, signal, window_size):
        """Compute rolling standard deviation efficiently"""
        std_signal = np.zeros_like(signal)
        half_window = window_size // 2
        for i in range(len(signal)):
            start = max(0, i - half_window)
            end = min(len(signal), i + half_window + 1)
            std_signal[i] = np.std(signal[start:end])
        return std_signal

    def detect_events(self):
        """ENHANCED: Detect events with balanced sampling"""
        # Classify coordination modes
        in_phase_current = np.abs(self.phase_diff_wrapped) < self.config['anti_phase_threshold_rad']
        anti_phase_current = np.abs(self.phase_diff_wrapped) >= self.config['anti_phase_threshold_rad']
        # Get grouping window
        grouping_window = int(self.config['grouping_window_ratio'] * self.config['sample_rate'])
        # Get indices for each state
        in_phase_indices = np.where(in_phase_current)[0]
        anti_phase_indices = np.where(anti_phase_current)[0]
        high_std_indices = np.where(self.std_coord > self.config['breakdown_std_threshold'])[0]
        # Find breakdown candidates (in-phase + high variability)
        breakdown_candidate_indices = np.intersect1d(in_phase_indices, high_std_indices)
        # Group breakdown candidates
        breakdown_groups = self._group_indices(breakdown_candidate_indices, grouping_window)
        # Create In-Phase Breakdown events
        in_phase_breakdown_events = []
        for group in breakdown_groups:
            idx = group[0]
            in_phase_breakdown_events.append({
                'time_s': float(self.times[idx]),
                'frame': int(idx),
                'type': 'In-Phase Breakdown'
            })
        # Create Anti-Phase events
        anti_phase_groups = self._group_indices(anti_phase_indices, grouping_window)
        anti_phase_events = []
        for group in anti_phase_groups:
            idx = group[0]
            anti_phase_events.append({
                'time_s': float(self.times[idx]),
                'frame': int(idx),
                'type': 'Anti-Phase Event'
            })
        # NEW: Create balanced In-Phase events using improved sampling
        in_phase_events = self._create_in_phase_events_balanced(
            in_phase_indices, self.config['target_in_phase_events']
        )
        # Combine all events
        all_events = in_phase_breakdown_events + anti_phase_events + in_phase_events
        all_events.sort(key=lambda x: x['time_s'])
        # Create DataFrame
        if all_events:
            self.event_df = pd.DataFrame(all_events)
            self.event_df['time_s'] = self.event_df['time_s'].round(2)
            self.event_df['subject_id'] = self.subject_id
        else:
            self.event_df = pd.DataFrame(columns=['time_s', 'frame', 'type', 'subject_id'])
        # Print enhanced summary
        self._print_enhanced_analysis_summary(in_phase_current, anti_phase_current, all_events)
        return self

    def _create_in_phase_events_balanced(self, in_phase_indices, target_events=250):
        """NEW: Create balanced in-phase events by intelligent sampling"""
        events = []
        if len(in_phase_indices) == 0:
            return events
        # Group consecutive in-phase points
        grouping_window = int(self.config['grouping_window_ratio'] * self.config['sample_rate'])
        in_phase_groups = self._group_indices(in_phase_indices, grouping_window)
        # Collect all sampleable points
        all_sampleable_points = []
        for group in in_phase_groups:
            if len(group) == 0:
                continue
            elif len(group) <= 5:  # Small group - just take the first point
                all_sampleable_points.append(group[0])
            else:  # Long group - sample multiple points
                step = max(1, len(group) // min(5, len(group)))
                sampled_from_group = group[::step]
                all_sampleable_points.extend(sampled_from_group)
        # Sample to reach target
        if len(all_sampleable_points) <= target_events:
            selected_indices = all_sampleable_points
        else:
            step = len(all_sampleable_points) // target_events
            selected_indices = all_sampleable_points[::step][:target_events]
        # Create events
        for idx in selected_indices:
            events.append({
                'time_s': float(self.times[idx]),
                'frame': int(idx),
                'type': 'In-Phase Event'
            })
        return events[:target_events]

    def _group_indices(self, indices, gap):
        """Group nearby indices"""
        if len(indices) == 0:
            return []
        groups = []
        current_group = [indices[0]]
        for idx in indices[1:]:
            if idx - current_group[-1] <= gap:
                current_group.append(idx)
            else:
                groups.append(np.array(current_group))
                current_group = [idx]
        groups.append(np.array(current_group))
        return groups

    def _print_enhanced_analysis_summary(self, in_phase_current, anti_phase_current, all_events):
        """Enhanced analysis summary with event type breakdown"""
        in_phase_pct = np.mean(in_phase_current) * 100
        anti_phase_pct = np.mean(anti_phase_current) * 100
        print(f"\n📊 {self.subject_id} ENHANCED COORDINATION ANALYSIS")
        print("—" * 60)
        print(f"Total events detected: {len(all_events)}")
        if all_events:
            event_types = {}
            for event in all_events:
                event_type = event['type']
                event_types[event_type] = event_types.get(event_type, 0) + 1
            print(f"Event type distribution:")
            for event_type, count in event_types.items():
                print(f"  {event_type}: {count}")
        print(f"In-phase time: {in_phase_pct:.1f}%")
        print(f"Anti-phase time: {anti_phase_pct:.1f}%")

    def create_plots(self, output_dir="plots", plot_duration=60.0):
        """Generate all standard plots (keeping original functionality)"""
        output_dir = Path(output_dir)
        output_dir.mkdir(exist_ok=True)
        # Main coordination analysis plot
        self._plot_coordination_overview(output_dir, plot_duration)
        # Variability plot
        self._plot_variability(output_dir, plot_duration)
        return self

    def _plot_coordination_overview(self, output_dir, plot_duration):
        """Create main coordination analysis plot"""
        mask_plot = self.times <= plot_duration
        t_plot = self.times[mask_plot]
        # Filter events for plotting
        events_plot = self.event_df[self.event_df['time_s'] <= plot_duration] if not self.event_df.empty else pd.DataFrame()
        colors = {
            'right_hand': "#9121B4", 'left_hand': "#4446D6", 'breakdown': "#E60000",
            'in_phase': "#4446D6", 'anti_phase': "#D81049", 'background_grid': "#FDFDFD"
        }
        fig, axes = plt.subplots(3, 1, figsize=(14, 10))
        # 1. Position
        ax = axes[0]
        ax.plot(t_plot, self.right_x_smooth[mask_plot], '-', color=colors['right_hand'], label='Right Hand', linewidth=2)
        ax.plot(t_plot, self.left_x_smooth[mask_plot], '-', color=colors['left_hand'], label='Left Hand', linewidth=2)
        self._add_event_lines(ax, events_plot, colors)
        ax.set_ylabel("Position (px)", fontweight='bold')
        ax.set_title(f"{self.subject_id} - Hand Positions", fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        # 2. Coordination Index
        ax = axes[1]
        in_phase_plot = np.abs(self.phase_diff_wrapped[mask_plot]) < self.config['anti_phase_threshold_rad']
        anti_phase_plot = np.abs(self.phase_diff_wrapped[mask_plot]) >= self.config['anti_phase_threshold_rad']
        ax.plot(t_plot, self.coord_smooth[mask_plot], '-', color='black', linewidth=2, label='Coordination Index')
        ax.fill_between(t_plot, -1, 1, where=in_phase_plot, alpha=0.3, color=colors['in_phase'], label='In-Phase')
        ax.fill_between(t_plot, -1, 1, where=anti_phase_plot, alpha=0.3, color=colors['anti_phase'], label='Anti-Phase')
        ax.axhline(0, color='gray', linestyle=':', alpha=0.6)
        self._add_event_lines(ax, events_plot, colors)
        ax.set_ylabel("cos(Δφ)", fontweight='bold')
        ax.set_title("Coordination Index", fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        # 3. Coordination Variability
        ax = axes[2]
        ax.plot(t_plot, self.std_coord[mask_plot], '-', color=colors['in_phase'], linewidth=2, label='Coordination Std')
        ax.axhline(self.config['breakdown_std_threshold'], color=colors['breakdown'], linestyle=':', label='Breakdown Threshold')
        self._add_event_lines(ax, events_plot, colors)
        ax.set_ylabel("Std Dev", fontweight='bold')
        ax.set_xlabel("Time (s)", fontweight='bold')
        ax.set_title("Coordination Variability", fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(output_dir / f"{self.subject_id}_coordination_analysis.png", dpi=300, bbox_inches='tight')
        plt.close()
        print(f"✅ Saved coordination analysis plot: {self.subject_id}_coordination_analysis.png")

    def _plot_variability(self, output_dir, plot_duration):
        """Create enhanced variability plot with all event types"""
        mask_plot = self.times <= plot_duration
        t_plot = self.times[mask_plot]
        events_plot = self.event_df[self.event_df['time_s'] <= plot_duration] if not self.event_df.empty else pd.DataFrame()
        plt.figure(figsize=(14, 5))
        plt.plot(t_plot, self.std_coord[mask_plot], color='indigo', linewidth=2.5, alpha=0.8, label='Coordination Variability')
        plt.axhline(self.config['breakdown_std_threshold'], color='black', linestyle=':', alpha=0.7, label='Instability Threshold')
        # Add event lines with enhanced colors
        event_colors = {
            'In-Phase Breakdown': 'blue',
            'Anti-Phase Event': 'red',
            'In-Phase Event': 'green'  # NEW
        }
        for _, event in events_plot.iterrows():
            color = event_colors.get(event['type'], 'black')
            plt.axvline(x=event['time_s'], color=color, linestyle='--', linewidth=2.5, alpha=0.8)
        # Enhanced legend
        legend_elements = [
            Line2D([0], [0], color='indigo', lw=2.5, label='Coordination Variability'),
            Line2D([0], [0], color='black', lw=1.5, linestyle=':', label='Instability Threshold'),
            Line2D([0], [0], color='blue', lw=2.5, linestyle='--', label='In-Phase Breakdown'),
            Line2D([0], [0], color='red', lw=2.5, linestyle='--', label='Anti-Phase Event'),
            Line2D([0], [0], color='green', lw=2.5, linestyle='--', label='In-Phase Event')
        ]
        plt.ylabel("Std Dev of Coordination Index", fontweight='bold')
        plt.xlabel("Time (s)", fontweight='bold')
        plt.title(f"{self.subject_id} - Enhanced Coordination Variability (0 – {plot_duration:.0f}s)", fontweight='bold')
        plt.legend(handles=legend_elements, loc='upper right')
        plt.grid(True, alpha=0.4)
        plt.ylim(0, np.max(self.std_coord[mask_plot]) * 1.1 if len(self.std_coord[mask_plot]) > 0 else 1.0)
        plt.tight_layout()
        plt.savefig(output_dir / f"{self.subject_id}_enhanced_variability.png", dpi=300, bbox_inches='tight')
        plt.close()
        print(f"✅ Saved enhanced variability plot: {self.subject_id}_enhanced_variability.png")

    def _add_event_lines(self, ax, events_df, colors):
        """Add event lines to plot with enhanced event types"""
        if events_df.empty:
            return
        # Enhanced event type handling
        event_type_colors = {
            'In-Phase Breakdown': colors['breakdown'],
            'Anti-Phase Event': colors['anti_phase'],
            'In-Phase Event': 'green'
        }
        # Group events by type to avoid duplicate labels
        plotted_types = set()
        for _, event in events_df.iterrows():
            event_type = event['type']
            color = event_type_colors.get(event_type, 'black')
            linestyle = '--' if 'Breakdown' in event_type else '-.' if 'Anti-Phase' in event_type else ':'
            label = event_type if event_type not in plotted_types else ""
            plotted_types.add(event_type)
            ax.axvline(event['time_s'], color=color, linestyle=linestyle, alpha=0.7, label=label)


# === Enhanced EEGTrialTracker Class ===
class EEGTrialTracker:
    """Enhanced EEG trial detection and alignment with improved stim_channel handling"""
    def __init__(self, subject_id, base_path, mff_filename, config=None):
        self.subject_id = subject_id
        self.base_path = Path(base_path)
        self.mff_filename = mff_filename
        self.config = config or {}
        self.raw = None

    def load_and_align(self, analyzer):
        """Load EEG data and align with behavioral events"""
        try:
            # Load EEG data
            mff_path = self.base_path / self.mff_filename
            if not mff_path.exists():
                print(f"⚠️ EEG file not found: {mff_path}")
                return analyzer
            self.raw = mne.io.read_raw_egi(str(mff_path), preload=True, verbose=False)
            self._setup_channels()
            # Get trials
            trials_df = self._get_trials_from_pacing()
            if trials_df.empty:
                print(f"⚠️ No trials found for {self.subject_id}")
                return analyzer
            # Align events
            analyzer = self._align_behavioral_events(analyzer, trials_df)
            analyzer.trials_df = trials_df
            print(f"✅ EEG alignment complete for {self.subject_id}")
            return analyzer
        except Exception as e:
            print(f"⚠️ EEG alignment failed for {self.subject_id}: {e}")
            return analyzer

    def _setup_channels(self):
        """Setup channel names and types"""
        rename_dict = {str(i): f'E{i}' for i in range(1, 281)}
        rename_dict['REF CZ'] = 'Cz'
        self.raw.rename_channels(rename_dict)
        stim_channels = self.config.get('stim_channels', ['TT140', 'TT255', '1a', '2a', '3a', '4a', '5a', '6a'])
        existing_stim = [ch for ch in stim_channels if ch in self.raw.ch_names]
        self.raw.set_channel_types({ch: 'stim' for ch in existing_stim})

    def _get_trials_from_pacing(self):
        """Extract trials from pacing channels with enhanced detection"""
        pacing_channels = self.config.get('pacing_channels', ['1a', '2a', '3a', '4a', '5a', '6a'])
        picks = [ch for ch in pacing_channels if ch in self.raw.ch_names]
        if not picks:
            return pd.DataFrame()
        # Extract all pacing events
        all_onsets = []
        for ch_name in picks:
            ch_idx = self.raw.ch_names.index(ch_name)
            ch_data = self.raw.get_data(picks=[ch_idx])[0]
            digital = (ch_data != 0).astype(int)
            starts = np.where(np.diff(digital, prepend=0) == 1)[0]
            event_times = self.raw.times[starts] + self.config.get('eeg_to_behavior_delay', 0.0)
            for t in event_times:
                all_onsets.append({'channel': ch_name, 'eeg_time_s': t})
        if not all_onsets:
            return pd.DataFrame()
        events_df = pd.DataFrame(all_onsets).sort_values('eeg_time_s').reset_index(drop=True)
        # Define trials based on gaps
        return self._define_trials_from_gaps(events_df)

    def _define_trials_from_gaps(self, events_df):
        """Enhanced trial definition with improved gap detection"""
        gap_threshold = self.config.get('trial_gap_threshold_s', 2.5)
        extension_duration = self.config.get('extension_duration_s', 4.3)
        times = events_df['eeg_time_s'].values
        intervals = np.diff(times)
        gap_indices = np.where(intervals > gap_threshold)[0]
        # Define block boundaries
        boundaries = []
        start_idx = 0
        for gap_idx in gap_indices:
            boundaries.append((start_idx, gap_idx))
            start_idx = gap_idx + 1
        boundaries.append((start_idx, len(events_df) - 1))
        # Create trials with enhanced information
        trials = []
        for start_idx, end_idx in boundaries:
            block_events = events_df.iloc[start_idx:end_idx+1]
            first_1a = block_events[block_events['channel'] == '1a']['eeg_time_s'].min()
            last_6a = block_events[block_events['channel'] == '6a']['eeg_time_s'].max()
            if pd.notna(first_1a) and pd.notna(last_6a) and last_6a > first_1a:
                trials.append({
                    'trial_number': len(trials) + 1,
                    'start_time': first_1a,
                    'end_time': last_6a,
                    'end_time_extended': last_6a + extension_duration,
                    'duration_s': last_6a - first_1a,
                    'subject_id': self.subject_id
                })
        return pd.DataFrame(trials)

    def _align_behavioral_events(self, analyzer, trials_df):
        """Enhanced behavioral event alignment with improved stim_channel handling"""
        if analyzer.event_df.empty or trials_df.empty:
            return analyzer
        # Convert behavioral time to EEG time
        first_trial_start = trials_df['start_time'].iloc[0]
        analyzer.event_df['eeg_time_s'] = (analyzer.event_df['time_s'] + first_trial_start).round(3)
        # Assign to trials using extended window
        def assign_to_trial(eeg_time):
            for _, trial in trials_df.iterrows():
                if trial['start_time'] <= eeg_time < trial['end_time_extended']:
                    return trial['trial_number']
            return np.nan
        analyzer.event_df['trial_number'] = analyzer.event_df['eeg_time_s'].apply(assign_to_trial)
        # Enhanced stim_channel assignment
        print(f"🧠 Adding enhanced stim_channel info for {self.subject_id}...")
        pacing_channels = self.config.get('pacing_channels', ['1a', '2a', '3a', '4a', '5a', '6a'])
        pacing_onsets = []
        for ch_name in pacing_channels:
            if ch_name not in self.raw.ch_names:
                continue
            ch_idx = self.raw.ch_names.index(ch_name)
            ch_data = self.raw.get_data(picks=[ch_idx])[0]
            digital = (ch_data != 0).astype(int)
            transitions = np.diff(digital, prepend=0)
            starts = np.where(transitions == 1)[0]  # Rising edges
            event_times = self.raw.times[starts]
            for t in event_times:
                pacing_onsets.append({'eeg_time_s': t, 'stim_channel': ch_name})
        # Sort all pacing triggers by time
        pacing_df = pd.DataFrame(pacing_onsets).sort_values('eeg_time_s').reset_index(drop=True)
        # Enhanced function to get last stim channel before event
        def get_last_stim_channel(eeg_time_s):
            valid = pacing_df[pacing_df['eeg_time_s'] <= eeg_time_s]
            if not valid.empty:
                return valid.iloc[-1]['stim_channel']
            return np.nan
        # Apply to each event
        analyzer.event_df['stim_channel'] = analyzer.event_df['eeg_time_s'].apply(get_last_stim_channel)
        print(f"✅ Enhanced stim_channel added for {self.subject_id}")
        return analyzer

    def save_events_to_mne_format(self, analyzer, output_dir="results"):
        """NEW: Save events in MNE format for further analysis"""
        if analyzer.event_df.empty or self.raw is None:
            print(f"⚠️ Cannot save MNE events for {self.subject_id}: missing data")
            return
        try:
            # Convert events to MNE format
            events_list = []
            event_id = {}
            event_counter = 1
            for _, event in analyzer.event_df.iterrows():
                if pd.notna(event['eeg_time_s']):
                    # Find the sample index corresponding to the time
                    sample_idx = int(event['eeg_time_s'] * self.raw.info['sfreq'])
                    # Create unique event ID for each event type
                    event_type = event['type']
                    if event_type not in event_id:
                        event_id[event_type] = event_counter
                        event_counter += 1
                    events_list.append([sample_idx, 0, event_id[event_type]])
            if events_list:
                events_array = np.array(events_list)
                # Save MNE events
                output_path = Path(output_dir) / self.subject_id
                output_path.mkdir(parents=True, exist_ok=True)
                events_file = output_path / f"{self.subject_id}_events_mne.fif"
                mne.write_events(str(events_file), events_array)
                # Save event_id mapping
                event_id_file = output_path / f"{self.subject_id}_event_id.json"
                with open(event_id_file, 'w') as f:
                    json.dump(event_id, f, indent=2)
                print(f"✅ Saved MNE events: {events_file}")
                print(f"✅ Saved event ID mapping: {event_id_file}")
            else:
                print(f"⚠️ No valid events to save in MNE format for {self.subject_id}")
        except Exception as e:
            print(f"❌ Failed to save MNE events for {self.subject_id}: {e}")


# === Enhanced Processing Functions ===
def process_subject(subject_id, behavioral_file, eeg_config=None, analysis_config=None, output_dir="results"):
    """Process a single subject through the complete pipeline"""
    print(f"\n🔄 Processing {subject_id}...")
    # Initialize analyzer
    analyzer = CoordinationAnalyzer(subject_id, analysis_config)
    # Load and analyze behavioral data
    analyzer.load_behavioral_data(behavioral_file)
    analyzer.preprocess_signals()
    analyzer.analyze_coordination()
    analyzer.detect_events()
    # EEG alignment (if config provided)
    if eeg_config:
        tracker = EEGTrialTracker(subject_id, **eeg_config)
        analyzer = tracker.load_and_align(analyzer)
        # NEW: Save MNE format events
        tracker.save_events_to_mne_format(analyzer, output_dir)
    # Create plots
    output_path = Path(output_dir) / subject_id
    output_path.mkdir(parents=True, exist_ok=True)
    analyzer.create_plots(output_path)
    # Save event data
    if not analyzer.event_df.empty:
        event_file = output_path / f"{subject_id}_events.csv"
        analyzer.event_df.to_csv(event_file, index=False)
        print(f"✅ Saved events: {event_file}")
    # Save trial data (if available)
    if hasattr(analyzer, 'trials_df') and analyzer.trials_df is not None:
        trial_file = output_path / f"{subject_id}_trials.csv"
        analyzer.trials_df.to_csv(trial_file, index=False)
        print(f"✅ Saved trials: {trial_file}")
    print(f"✅ {subject_id} processing complete!")
    return analyzer

def process_multiple_subjects(subject_configs, output_dir="results"):
    """Process multiple subjects in batch"""
    results = {}
    for subject_id, config in subject_configs.items():
        try:
            analyzer = process_subject(
                subject_id=subject_id,
                behavioral_file=config['behavioral_file'],
                eeg_config=config.get('eeg_config'),
                analysis_config=config.get('analysis_config'),
                output_dir=output_dir
            )
            results[subject_id] = analyzer
        except Exception as e:
            print(f"❌ Failed to process {subject_id}: {e}")
            results[subject_id] = None
    # Combine all events
    all_events = []
    for subject_id, analyzer in results.items():
        if analyzer and not analyzer.event_df.empty:
            all_events.append(analyzer.event_df)
    if all_events:
        combined_events = pd.concat(all_events, ignore_index=True)
        combined_file = Path(output_dir) / "combined_events.csv"
        combined_events.to_csv(combined_file, index=False)
        print(f"\n✅ Saved combined events: {combined_file}")
    return results


# === Visualize Trials Functions ===
def plot_single_trial_for_subject(analyzer, trial_number, output_dir="results", show_events=True, duration_limit=50):
    """
    Enhanced trial plotting with 6 subplots including Hilbert phase analysis.
    Plots first `duration_limit` seconds of each trial based on EEG time and includes stim channel events.
    X-axis shows EEG time.

    Parameters:
    - analyzer: CoordinationAnalyzer object after processing.
    - trial_number: int, the trial number to plot.
    - output_dir: str, base directory for saving results.
    - show_events: bool, whether to mark events on the plot.
    - duration_limit: float, maximum duration to plot in seconds (default 50 seconds).
    """
    # --- 1. Check if necessary data exists ---
    if analyzer.trials_df is None or analyzer.trials_df.empty:
        print(f"⚠️ No trial data available for {analyzer.subject_id}. Skipping trial plot {trial_number}.")
        return
    if trial_number not in analyzer.trials_df['trial_number'].values:
        print(f"❌ Trial {trial_number} not found for subject {analyzer.subject_id}. Available trials: {list(analyzer.trials_df['trial_number'])}")
        return
    if not hasattr(analyzer, 'right_x_filtered') or not hasattr(analyzer, 'left_x_filtered'):
        print(f"⚠️ Filtered signals not found for {analyzer.subject_id}. Run preprocessing first. Skipping trial plot {trial_number}.")
        return
    if analyzer.event_df is None or analyzer.event_df.empty: # Check if empty too
        print(f"⚠️ Event data not found or empty for {analyzer.subject_id}. Skipping event marking for trial {trial_number}.")
        show_events = False

    # --- 2. Get trial info (EEG time base) ---
    trial = analyzer.trials_df[analyzer.trials_df['trial_number'] == trial_number].iloc[0]
    start_time_eeg = trial['start_time']  # This is in EEG time (s)
    
    # Determine the end time based on duration limit and trial extension
    # Use 'end_time_extended' if available for a potentially longer view, but still limit by duration_limit
    actual_end_time_eeg = trial.get('end_time_extended', trial['end_time'])
    end_time_eeg = min(start_time_eeg + duration_limit, actual_end_time_eeg)
    plot_duration_eeg = end_time_eeg - start_time_eeg # This is the actual duration plotted

    # --- 3. Align Behavioral Time with EEG Time ---
    # The alignment offset was calculated during EEG processing:
    # analyzer.event_df['eeg_time_s'] = (analyzer.event_df['time_s'] + first_trial_start)
    # Therefore: behavioral_time = eeg_time - first_trial_start
    # We need the `first_trial_start` value. Let's assume it's the start_time of trial 1.
    # A more robust way is to check the offset used in alignment.
    # Let's find the offset by comparing the first event's behavioral and EEG times if possible.
    # Or, we can use the trial's start_time and the corresponding behavioral time segment.
    
    # Find the index in behavioral data (`analyzer.times`) that corresponds to `start_time_eeg`
    # We need to find the behavioral time that maps to `start_time_eeg`
    # Since `eeg_time_s = behavioral_time_s + offset`, we have `offset = eeg_time_s - behavioral_time_s`
    # We can use the first event of this trial to find the offset, or assume it's consistent.
    # Let's find the first event of this trial to get the offset.
    trial_events_all = analyzer.event_df[analyzer.event_df['trial_number'] == trial_number]
    if not trial_events_all.empty:
        first_event_of_trial = trial_events_all.iloc[0]
        # offset = eeg_time_s - behavioral_time_s
        time_offset = first_event_of_trial['eeg_time_s'] - first_event_of_trial['time_s']
        # print(f"Debug: Time offset for trial {trial_number}: {time_offset}")
    else:
        # If no events, we cannot reliably determine the offset. Fallback or skip?
        # Let's try to use the trial start time. The trial start time in EEG should correspond
        # to some point in the behavioral data. This is tricky without the explicit offset.
        # Let's assume the offset is `first_trial_start` from the EEG alignment process.
        # How was `first_trial_start` determined? It was `trials_df['start_time'].iloc[0]`.
        # The behavioral data starts at t=0. So, the offset should be `first_trial_start`.
        # Let's try to get it from the first trial's start time in EEG.
        if not analyzer.trials_df.empty:
             first_trial_start_eeg = analyzer.trials_df['start_time'].iloc[0]
             # Behavioral time 0 maps to EEG time `first_trial_start_eeg`
             time_offset = first_trial_start_eeg # eeg_time = behavioral_time + time_offset
             # print(f"Debug: Using first trial start as offset: {time_offset}")
        else:
             print(f"❌ Cannot determine time alignment for trial {trial_number} of {analyzer.subject_id}. Skipping.")
             return

    # Now, convert EEG time window to behavioral time window for data extraction
    start_time_behavioral = start_time_eeg - time_offset
    end_time_behavioral = end_time_eeg - time_offset

    # --- 4. Extract data segment (using behavioral time for indexing) ---
    mask = (analyzer.times >= start_time_behavioral) & (analyzer.times <= end_time_behavioral)
    t_segment_behavioral = analyzer.times[mask] # Behavioral time axis for data
    if len(t_segment_behavioral) == 0:
        print(f"❌ No behavioral data found for trial {trial_number} of {analyzer.subject_id} (EEG: {start_time_eeg:.2f}s – {end_time_eeg:.2f}s, Behavioral: {start_time_behavioral:.2f}s – {end_time_behavioral:.2f}s)")
        return

    # Convert behavioral time axis to EEG time axis for plotting
    t_segment_eeg = t_segment_behavioral + time_offset # This is the correct EEG time axis for the x-axis

    left_x_segment = analyzer.left_x_filtered[mask]
    right_x_segment = analyzer.right_x_filtered[mask]
    coord_smooth_segment = analyzer.coord_smooth[mask]
    std_coord_segment = analyzer.std_coord[mask]

    # --- 5. Enhanced Hilbert Transform Analysis ---
    try:
        # Left hand analysis
        analytic_left = hilbert(left_x_segment)
        phase_wrapped_left = np.angle(analytic_left)
        phase_unwrapped_left = np.unwrap(phase_wrapped_left)
        amplitude_left = np.abs(analytic_left)
        # Right hand analysis
        analytic_right = hilbert(right_x_segment)
        phase_wrapped_right = np.angle(analytic_right)
        phase_unwrapped_right = np.unwrap(phase_wrapped_right)
        amplitude_right = np.abs(analytic_right)
        # Relative phase analysis
        rp_unwrap = phase_unwrapped_right - phase_unwrapped_left
        rp_normalized = np.angle(np.exp(1j * rp_unwrap))  # Wraps to [-π, π]
    except Exception as e:
        print(f"❌ Error during Hilbert transform for trial {trial_number} of {analyzer.subject_id}: {e}")
        return

    # --- 6. Get Events for this Trial within the specified EEG time window ---
    trial_events_filtered = pd.DataFrame()
    if show_events and not analyzer.event_df.empty:
        # Filter events: must belong to this trial AND fall within the EEG time window of the plot
        trial_events_filtered = analyzer.event_df[
            (analyzer.event_df['trial_number'] == trial_number) &
            (analyzer.event_df['eeg_time_s'] >= start_time_eeg) &
            (analyzer.event_df['eeg_time_s'] <= end_time_eeg)
        ].copy()

        if not trial_events_filtered.empty:
            print(f"   📊 Found {len(trial_events_filtered)} events for trial {trial_number} within EEG time {start_time_eeg:.2f}s - {end_time_eeg:.2f}s")
            if 'stim_channel' in trial_events_filtered.columns:
                stim_channels = trial_events_filtered['stim_channel'].dropna().unique() # Drop NaN if any
                if len(stim_channels) > 0:
                     print(f"   🎛️ Stim channels in this segment: {list(stim_channels)}")
                else:
                     print(f"   ⚠️ No valid stim channels found for events in this segment.")
        else:
             print(f"   🟢 No events found for trial {trial_number} within the specified EEG time window ({start_time_eeg:.2f}s - {end_time_eeg:.2f}s).")


    # === 7. ENHANCED PLOT WITH 6 SUBPLOTS (X-axis: EEG Time) ===
    fig, axes = plt.subplots(6, 1, figsize=(16, 14), sharex=True)

    # Define colors
    colors = {
        'right_hand': '#9121B4',
        'left_hand': '#4446D6',
        'breakdown': '#E60000',
        'in_phase': '#4446D6',
        'anti_phase': '#D81049'
    }
    stim_channel_colors = {
        '1a': '#FF6B6B',
        '2a': '#4ECDC4',
        '3a': '#45B7D1',
        '4a': '#96CEB4',
        '5a': '#FFEAA7',
        '6a': '#DDA0DD',
        'default': '#FF8C00'
    }

    # --- 8. Plotting (using EEG time axis) ---
    # 1. Position with Amplitude Envelope
    ax = axes[0]
    ax.plot(t_segment_eeg, right_x_segment, color=colors['right_hand'], label='Right Hand', lw=2)
    ax.plot(t_segment_eeg, left_x_segment, color=colors['left_hand'], label='Left Hand', lw=2)
    ax.plot(t_segment_eeg, right_x_segment + amplitude_right, color=colors['right_hand'], alpha=0.3, lw=1)
    ax.plot(t_segment_eeg, right_x_segment - amplitude_right, color=colors['right_hand'], alpha=0.3, lw=1)
    ax.fill_between(t_segment_eeg, right_x_segment - amplitude_right, right_x_segment + amplitude_right,
                    color=colors['right_hand'], alpha=0.1)
    # Add events
    for _, e in trial_events_filtered.iterrows():
        stim_channel = e.get('stim_channel', 'default')
        event_type = e.get('type', '')
        color = stim_channel_colors.get(stim_channel, stim_channel_colors['default'])
        linestyle = '--'
        if 'Breakdown' in event_type:
            linestyle = '-.'
        elif 'Anti-Phase' in event_type:
            linestyle = ':'
        ax.axvline(e['eeg_time_s'], color=color, ls=linestyle, lw=1.5, alpha=0.8)
        ax.text(e['eeg_time_s'], ax.get_ylim()[1], f"{stim_channel}",
                rotation=90, verticalalignment='bottom',
                fontsize=8, color=color, alpha=0.7)
    ax.set_ylabel("Position (px)", fontweight='bold')
    ax.set_title(f"Subject {analyzer.subject_id} - Trial {trial_number}: First {plot_duration_eeg:.1f}s (EEG Time) ({len(trial_events_filtered)} events)", fontweight='bold')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)

    # 2. Wrapped Phase
    ax = axes[1]
    ax.plot(t_segment_eeg, phase_wrapped_left, color=colors['left_hand'], label='Left Hand', lw=2)
    ax.plot(t_segment_eeg, phase_wrapped_right, color=colors['right_hand'], label='Right Hand', lw=2)
    ax.axhline(0, color='k', ls=':', alpha=0.6)
    # Add events
    for _, e in trial_events_filtered.iterrows():
        stim_channel = e.get('stim_channel', 'default')
        color = stim_channel_colors.get(stim_channel, stim_channel_colors['default'])
        linestyle = '--'
        if 'Breakdown' in e.get('type', ''):
            linestyle = '-.'
        elif 'Anti-Phase' in e.get('type', ''):
            linestyle = ':'
        ax.axvline(e['eeg_time_s'], color=color, ls=linestyle, lw=1.0, alpha=0.6)
    ax.set_ylabel("Wrapped Phase (rad)", fontweight='bold')
    ax.set_title("Instantaneous Phase (Wrapped) - EEG Time", fontweight='bold')
    ax.set_ylim(-np.pi - 0.2, np.pi + 0.2)
    ax.set_yticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
    ax.set_yticklabels([r'$-\pi$', r'$-\pi/2$', '0', r'$\pi/2$', r'$\pi$'])
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 3. Unwrapped Phase
    ax = axes[2]
    ax.plot(t_segment_eeg, phase_unwrapped_left, color=colors['left_hand'], label='Left Hand', lw=2)
    ax.plot(t_segment_eeg, phase_unwrapped_right, color=colors['right_hand'], label='Right Hand', lw=2)
    # Add events
    for _, e in trial_events_filtered.iterrows():
        stim_channel = e.get('stim_channel', 'default')
        color = stim_channel_colors.get(stim_channel, stim_channel_colors['default'])
        ax.axvline(e['eeg_time_s'], color=color, ls='--', lw=1.0, alpha=0.6)
    ax.set_ylabel("Unwrapped Phase (rad)", fontweight='bold')
    ax.set_title("Instantaneous Phase (Unwrapped) - EEG Time", fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 4. Relative Phase
    ax = axes[3]
    ax.plot(t_segment_eeg, rp_normalized, color='indigo', lw=2.5, label='Relative Phase (Right - Left)')
    ax.axhspan(-np.pi/6, np.pi/6, alpha=0.2, color='green', label='In-Phase Region')
    ax.axhspan(5*np.pi/6, np.pi, alpha=0.2, color='red', label='Anti-Phase Region')
    ax.axhspan(-np.pi, -5*np.pi/6, alpha=0.2, color='red')
    ax.axhline(0, color='green', ls='--', lw=1.5, alpha=0.7)
    ax.axhline(np.pi, color='red', ls='--', lw=1.5, alpha=0.7)
    ax.axhline(-np.pi, color='red', ls='--', lw=1.5, alpha=0.7)
    # Add events
    for _, e in trial_events_filtered.iterrows():
        stim_channel = e.get('stim_channel', 'default')
        color = stim_channel_colors.get(stim_channel, stim_channel_colors['default'])
        ax.axvline(e['eeg_time_s'], color=color, ls='--', lw=1.5, alpha=0.8)
    ax.set_ylim(-np.pi - 0.2, np.pi + 0.2)
    ax.set_yticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
    ax.set_yticklabels([r'$-\pi$', r'$-\pi/2$', '0', r'$\pi/2$', r'$\pi$'])
    ax.set_ylabel("Relative Phase (rad)", fontweight='bold')
    ax.set_title("Enhanced Relative Phase Dynamics - EEG Time", fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 5. Coordination Index
    ax = axes[4]
    ax.plot(t_segment_eeg, coord_smooth_segment, color='black', lw=2, label='Coordination Index')
    ax.axhline(0, color='gray', linestyle=':', alpha=0.6)
    ax.axhline(np.cos(analyzer.config['anti_phase_threshold_rad']), color='red',
              linestyle=':', alpha=0.7, label='Anti-Phase Threshold')
    in_phase_mask = coord_smooth_segment > np.cos(analyzer.config['anti_phase_threshold_rad'])
    ax.fill_between(t_segment_eeg, -1, 1, where=in_phase_mask, alpha=0.2, color='blue', label='In-Phase')
    ax.fill_between(t_segment_eeg, -1, 1, where=~in_phase_mask, alpha=0.2, color='red', label='Anti-Phase')
    # Add events
    for _, e in trial_events_filtered.iterrows():
        stim_channel = e.get('stim_channel', 'default')
        color = stim_channel_colors.get(stim_channel, stim_channel_colors['default'])
        ax.axvline(e['eeg_time_s'], color=color, ls='--', lw=1.0, alpha=0.6)
    ax.set_ylabel("cos(Δφ)", fontweight='bold')
    ax.set_title("Coordination Index - EEG Time", fontweight='bold')
    ax.set_ylim(-1.1, 1.1)
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 6. Coordination Variability
    ax = axes[5]
    ax.plot(t_segment_eeg, std_coord_segment, color='purple', lw=2, label='Coordination Variability')
    ax.axhline(analyzer.config['breakdown_std_threshold'], color='red',
              linestyle=':', alpha=0.7, label='Breakdown Threshold')
    high_var_mask = std_coord_segment > analyzer.config['breakdown_std_threshold']
    ax.fill_between(t_segment_eeg, 0, np.max(std_coord_segment) if len(std_coord_segment) > 0 else 1,
                   where=high_var_mask, alpha=0.3, color='red', label='High Variability')
    # Add events
    for _, e in trial_events_filtered.iterrows():
        stim_channel = e.get('stim_channel', 'default')
        color = stim_channel_colors.get(stim_channel, stim_channel_colors['default'])
        ax.axvline(e['eeg_time_s'], color=color, ls='--', lw=1.5, alpha=0.8)
    ax.set_ylabel("Std Dev", fontweight='bold')
    ax.set_xlabel("Time (s) - EEG Time", fontweight='bold')
    ax.set_title("Coordination Variability - EEG Time", fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Set x-axis limits explicitly to match the intended window
    ax.set_xlim(start_time_eeg, end_time_eeg)

    plt.tight_layout()

    # --- 9. Save plot ---
    subject_output_dir = Path(output_dir) / analyzer.subject_id
    subject_output_dir.mkdir(parents=True, exist_ok=True)
    # Use the actual plotted duration in the filename
    plot_filename = subject_output_dir / f"{analyzer.subject_id}_trial_{trial_number:02d}_eeg_time_{int(plot_duration_eeg)}s.png"
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()

    # --- 10. Print Summary ---
    print(f"📈 Enhanced trial plot saved: {analyzer.subject_id}, Trial {trial_number} | EEG Duration: {plot_duration_eeg:.2f}s")
    if len(trial_events_filtered) > 0:
        print(f"   🔴 Events in plotted segment: {len(trial_events_filtered)}")
        if 'stim_channel' in trial_events_filtered.columns:
            stim_summary = trial_events_filtered['stim_channel'].value_counts().to_dict()
            for stim, count in stim_summary.items():
                print(f"     • Stim {stim}: {count} events")
    else:
        print(f"   🟢 No events in this plotted EEG time segment ({start_time_eeg:.2f}s - {end_time_eeg:.2f}s)")


def plot_all_trials_for_subject(analyzer, output_dir="results", max_trials=None, duration_limit=50):
    """Plot first `duration_limit` seconds of all trials for a given subject using enhanced plotting with EEG time"""
    if analyzer.trials_df is None or analyzer.trials_df.empty:
        print(f"⚠️ No trial data available for {analyzer.subject_id}. Skipping all trial plots.")
        return

    trial_numbers_to_plot = analyzer.trials_df['trial_number'].tolist()
    if max_trials is not None:
        trial_numbers_to_plot = trial_numbers_to_plot[:max_trials]

    print(f"\n🔄 Plotting first {duration_limit}s (EEG time) of {len(trial_numbers_to_plot)} trial(s) for {analyzer.subject_id}...")
    for trial_num in trial_numbers_to_plot:
        plot_single_trial_for_subject(analyzer, trial_num, output_dir=output_dir, duration_limit=duration_limit)

    print(f"✅ Finished plotting enhanced trials for {analyzer.subject_id} (EEG Time)")



def save_subject_event_df(analyzer, output_dir="results"):
    """Save the enhanced event_df of a subject to CSV"""
    if analyzer.event_df is not None and not analyzer.event_df.empty:
        subject_output_dir = Path(output_dir) / analyzer.subject_id
        subject_output_dir.mkdir(parents=True, exist_ok=True)
        event_file = subject_output_dir / f"{analyzer.subject_id}_events_enhanced.csv"
        analyzer.event_df.to_csv(event_file, index=False)
        print(f"💾 Saved enhanced event data: {event_file}")
    else:
        print(f"⚠️ No event data to save for {analyzer.subject_id}")


def save_subject_event_df(analyzer, output_dir="results"):
    """Save the enhanced event_df of a subject to CSV"""
    if analyzer.event_df is not None and not analyzer.event_df.empty:
        subject_output_dir = Path(output_dir) / analyzer.subject_id
        subject_output_dir.mkdir(parents=True, exist_ok=True)
        event_file = subject_output_dir / f"{analyzer.subject_id}_events_enhanced.csv"
        analyzer.event_df.to_csv(event_file, index=False)
        print(f"💾 Saved enhanced event data: {event_file}")
    else:
        print(f"⚠️ No event data to save for {analyzer.subject_id}")


def create_binary_mne_events_for_subject(subject_id, base_results_path, binary_suffix="_binary"):
    """
    Creates binary MNE event files for a single subject, replicating the logic
    from the batch BINARY MNE script.

    This function takes the multi-class event files generated by the main analysis
    ('In-Phase Event', 'Anti-Phase Event', 'In-Phase Breakdown') and creates new
    binary event files where:
    - 'In-Phase Event' -> 0 (In-Phase)
    - 'Anti-Phase Event' + 'In-Phase Breakdown' -> 1 (Out-of-Phase)

    Parameters:
    -----------
    subject_id : str
        The ID of the subject (e.g., 'Sbj01').
    base_results_path : str or Path
        The path to the main results directory containing subject subfolders.
    binary_suffix : str, optional
        The suffix to append to the new binary files (default is "_binary").
        This creates files like 'Sbj01_events_mne_binary-eve.fif'.

    Returns:
    --------
    bool
        True if the process was successful, False otherwise.
    """
    import pandas as pd
    import mne
    import json
    import numpy as np
    from pathlib import Path

    base_results_path = Path(base_results_path)
    subject_folder = base_results_path / subject_id

    print(f"\n--- Creating Binary MNE Events for Subject: {subject_id} ---")

    # --- 1. Define File Paths for this Subject ---
    trials_path = subject_folder / f"{subject_id}_trials.csv"
    events_path = subject_folder / f"{subject_id}_events.csv"
    # Try the enhanced version if the standard one doesn't exist
    if not events_path.exists():
        events_path = subject_folder / f"{subject_id}_events_enhanced.csv"

    original_mne_events_path = subject_folder / f"{subject_id}_events_mne.fif"
    original_event_id_path = subject_folder / f"{subject_id}_event_id.json"

    # --- 2. Check if Required Files Exist ---
    required_files = [trials_path, events_path, original_mne_events_path, original_event_id_path]
    missing_files = [f for f in required_files if not f.exists()]

    if missing_files:
        print(f"  ⚠️ Skipping {subject_id}. Missing files: {[f.name for f in missing_files]}")
        return False # Indicate failure

    # --- 3. Load Data for this Subject ---
    try:
        # Load DataFrames (optional for verification, but good to have)
        # trials_df = pd.read_csv(trials_path) # Not strictly necessary for this function
        # events_df = pd.read_csv(events_path) # Not strictly necessary for this function
        # print(f"  ✅ Loaded CSV data.")

        # Load original MNE Events and ID mapping
        original_mne_events = mne.read_events(str(original_mne_events_path))
        with open(original_event_id_path, 'r') as f:
            original_event_id_mapping = json.load(f)
        print(f"  ✅ Loaded original MNE events ({original_mne_events.shape}) and ID mapping.")
        print(f"     Original ID Mapping: {original_event_id_mapping}")

    except Exception as e:
        print(f"  ⚠️ Error loading data for {subject_id}: {e}")
        return False

    # --- 4. Prepare for Binary Recoding ---
    # Define the mapping from ORIGINAL event IDs to NEW binary IDs
    # Goal: 0: In-Phase, 1: Out-of-Phase (Anti-Phase + Breakdown)
    id_transformation_map = {}

    # Find the ID for 'In-Phase Event' and map it to 0
    in_phase_id = original_event_id_mapping.get('In-Phase Event')
    if in_phase_id is not None:
        id_transformation_map[in_phase_id] = 0
    else:
        print(f"  ⚠️ 'In-Phase Event' not found in original mapping for {subject_id}. Skipping recoding.")
        return False

    # Find IDs for 'Anti-Phase Event' and 'In-Phase Breakdown' and map them to 1
    anti_phase_id = original_event_id_mapping.get('Anti-Phase Event')
    if anti_phase_id is not None:
        id_transformation_map[anti_phase_id] = 1

    breakdown_id = original_event_id_mapping.get('In-Phase Breakdown')
    if breakdown_id is not None:
        id_transformation_map[breakdown_id] = 1

    if len(id_transformation_map) < 2: # Need at least In-Phase ID and one Out-of-Phase ID
        print(f"  ⚠️ Incomplete original mapping for {subject_id}. Found IDs: {id_transformation_map}. Skipping recoding.")
        return False

    print(f"  🔄 ID Transformation Map: {id_transformation_map}")

    # --- 5. Modify the Existing MNE Events Array ---
    try:
        binary_mne_events_array = original_mne_events.copy()
        original_ids = binary_mne_events_array[:, 2]

        # Recode IDs
        unique_original_ids_in_data = np.unique(original_ids)
        print(f"  📊 Unique original IDs in data: {unique_original_ids_in_data}")

        for orig_id in unique_original_ids_in_data:
            new_id = id_transformation_map.get(orig_id)
            if new_id is not None:
                indices_to_change = np.where(original_ids == orig_id)[0]
                binary_mne_events_array[indices_to_change, 2] = new_id
                # Get original type name for reporting
                old_type = [k for k, v in original_event_id_mapping.items() if v == orig_id][0]
                new_type = "In-Phase" if new_id == 0 else "Out-of-Phase"
                print(f"    Recoded {len(indices_to_change)} events: {old_type} (ID {orig_id}) -> {new_type} (ID {new_id})")
            else:
                print(f"    ⚠️ No mapping for original ID {orig_id}. Events unchanged.")

        print(f"  ✅ Modified MNE events array. Shape: {binary_mne_events_array.shape}")

    except Exception as e:
        print(f"  ⚠️ Error modifying MNE events for {subject_id}: {e}")
        return False

    # --- 6. Define New Event ID Mapping ---
    new_mne_event_id = {
        'In-Phase': 0,
        'Out-of-Phase': 1 # This now includes both Anti-Phase Events and Breakdowns
    }
    print(f"  🆕 New MNE Event ID Mapping: {new_mne_event_id}")

    # --- 7. Save New Binary Files ---
    try:
        # New filenames (using MNE's recommended naming convention)
        binary_events_fif_path = subject_folder / f"{subject_id}_events_mne{binary_suffix}-eve.fif"
        binary_events_txt_path = subject_folder / f"{subject_id}_events_mne{binary_suffix}-eve.txt" # Optional
        new_event_id_json_path = subject_folder / f"{subject_id}_event_id{binary_suffix}.json"

        # Save the BINARY MNE events (binary .fif - Correct Format)
        mne.write_events(str(binary_events_fif_path), binary_mne_events_array, overwrite=True)
        print(f"  💾 Saved BINARY MNE events (.fif): {binary_events_fif_path.name}")

        # Save the BINARY MNE events (text .txt - Optional)
        mne.write_events(str(binary_events_txt_path), binary_mne_events_array, overwrite=True)
        print(f"  💾 Saved BINARY MNE events (.txt): {binary_events_txt_path.name}")

        # Save the NEW event ID mapping (JSON)
        with open(new_event_id_json_path, 'w') as f:
            json.dump(new_mne_event_id, f, indent=2)
        print(f"  💾 Saved NEW event ID mapping: {new_event_id_json_path.name}")

        print(f"  ✅ Successfully processed and saved binary files for {subject_id}.")
        return True # Indicate success

    except Exception as e:
        print(f"  ⚠️ Error saving new files for {subject_id}: {e}")
        return False # Indicate failure


In [12]:
# --- Configuration for Single Subject ---
# 1. SET YOUR MAIN DATA DIRECTORY
data_dir = Path(r"C:\Users\lacom\Downloads\xtra\data\PD_Keypoints")  # <-- CHANGE THIS!

# 2. SPECIFY THE SUBJECT YOU WANT TO PROCESS (e.g., "Sbj01", "Sbj02")
target_subject_id = "Sbj06"  # <-- CHANGE THIS TO THE SUBJECT YOU WANT TO RUN!

# 3. OUTPUT DIRECTORY NAME
output_dir_main = "results_Enhanced_PD_Analysis"

# --- Find Files for the Target Subject ---
keypoint_files = list(data_dir.glob(f"{target_subject_id}_task_hand_keypoints_cam0.json"))
if not keypoint_files:
    raise FileNotFoundError(f"Keypoint file for {target_subject_id} not found in {data_dir}")
behavioral_file = str(keypoint_files[0])

# Find corresponding EEG directory
eeg_mff_dirs = [p for p in data_dir.iterdir() if p.is_dir() and p.name.startswith(f"PD_{int(target_subject_id.replace('Sbj', '')):03d}_")]
eeg_config = None
if eeg_mff_dirs:
    eeg_mff_path = eeg_mff_dirs[0]
    eeg_config = {
        "base_path": str(data_dir),
        "mff_filename": eeg_mff_path.name,
        "config": { # Nesting config as expected by EEGTrialTracker
            'trial_gap_threshold_s': 2.5,
            'extension_duration_s': 4.3,
            'eeg_to_behavior_delay': 0.1,
            'stim_channels': ['TT140', 'TT255', '1a', '2a', '3a', '4a', '5a', '6a'],
            'pacing_channels': ['1a', '2a', '3a', '4a', '5a', '6a']
        }
    }
    print(f"✅ Found EEG file for {target_subject_id}: {eeg_mff_path.name}")
else:
    print(f"⚠️ No matching EEG file found for {target_subject_id}")

# Analysis configuration
analysis_config = {
    'breakdown_std_threshold': 0.60,
    'target_in_phase_events': 280,
    'anti_phase_threshold_rad': 5 * np.pi / 6,
    'keypoint_index': 8,
    'sample_rate': 60.0,
    'window_sec': 0.5,
    'window_size_ratio': 0.5,
    'initial_window_ratio': 1.0,
    'grouping_window_ratio': 0.5,
    'filter_lowcut': 0.5,
    'filter_highcut': 10.0
}

# --- STEP 1: Initialize Analyzer ---
print(f"\n--- STEP 1: Initializing Analyzer for {target_subject_id} ---")
analyzer = CoordinationAnalyzer(target_subject_id, analysis_config)
print("✅ Analyzer initialized.")

# --- STEP 2: Load Behavioral Data ---
print(f"\n--- STEP 2: Loading Behavioral Data for {target_subject_id} ---")
analyzer.load_behavioral_data(behavioral_file)
print("✅ Behavioral data loaded.")

# --- STEP 3: Preprocess Signals ---
print(f"\n--- STEP 3: Preprocessing Signals for {target_subject_id} ---")
analyzer.preprocess_signals()
print("✅ Signals preprocessed.")

# --- STEP 4: Analyze Coordination ---
print(f"\n--- STEP 4: Analyzing Coordination for {target_subject_id} ---")
analyzer.analyze_coordination()
print("✅ Coordination analyzed.")

# --- STEP 5: Detect Events ---
print(f"\n--- STEP 5: Detecting Events for {target_subject_id} ---")
analyzer.detect_events()
print("✅ Events detected.")

# --- STOPPING POINT ---
print(f"\n--- STOPPING POINT: Inspect and Edit analyzer.event_df ---")
print(f"Location of event_df: analyzer.event_df")
print(f"Shape of event_df: {analyzer.event_df.shape}")
print(f"Columns: {list(analyzer.event_df.columns)}")




✅ Found EEG file for Sbj06: PD_006_bima_DBSOFF.mff

--- STEP 1: Initializing Analyzer for Sbj06 ---
✅ Analyzer initialized.

--- STEP 2: Loading Behavioral Data for Sbj06 ---
Loading behavioral data for Sbj06...
✓ Loaded 31491 frames, duration: 524.8s
✅ Behavioral data loaded.

--- STEP 3: Preprocessing Signals for Sbj06 ---
✅ Signals preprocessed.

--- STEP 4: Analyzing Coordination for Sbj06 ---
✅ Coordination analyzed.

--- STEP 5: Detecting Events for Sbj06 ---

📊 Sbj06 ENHANCED COORDINATION ANALYSIS
————————————————————————————————————————————————————————————
Total events detected: 363
Event type distribution:
  In-Phase Event: 187
  In-Phase Breakdown: 65
  Anti-Phase Event: 111
In-phase time: 90.5%
Anti-phase time: 9.5%
✅ Events detected.

--- STOPPING POINT: Inspect and Edit analyzer.event_df ---
Location of event_df: analyzer.event_df
Shape of event_df: (363, 4)
Columns: ['time_s', 'frame', 'type', 'subject_id']


In [13]:
analyzer.event_df

Unnamed: 0,time_s,frame,type,subject_id
0,0.00,0,In-Phase Event,Sbj06
1,4.18,251,In-Phase Event,Sbj06
2,8.37,502,In-Phase Event,Sbj06
3,12.55,753,In-Phase Event,Sbj06
4,16.73,1004,In-Phase Event,Sbj06
...,...,...,...,...
358,524.38,31463,In-Phase Breakdown,Sbj06
359,524.45,31467,In-Phase Event,Sbj06
360,524.55,31473,In-Phase Event,Sbj06
361,524.65,31479,In-Phase Event,Sbj06


In [14]:
# Example: Drop rows with DataFrame index
#analyzer.event_df = analyzer.event_df.drop(index=[5, 10, 15]).reset_index(drop=True)

In [15]:
# --- STEP 6: Align with EEG (if applicable) and Save MNE Events ---
# This step will use the (potentially modified) analyzer.event_df
if eeg_config:
    print(f"\n--- STEP 6a: Aligning with EEG for {target_subject_id} ---")
    tracker = EEGTrialTracker(target_subject_id, **eeg_config)
    analyzer = tracker.load_and_align(analyzer)
    print("✅ EEG alignment complete.")

    print(f"\n--- STEP 6b: Saving MNE Events for {target_subject_id} ---")
    tracker.save_events_to_mne_format(analyzer, output_dir_main)
    print("✅ MNE events saved.")
    
    # --- NEW STEP 6c: Create Binary MNE Events ---
    print(f"\n--- STEP 6c: Creating Binary MNE Events for {target_subject_id} ---")
    try:
        # Define paths for the current subject
        subject_folder = Path(output_dir_main) / target_subject_id
        binary_suffix = "_binary"
        
        # File paths
        events_path = subject_folder / f"{target_subject_id}_events.csv"
        if not events_path.exists():
            events_path = subject_folder / f"{target_subject_id}_events_enhanced.csv"
            
        original_mne_events_path = subject_folder / f"{target_subject_id}_events_mne.fif"
        original_event_id_path = subject_folder / f"{target_subject_id}_event_id.json"
        
        # Check if required files exist
        required_files = [events_path, original_mne_events_path, original_event_id_path]
        missing_files = [f for f in required_files if not f.exists()]
        
        if missing_files:
            print(f"  ⚠️ Skipping binary creation for {target_subject_id}. Missing files: {[f.name for f in missing_files]}")
        else:
            # Load original MNE Events and ID mapping
            original_mne_events = mne.read_events(str(original_mne_events_path))
            with open(original_event_id_path, 'r') as f:
                original_event_id_mapping = json.load(f)
            print(f"  ✅ Loaded original MNE events ({original_mne_events.shape}) and ID mapping.")
            
            # Prepare for Binary Recoding
            id_transformation_map = {}
            
            # Find the ID for 'In-Phase Event' and map it to 0
            in_phase_id = original_event_id_mapping.get('In-Phase Event')
            if in_phase_id is not None:
                id_transformation_map[in_phase_id] = 0
            else:
                print(f"  ⚠️ 'In-Phase Event' not found in original mapping for {target_subject_id}. Skipping binary recoding.")
                
            # Find IDs for 'Anti-Phase Event' and 'In-Phase Breakdown' and map them to 1
            anti_phase_id = original_event_id_mapping.get('Anti-Phase Event')
            if anti_phase_id is not None:
                id_transformation_map[anti_phase_id] = 1
                
            breakdown_id = original_event_id_mapping.get('In-Phase Breakdown')
            if breakdown_id is not None:
                id_transformation_map[breakdown_id] = 1
                
            if len(id_transformation_map) < 2: # Need at least In-Phase ID and one Out-of-Phase ID
                print(f"  ⚠️ Incomplete original mapping for {target_subject_id}. Found IDs: {id_transformation_map}. Skipping binary recoding.")
            else:
                print(f"  🔄 ID Transformation Map: {id_transformation_map}")
                
                # Modify the Existing MNE Events Array
                binary_mne_events_array = original_mne_events.copy()
                original_ids = binary_mne_events_array[:, 2]
                
                # Recode IDs
                unique_original_ids_in_data = np.unique(original_ids)
                print(f"  📊 Unique original IDs in data: {unique_original_ids_in_data}")
                
                for orig_id in unique_original_ids_in_data:
                    new_id = id_transformation_map.get(orig_id)
                    if new_id is not None:
                        indices_to_change = np.where(original_ids == orig_id)[0]
                        binary_mne_events_array[indices_to_change, 2] = new_id
                        # Get original type name for reporting
                        old_type = [k for k, v in original_event_id_mapping.items() if v == orig_id][0]
                        new_type = "In-Phase" if new_id == 0 else "Out-of-Phase"
                        print(f"    Recoded {len(indices_to_change)} events: {old_type} (ID {orig_id}) -> {new_type} (ID {new_id})")
                    else:
                        print(f"    ⚠️ No mapping for original ID {orig_id}. Events unchanged.")
                        
                print(f"  ✅ Modified MNE events array. Shape: {binary_mne_events_array.shape}")
                
                # Define New Event ID Mapping
                new_mne_event_id = {
                    'In-Phase': 0,
                    'Out-of-Phase': 1 # This now includes both Anti-Phase Events and Breakdowns
                }
                print(f"  🆕 New MNE Event ID Mapping: {new_mne_event_id}")
                
                # Save New Binary Files
                # New filenames (using MNE's recommended naming convention)
                binary_events_fif_path = subject_folder / f"{target_subject_id}_events_mne{binary_suffix}-eve.fif"
                binary_events_txt_path = subject_folder / f"{target_subject_id}_events_mne{binary_suffix}-eve.txt" # Optional
                new_event_id_json_path = subject_folder / f"{target_subject_id}_event_id{binary_suffix}.json"
                
                # Save the BINARY MNE events (binary .fif - Correct Format)
                mne.write_events(str(binary_events_fif_path), binary_mne_events_array, overwrite=True)
                print(f"  💾 Saved BINARY MNE events (.fif): {binary_events_fif_path.name}")
                
                # Save the BINARY MNE events (text .txt - Optional)
                mne.write_events(str(binary_events_txt_path), binary_mne_events_array, overwrite=True)
                print(f"  💾 Saved BINARY MNE events (.txt): {binary_events_txt_path.name}")
                
                # Save the NEW event ID mapping (JSON)
                with open(new_event_id_json_path, 'w') as f:
                    json.dump(new_mne_event_id, f, indent=2)
                print(f"  💾 Saved NEW event ID mapping: {new_event_id_json_path.name}")
                
                print(f"  ✅ Successfully created binary files for {target_subject_id}.")
                
    except Exception as e:
        print(f"  ⚠️ Error creating binary MNE events for {target_subject_id}: {e}")

# --- STEP 7: Create Plots ---
print(f"\n--- STEP 7: Creating Plots for {target_subject_id} ---")
output_path = Path(output_dir_main) / target_subject_id
output_path.mkdir(parents=True, exist_ok=True)
analyzer.create_plots(output_path)
print("✅ Plots created.")

# --- STEP 8: Save Event and Trial Data ---
print(f"\n--- STEP 8: Saving Event and Trial Data for {target_subject_id} ---")
if not analyzer.event_df.empty:
    # Save the potentially modified event_df
    event_file = output_path / f"{target_subject_id}_events.csv"
    analyzer.event_df.to_csv(event_file, index=False)
    print(f"✅ Saved (potentially modified) events: {event_file}")

if hasattr(analyzer, 'trials_df') and analyzer.trials_df is not None and not analyzer.trials_df.empty:
    trial_file = output_path / f"{target_subject_id}_trials.csv"
    analyzer.trials_df.to_csv(trial_file, index=False)
    print(f"✅ Saved trials: {trial_file}")

# --- NEW STEP 8b: Create Binary MNE Events (using the new function) ---
print(f"\n--- STEP 8b: Creating Binary MNE Events for {target_subject_id} ---")
success = create_binary_mne_events_for_subject(target_subject_id, output_dir_main, binary_suffix="_binary")
if success:
    print(f"  ✅ Binary MNE events created successfully for {target_subject_id}.")
else:
    print(f"  ⚠️ Failed to create binary MNE events for {target_subject_id}.")

print(f"\n🎉 Processing for {target_subject_id} complete (up to manual editing point)!")
print(f"📁 Results saved to: {output_path}")





--- STEP 6a: Aligning with EEG for Sbj06 ---
🧠 Adding enhanced stim_channel info for Sbj06...
✅ Enhanced stim_channel added for Sbj06
✅ EEG alignment complete for Sbj06
✅ EEG alignment complete.

--- STEP 6b: Saving MNE Events for Sbj06 ---
✅ Saved MNE events: results_Enhanced_PD_Analysis\Sbj06\Sbj06_events_mne.fif
✅ Saved event ID mapping: results_Enhanced_PD_Analysis\Sbj06\Sbj06_event_id.json
✅ MNE events saved.

--- STEP 6c: Creating Binary MNE Events for Sbj06 ---
  ⚠️ Skipping binary creation for Sbj06. Missing files: ['Sbj06_events_enhanced.csv']

--- STEP 7: Creating Plots for Sbj06 ---
✅ Saved coordination analysis plot: Sbj06_coordination_analysis.png
✅ Saved enhanced variability plot: Sbj06_enhanced_variability.png
✅ Plots created.

--- STEP 8: Saving Event and Trial Data for Sbj06 ---
✅ Saved (potentially modified) events: results_Enhanced_PD_Analysis\Sbj06\Sbj06_events.csv
✅ Saved trials: results_Enhanced_PD_Analysis\Sbj06\Sbj06_trials.csv

--- STEP 8b: Creating Binary MNE

In [18]:
# --- STEP 9: Trial Visualization ---
# This also uses the (potentially modified) event_df for marking events
print(f"\n--- STEP 9: Creating Enhanced Trial Plots for {target_subject_id} ---")
# Check if trial data exists before plotting
if hasattr(analyzer, 'trials_df') and analyzer.trials_df is not None and not analyzer.trials_df.empty:
    print(f"Found trial data. Generating plots...")
    plot_all_trials_for_subject(analyzer, output_dir=output_dir_main, max_trials=10, duration_limit=50)
    print(f"✅ Enhanced trial plots created.")
else:
    print(f"⚠️ No trial data found for {target_subject_id}. Skipping enhanced trial plots.")


--- STEP 9: Creating Enhanced Trial Plots for Sbj06 ---
Found trial data. Generating plots...

🔄 Plotting first 50s (EEG time) of 10 trial(s) for Sbj06...
   📊 Found 59 events for trial 1 within EEG time 46.19s - 96.19s
   🎛️ Stim channels in this segment: ['1a', '2a', '3a', '4a', '5a', '6a']
📈 Enhanced trial plot saved: Sbj06, Trial 1 | EEG Duration: 50.00s
   🔴 Events in plotted segment: 59
     • Stim 6a: 15 events
     • Stim 3a: 12 events
     • Stim 5a: 12 events
     • Stim 4a: 10 events
     • Stim 2a: 6 events
     • Stim 1a: 4 events
   📊 Found 20 events for trial 2 within EEG time 98.95s - 148.95s
   🎛️ Stim channels in this segment: ['1a', '2a', '3a', '5a', '6a']
📈 Enhanced trial plot saved: Sbj06, Trial 2 | EEG Duration: 50.00s
   🔴 Events in plotted segment: 20
     • Stim 1a: 11 events
     • Stim 6a: 6 events
     • Stim 2a: 1 events
     • Stim 3a: 1 events
     • Stim 5a: 1 events
   📊 Found 32 events for trial 3 within EEG time 151.68s - 201.68s
   🎛️ Stim channels 