In [2]:
!pip install mne
!pip install mne-connectivity
!pip install ema_pytorch



In [3]:
import os
import pandas as pd
import mne
import numpy as np
from pathlib import Path
import torch
from collections import defaultdict

# Load one subject's data by session
def load_data_by_session(root_dir, subject_id, session_idx_list):
    data = np.load(os.path.join(root_dir, f"S{subject_id}_chars.npy"))
    data = data[:, session_idx_list]
    X = data.reshape(-1, 64, 250)
    Y = np.repeat(np.arange(26), len(session_idx_list))
    return torch.tensor(X, dtype=torch.float32), torch.tensor(Y, dtype=torch.long)

# --------- MI ---------
def MI_load_data_by_session(root_dir, subject_id, session_folders, label_dir):
    """
    root_dir/
      first_session/
        A01T_cleaned.fif … A09T_cleaned.fif
      second_session/
        A01E_cleaned.fif … A09E_cleaned.fif

    session_folders: list of folder names, e.g. ["first_session"] or ["second_session"]
    """
    X_list, Y_list = [], []

    prefix = "T"
    fname = f"A{subject_id:02d}{prefix}.fif"
    fpath = os.path.join(root_dir, "first_session", fname)
    raw = mne.io.read_raw_fif(fpath, preload=True, verbose=False)

    # MI cue as '769'~'772'，mapping as 0–3 labels
    events, event_id = mne.events_from_annotations(raw, verbose=False)
    motor_keys = ['769', '770', '771', '772']
    motor_event_id = {k: v for k, v in event_id.items() if k in motor_keys}
    if len(motor_event_id) < 4:
        raise ValueError(f"{fname} missing MI cues. Found: {event_id}")
    events = np.array([e for e in events if e[2] in motor_event_id.values()])
    label_map = {
        motor_event_id['769']: 0,
        motor_event_id['770']: 1,
        motor_event_id['771']: 2,
        motor_event_id['772']: 3,
    }
    labels = np.array([label_map[e[-1]] for e in events])

    epochs = mne.Epochs(
        raw, events,
        tmin=0.0,
        tmax=4.0,
        baseline=None,
        preload=True,
        verbose=False,
        event_repeated="drop"
    )
    data = epochs.get_data()
    X_list.append(torch.from_numpy(data).float())
    Y_list.append(torch.from_numpy(labels).long())

    X = torch.cat(X_list, dim=0)
    Y = torch.cat(Y_list, dim=0)
    return X, Y

# --------- P300 ---------
def P300_load_subject_data(subject_id, root_dir):
    folder = os.path.join(root_dir, f"subject_{subject_id:02d}")
    X = np.load(os.path.join(folder, "X.npy"))                # shape: (n_trials, C, T)
    Y = np.load(os.path.join(folder, "y.npy"))                # shape: (n_trials,)
    Y = np.array([1 if label == 'Target' else 0 for label in Y])

    meta = pd.read_csv(os.path.join(folder, "metadata.csv"))  # contains at least 'session'

    trials_per_repetition = 12
    reps_per_level = 8
    trials_per_level = reps_per_level * trials_per_repetition  # 96
    levels_per_session = 9

    level_list = []
    repetition_list = []

    for sess in sorted(meta["session"].unique()):
        session_idxs = meta.index[meta["session"] == sess].tolist()
        for i, idx in enumerate(session_idxs):
            rep = i // trials_per_repetition
            level = rep // reps_per_level
            repetition = rep % reps_per_level
            level_list.append(level)
            repetition_list.append(repetition)

    meta["level"] = level_list
    meta["repetition"] = repetition_list

    return {
        "X": X,
        "Y": Y,
        "session": meta["session"].tolist(),
        "level": meta["level"].tolist(),
        "repetition": meta["repetition"].tolist()
    }

# --------- Imagined_speech ---------
def ImaginedSpeech_load_subject_data(subject_id, root_dir):
    x_path = os.path.join(root_dir, f"epochs_{subject_id}_notched.npy")
    y_path = os.path.join(root_dir, f"labels_{subject_id}.npy")

    X = np.load(x_path)  # shape: (n_trials, C, T)
    Y_raw = np.load(y_path, allow_pickle=True)  # shape: (n_trials,), string labels

    Y_raw = Y_raw.flatten()

    # Map labels to level indices
    label_set = sorted(set(Y_raw.tolist()))
    label2idx = {label: i for i, label in enumerate(label_set)}  # consistent across subjects

    level = [label2idx[label] for label in Y_raw]

    # Build repetition index for each stimulus
    counter = defaultdict(int)
    repetition = []
    for label in Y_raw:
        repetition.append(counter[label])
        counter[label] += 1

    # All trials are from a single session
    session = [0] * len(Y_raw)

    return {
        "X": X,
        "Y": np.array(level),
        "session": session,
        "level": level,
        "repetition": repetition
    }

In [4]:
"""
Cross-Trial Analysis
"""

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from collections import defaultdict

# ============================================================================
# CONFIGURATION - Set your task here
# ============================================================================

TASK = 'P300'  # Change to: 'MI', 'SSVEP', 'P300', 'Imagined_speech'

# Data paths (update these)
DATA_PATHS = {
    'MI': '/content/drive/MyDrive/IDL/IDL Project Team 5 F25/dataset/MI/cleaned_data',
    'SSVEP': '/content/drive/MyDrive/IDL/IDL Project Team 5 F25/dataset/ssvep/chars',
    'P300': '/content/drive/MyDrive/IDL/IDL Project Team 5 F25/dataset/p300/bi2015a/cleaned_data',
    'Imagined_speech': '/content/drive/MyDrive/IDL/IDL Project Team 5 F25/dataset/speech_imagined/KARA_ONE/epochs/notched'
}

OUTPUT_DIR = Path(f'/content/drive/MyDrive/IDL/IDL Project Team 5 F25/data analysis/{TASK}/cross_trial_{TASK}')
OUTPUT_DIR.mkdir(exist_ok=True)

# Task-specific parameters
TASK_CONFIG = {
    'MI': {'sfreq': 250, 'n_classes': 4, 'use_alignment': True},
    'SSVEP': {'sfreq': 250, 'n_classes': 26, 'use_alignment': False},
    'P300': {'sfreq': 256, 'n_classes': 2, 'use_alignment': False},
    'Imagined_speech': {'sfreq': 128, 'n_classes': 11, 'use_alignment': False}
}

freq_bands = {'Theta': (4, 8), 'Alpha': (8, 13), 'Beta': (13, 30), 'Gamma': (30, 40)}

# ============================================================================
# NEW: DATA LOADING from dataset.py - Load ALL data for analysis
# ============================================================================

def load_all_subjects_for_task(task, root_dir, **kwargs):
    """
    Load ALL data for all subjects (not split) - for data analysis purposes.
    Returns data organized by subject and condition.
    """
    if task == 'MI':
        return load_all_MI_data(root_dir, kwargs.get('label_dir'))
    elif task == 'SSVEP':
        return load_all_SSVEP_data(root_dir)
    elif task == 'P300':
        return load_all_P300_data(root_dir)
    elif task == 'Imagined_speech':
        return load_all_ImaginedSpeech_data(root_dir)


def load_all_MI_data(root_dir, label_dir=None):
    """Load all MI subjects without splitting - reuse dataset.py loading logic"""
    import numpy as np
    import os

    subject_data = {}

    # Load all 9 subjects (A01-A09)
    for subject_id in range(1, 10):
        try:
            # Load both sessions for each subject
            X, Y = MI_load_data_by_session(
                root_dir, subject_id,
                ["first_session"],
                label_dir
            )

            # Convert to numpy
            X = X.cpu().numpy()  # (n_trials, n_channels, n_times)
            Y = Y.cpu().numpy()

            # Group by condition
            subject_data[subject_id] = {}
            for label in np.unique(Y):
                mask = Y == label
                subject_data[subject_id][int(label)] = X[mask]

            print(f"Loaded Subject {subject_id}: {len(Y)} trials, {len(np.unique(Y))} conditions")

        except Exception as e:
            print(f"Could not load subject {subject_id}: {e}")

    return subject_data


def load_all_SSVEP_data(root_dir):
    """Load all SSVEP subjects without splitting"""
    import numpy as np
    import os

    subject_data = {}

    # Load all 35 subjects
    for subject_id in range(1, 36):
        try:
            # Load all 6 sessions
            X, Y = load_data_by_session(root_dir, subject_id, [0, 1, 2, 3, 4, 5])

            X = X.cpu().numpy()
            Y = Y.cpu().numpy()

            # Group by condition (26 letters)
            subject_data[subject_id] = {}
            for label in np.unique(Y):
                mask = Y == label
                subject_data[subject_id][int(label)] = X[mask]

            print(f"Loaded Subject {subject_id}: {len(Y)} trials")

        except Exception as e:
            print(f"Could not load subject {subject_id}: {e}")

    return subject_data


def load_all_P300_data(root_dir):
    """Load all P300 subjects without splitting"""
    import numpy as np

    subject_data = {}

    # Load all subjects (typically 1-8)
    for subject_id in range(1, 44):
        try:
            data = P300_load_subject_data(subject_id, root_dir)

            X = data['X']  # Already numpy
            Y = data['Y']

            # Group by condition
            subject_data[subject_id] = {}
            for label in np.unique(Y):
                mask = Y == label
                subject_data[subject_id][int(label)] = X[mask]

            print(f"Loaded Subject {subject_id}: {len(Y)} trials")

        except Exception as e:
            print(f"Could not load subject {subject_id}: {e}")

    return subject_data


def load_all_ImaginedSpeech_data(root_dir):
    """Load all Imagined Speech subjects without splitting"""
    import numpy as np
    import os
    import re

    subject_data = {}

    # Find all available subjects from file names
    all_subjects = sorted([
        re.findall(r'epochs_(.*)\.npy', f)[0].replace("_notched", "")
        for f in os.listdir(root_dir)
        if f.startswith("epochs_") and f.endswith(".npy")
    ])

    for subject_id in all_subjects:
        try:
            data = ImaginedSpeech_load_subject_data(subject_id, root_dir)

            X = data['X']  # Already numpy
            Y = data['Y']

            # Group by condition
            subject_data[subject_id] = {}
            for label in np.unique(Y):
                mask = Y == label
                subject_data[subject_id][int(label)] = X[mask]

            print(f"Loaded Subject {subject_id}: {len(Y)} trials")

        except Exception as e:
            print(f"Could not load subject {subject_id}: {e}")

    return subject_data

# ============================================================================
# KEEP YOUR EXISTING HELPER FUNCTIONS (from helpers.py)
# ============================================================================

def find_mi_window(trial_data, sfreq, baseline_end_idx):
    """YOUR EXISTING FUNCTION - no changes needed"""
    power = trial_data ** 2
    baseline_power = np.mean(power[:baseline_end_idx])
    baseline_std = np.std(power[:baseline_end_idx])
    threshold = baseline_power + 2 * baseline_std
    active = power > threshold

    min_duration = int(0.5 * sfreq)
    active_start = None

    for i in range(baseline_end_idx, len(active) - min_duration):
        if np.sum(active[i:i+min_duration]) > 0.7 * min_duration:
            active_start = i
            break

    if active_start is None:
        active_start = baseline_end_idx + int(0.5 * sfreq)

    active_end = min(active_start + int(1.5 * sfreq), len(trial_data))
    return active_start, active_end


def align_trials_to_mi_onset(data, sfreq, baseline_end_idx):
    """YOUR EXISTING FUNCTION - no changes needed"""
    n_trials, n_times = data.shape
    aligned_trials = []
    windows = []

    for trial in data:
        start, end = find_mi_window(trial, sfreq, baseline_end_idx)
        windows.append((start, end))

    window_lengths = [end - start for start, end in windows]
    target_length = int(np.median(window_lengths))

    for trial, (start, end) in zip(data, windows):
        if end - start >= target_length:
            aligned_trials.append(trial[start:start+target_length])
        else:
            segment = trial[start:end]
            padded = np.pad(segment, (0, target_length - len(segment)), mode='edge')
            aligned_trials.append(padded)

    return np.array(aligned_trials)

# ============================================================================
# YOUR EXISTING ANALYSIS FUNCTIONS - Minimal changes
# ============================================================================

def compute_channel_consistency_temporal(trials, sfreq, use_alignment):
    """
    Modified version of your function to work with numpy arrays directly.

    Args:
        trials: numpy array (n_trials, n_channels, n_times)
        sfreq: sampling frequency
        use_alignment: whether to align trials (only for MI)
    """
    n_trials, n_channels, n_times = trials.shape

    if n_trials < 2:
        return np.zeros(n_channels)

    baseline_end_idx = int(1.0 * sfreq)  # Assuming 1s baseline for MI
    max_shift_samples = int(0.2 * sfreq)

    channel_consistencies = []

    for ch_idx in range(n_channels):
        channel_data = trials[:, ch_idx, :]  # (n_trials, n_times)

        # Auto-align if requested (only for MI)
        if use_alignment:
            aligned_data = align_trials_to_mi_onset(channel_data, sfreq, baseline_end_idx)
        else:
            aligned_data = channel_data

        # Compute pairwise correlations (YOUR EXISTING LOGIC)
        correlations = []
        for i in range(n_trials):
            for j in range(i+1, n_trials):
                trial1 = aligned_data[i]
                trial2 = aligned_data[j]

                # Normalize
                trial1 = (trial1 - np.mean(trial1)) / (np.std(trial1) + 1e-10)
                trial2 = (trial2 - np.mean(trial2)) / (np.std(trial2) + 1e-10)

                # Cross-correlation (YOUR EXISTING LOGIC)
                xcorr = np.correlate(trial1, trial2, mode='same')
                xcorr = xcorr / len(trial1)

                center = len(xcorr) // 2
                start = max(0, center - max_shift_samples)
                end = min(len(xcorr), center + max_shift_samples)
                max_corr = np.max(xcorr[start:end])

                correlations.append(max_corr)

        channel_consistencies.append(np.mean(correlations))

    return np.array(channel_consistencies)


def compute_channel_consistency_frequency(trials, sfreq, fmin, fmax):
    """
    Modified to work with numpy arrays instead of MNE Epochs.

    Args:
        trials: numpy array (n_trials, n_channels, n_times)
        sfreq: sampling frequency
        fmin, fmax: frequency band
    """
    from scipy import signal as sp_signal

    n_trials, n_channels, n_times = trials.shape

    if n_trials < 2:
        return np.zeros(n_channels)

    channel_consistencies = []

    for ch_idx in range(n_channels):
        channel_data = trials[:, ch_idx, :]  # (n_trials, n_times)

        # Compute PSD for each trial
        psds = []
        for trial in channel_data:
            freqs, psd = sp_signal.welch(trial, fs=sfreq, nperseg=min(512, n_times))
            # Select frequency band
            freq_mask = (freqs >= fmin) & (freqs <= fmax)
            psds.append(psd[freq_mask])

        psds = np.array(psds)  # (n_trials, n_freqs)

        # Compute pairwise correlations (YOUR EXISTING LOGIC)
        corr_matrix = np.corrcoef(psds)
        triu_idx = np.triu_indices_from(corr_matrix, k=1)
        pairwise_corrs = corr_matrix[triu_idx]

        channel_consistencies.append(np.mean(pairwise_corrs))

    return np.array(channel_consistencies)


def analyze_subject(subject_data, subject_id, task_type):
    """
    Modified to work with numpy arrays from DataLoader.

    Args:
        subject_data: {condition_id: numpy_array(n_trials, n_channels, n_times)}
        subject_id: subject identifier
        task_type: task name
    """
    print(f"\nAnalyzing Subject {subject_id}...")

    config = TASK_CONFIG[task_type]
    sfreq = config['sfreq']
    use_alignment = config['use_alignment']

    # Get channel names (generic for now)
    first_condition = list(subject_data.values())[0]
    n_channels = first_condition.shape[1]
    channel_names = [f'Ch{i}' for i in range(n_channels)]

    results = {
        'subject': subject_id,
        'task': task_type,
        'channel_names': channel_names,
        'conditions': {}
    }

    for condition_id, trials in subject_data.items():
        n_trials = trials.shape[0]

        if n_trials < 2:
            print(f"  Skipping condition {condition_id}: only {n_trials} trial(s)")
            continue

        print(f"  Condition {condition_id}: {n_trials} trials")

        # Temporal consistency
        temp_consistency = compute_channel_consistency_temporal(trials, sfreq, use_alignment)

        # Frequency consistency per band
        freq_consistency = {}
        for band_name, (fmin, fmax) in freq_bands.items():
            freq_consistency[band_name] = compute_channel_consistency_frequency(
                trials, sfreq, fmin, fmax
            )

        results['conditions'][condition_id] = {
            'n_trials': n_trials,
            'temporal': temp_consistency,
            'frequency': freq_consistency
        }

    return results

# ============================================================================
# KEEP YOUR EXISTING PLOTTING FUNCTIONS - minimal changes
# ============================================================================

def plot_subject_channel_consistency(results, save_path):
    """YOUR EXISTING FUNCTION - just adjusted for condition_id instead of names"""
    subject_id = results['subject']
    channel_names = results['channel_names']
    conditions = list(results['conditions'].keys())

    if not conditions:
        print(f"No conditions to plot for subject {subject_id}")
        return

    n_conds = min(len(conditions), 4)
    n_bands = len(freq_bands)
    n_freq_rows = (n_bands + 1) // 2  # 2 bands per row
    total_rows = 2 + n_freq_rows      # 2 rows for temporal + freq rows

    fig = plt.figure(figsize=(24, 6 * total_rows))
    gs = fig.add_gridspec(total_rows, 4, hspace=0.4, wspace=0.4)

    # Plot temporal consistency per channel for each condition
    subplot_positions = [(0, 0), (0, 1), (1, 0), (1, 1)]
    for idx, condition_id in enumerate(conditions[:4]):
        row, col = subplot_positions[idx]
        ax = fig.add_subplot(gs[row, col])

        temp_scores = results['conditions'][condition_id]['temporal']

        x = np.arange(len(channel_names))
        colors = ['red' if score < 0.5 else 'orange' if score < 0.7 else 'green'
                  for score in temp_scores]
        bars = ax.bar(x, temp_scores, color=colors, alpha=0.7, edgecolor='black', linewidth=0.5)

        ax.axhline(0.7, color='green', linestyle='--', alpha=0.5, linewidth=2, label='High (>0.7)')
        ax.axhline(0.5, color='orange', linestyle='--', alpha=0.5, linewidth=2, label='Moderate (>0.5)')
        ax.set_xlabel('Channel Index')
        ax.set_ylabel('Mean Trial Correlation', fontsize=10)
        ax.set_title(f'Condition {condition_id} - Temporal Consistency', fontweight='bold', fontsize=11)
        ax.set_ylim([0, 1])
        ax.legend(loc='upper right', fontsize=8)
        ax.grid(True, alpha=0.3, axis='y')

    # Frequency band comparison - ONE SUBPLOT PER BAND (bottom 2 rows)
    if len(conditions) > 0:
        band_names = list(freq_bands.keys())

        for band_idx, band_name in enumerate(band_names):
            ax = fig.add_subplot(gs[2 + band_idx // 2, (band_idx % 2) * 2:(band_idx % 2) * 2 + 2])

            # Collect scores for this band across all conditions
            x = np.arange(len(conditions))
            band_scores = []

            for condition_id in conditions:
                freq_scores = results['conditions'][condition_id]['frequency'][band_name]
                band_scores.append(np.mean(freq_scores))

            # Color bars by consistency level
            colors = ['red' if s < 0.5 else 'orange' if s < 0.7 else 'green' for s in band_scores]
            bars = ax.bar(x, band_scores, color=colors, alpha=0.7, edgecolor='black', linewidth=1)

            ax.axhline(0.7, color='green', linestyle='--', alpha=0.5, linewidth=2)
            ax.axhline(0.5, color='orange', linestyle='--', alpha=0.5, linewidth=2)
            ax.set_xlabel('Condition', fontsize=10)
            ax.set_ylabel('Mean Correlation', fontsize=10)
            ax.set_title(f'{band_name} Band ({freq_bands[band_name][0]}-{freq_bands[band_name][1]} Hz)',
                        fontweight='bold', fontsize=11)
            ax.set_xticks(x)
            ax.set_xticklabels([f'C{c}' for c in conditions])
            ax.set_ylim([0, 1])
            ax.grid(True, alpha=0.3, axis='y')

    plt.suptitle(f'Subject {subject_id} - {results["task"]} Task',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved plot: {save_path}")


def create_summary_dataframe(all_results):
    """YOUR EXISTING FUNCTION - minimal changes"""
    summary_data = []

    for result in all_results:
        subject = result['subject']
        task = result['task']

        for condition_id, cond_data in result['conditions'].items():
            temp_consistency = cond_data['temporal']

            row = {
                'Subject': subject,
                'Task': task,
                'Condition': condition_id,
                'N_Trials': cond_data['n_trials'],
                'Temporal_Mean': np.mean(temp_consistency),
                'Temporal_Std': np.std(temp_consistency),
            }

            # Add frequency bands
            for band_name in freq_bands.keys():
                freq_scores = cond_data['frequency'][band_name]
                row[f'{band_name}_Mean'] = np.mean(freq_scores)
                row[f'{band_name}_Std'] = np.std(freq_scores)

            summary_data.append(row)

    return pd.DataFrame(summary_data)


def plot_cross_subject_summary(df, save_path):
    """YOUR EXISTING FUNCTION - keep as-is"""
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    conditions = df['Condition'].unique()

    # Plot 1: Temporal consistency by subject
    ax = axes[0, 0]
    for condition in conditions:
        data = df[df['Condition'] == condition]
        ax.plot(data['Subject'], data['Temporal_Mean'],
                marker='o', label=f'Cond {condition}', linewidth=2, markersize=8)
    ax.axhline(0.7, color='green', linestyle='--', alpha=0.3)
    ax.axhline(0.5, color='orange', linestyle='--', alpha=0.3)
    ax.set_xlabel('Subject')
    ax.set_ylabel('Mean Consistency')
    ax.set_title('Temporal Consistency Across Subjects', fontweight='bold')
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')

    # Plot 2: Distribution by condition
    ax = axes[0, 1]
    temp_data = [df[df['Condition'] == c]['Temporal_Mean'].values for c in conditions]
    bp = ax.boxplot(temp_data, labels=[f'C{c}' for c in conditions], patch_artist=True)
    for patch in bp['boxes']:
        patch.set_facecolor('lightblue')
    ax.axhline(0.7, color='green', linestyle='--', alpha=0.3)
    ax.axhline(0.5, color='orange', linestyle='--', alpha=0.3)
    ax.set_ylabel('Consistency')
    ax.set_title('Distribution by Condition', fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')

    # Plot 3: Subject ranking
    ax = axes[1, 0]
    subject_avg = df.groupby('Subject')['Temporal_Mean'].mean().sort_values(ascending=False)
    colors = ['green' if v > 0.7 else 'orange' if v > 0.5 else 'red'
              for v in subject_avg.values]
    ax.bar(range(len(subject_avg)), subject_avg.values, color=colors, alpha=0.7)
    ax.axhline(0.7, color='green', linestyle='--', alpha=0.3)
    ax.axhline(0.5, color='orange', linestyle='--', alpha=0.3)
    ax.set_xlabel('Subject (ranked)')
    ax.set_ylabel('Mean Consistency')
    ax.set_title('Subject Ranking', fontweight='bold')
    ax.set_xticks(range(len(subject_avg)))
    ax.set_xticklabels([f'S{s}' for s in subject_avg.index], rotation=45, ha='right')
    ax.grid(True, alpha=0.3, axis='y')

    # Plot 4: Frequency bands heatmap
    ax = axes[1, 1]
    band_cols = [col for col in df.columns if col.endswith('_Mean') and col != 'Temporal_Mean']
    band_data = df.groupby('Condition')[band_cols].mean()
    band_data.columns = [col.replace('_Mean', '') for col in band_data.columns]

    im = ax.imshow(band_data.values, cmap='RdYlGn', vmin=0, vmax=1, aspect='auto')
    ax.set_xticks(range(len(band_data.columns)))
    ax.set_xticklabels(band_data.columns)
    ax.set_yticks(range(len(band_data.index)))
    ax.set_yticklabels([f'C{i}' for i in band_data.index])
    ax.set_title('Frequency Band Consistency', fontweight='bold')
    plt.colorbar(im, ax=ax, label='Mean Correlation')

    for i in range(len(band_data.index)):
        for j in range(len(band_data.columns)):
            ax.text(j, i, f'{band_data.values[i, j]:.2f}',
                   ha="center", va="center", color="black", fontsize=9)

    plt.suptitle(f'Cross-Subject Summary - {TASK} Task', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print("Saved cross-subject summary")


def print_report(df):
    """YOUR EXISTING FUNCTION - keep as-is"""
    print("\n" + "="*80)
    print(f"CROSS-TRIAL CONSISTENCY ANALYSIS - {TASK} TASK")
    print("="*80)
    print(f"\nTotal subjects: {df['Subject'].nunique()}")
    print(f"Total trials: {df['N_Trials'].sum()}")
    print(f"Total conditions: {df['Condition'].nunique()}")

    print("\n" + "-"*80)
    print("OVERALL CONSISTENCY:")
    print(f"  Mean: {df['Temporal_Mean'].mean():.3f} ± {df['Temporal_Mean'].std():.3f}")
    print(f"  Range: [{df['Temporal_Mean'].min():.3f}, {df['Temporal_Mean'].max():.3f}]")

    print("\n" + "-"*80)
    print("PER-CONDITION BREAKDOWN:")
    for condition in df['Condition'].unique():
        cond_data = df[df['Condition'] == condition]
        print(f"\n  Condition {condition}:")
        print(f"    Temporal: {cond_data['Temporal_Mean'].mean():.3f} ± {cond_data['Temporal_Mean'].std():.3f}")

        # Best frequency band
        band_cols = [col for col in df.columns if col.endswith('_Mean') and col != 'Temporal_Mean']
        band_means = {col.replace('_Mean', ''): cond_data[col].mean() for col in band_cols}
        best_band = max(band_means, key=band_means.get)
        print(f"    Best freq band: {best_band} ({band_means[best_band]:.3f})")

# ============================================================================
# MAIN EXECUTION
# ============================================================================

if __name__ == "__main__":
    print(f"\n{'='*80}")
    print(f"Cross-Trial Consistency Analysis - {TASK} Task")
    print(f"{'='*80}\n")

    # Load ALL data for all subjects (not split - this is for analysis, not training)
    root_dir = DATA_PATHS[TASK]
    print(f"Loading ALL {TASK} data from: {root_dir}")

    kwargs = {}

    subject_data = load_all_subjects_for_task(TASK, root_dir, **kwargs)

    print(f"\n{'='*40}")
    print(f"Found {len(subject_data)} subjects")
    for sid, conds in subject_data.items():
        print(f"  Subject {sid}: {len(conds)} conditions, ", end='')
        total_trials = sum(trials.shape[0] for trials in conds.values())
        print(f"{total_trials} total trials")
    print(f"{'='*40}\n")

    # Analyze each subject
    all_results = []

    for subject_id, conditions in subject_data.items():
        try:
            result = analyze_subject(conditions, subject_id, TASK)
            all_results.append(result)

            # Plot individual subject
            plot_path = OUTPUT_DIR / f'subject_{subject_id}_consistency.png'
            plot_subject_channel_consistency(result, plot_path)

        except Exception as e:
            print(f"Error with subject {subject_id}: {e}")
            import traceback
            traceback.print_exc()

    # Create summary
    if all_results:
        df = create_summary_dataframe(all_results)
        df.to_csv(OUTPUT_DIR / 'consistency_summary.csv', index=False)
        print("\nSaved CSV summary")

        summary_plot_path = OUTPUT_DIR / 'cross_trial_summary.png'
        plot_cross_subject_summary(df, summary_plot_path)

        # Print report
        print_report(df)

        print(f"\n{'='*80}")
        print(f"Analysis complete! Results saved to: {OUTPUT_DIR}")
        print(f"{'='*80}\n")
    else:
        print("\nNo results to summarize!")


Cross-Trial Consistency Analysis - P300 Task

Loading ALL P300 data from: /content/drive/MyDrive/IDL/IDL Project Team 5 F25/dataset/p300/bi2015a/cleaned_data


KeyboardInterrupt: 

In [None]:
"""
Unified Channel Selection Analysis for All EEG Task Types
Identifies which channels are most distinct and should be kept for modeling.

Only creates essential plots:
- Per subject: Incremental selection strategy
- Cross-subject: All channels ranked
"""

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from scipy.spatial.distance import squareform
from scipy.cluster import hierarchy

# ============================================================================
# CONFIGURATION
# ============================================================================

TASK = 'P300'  # Change to: 'MI', 'SSVEP', 'P300', 'Imagined_speech'

OUTPUT_DIR = Path(f'/content/drive/MyDrive/IDL/IDL Project Team 5 F25/data analysis/{TASK}/channel_selection_{TASK}')
OUTPUT_DIR.mkdir(exist_ok=True)

TASK_CONFIG = {
    'MI': {'sfreq': 250, 'n_classes': 4},
    'SSVEP': {'sfreq': 250, 'n_classes': 26},
    'P300': {'sfreq': 256, 'n_classes': 2},
    'Imagined_speech': {'sfreq': 128, 'n_classes': 11}
}

# Efficiency settings
MAX_TRIALS_FOR_ANALYSIS = 50  # Subsample trials if more

# ============================================================================
# CHANNEL SELECTION ANALYSIS
# ============================================================================

def compute_channel_similarity(trial_data):
    """
    Compute correlation between all channel pairs within a single trial.

    Args:
        trial_data: (n_channels, n_times) - single trial

    Returns:
        similarity_matrix: (n_channels, n_channels)
    """
    return np.corrcoef(trial_data)


def identify_representative_channels(trials, max_trials=50):
    """
    Identify which channels are most representative and should be kept.

    Args:
        trials: (n_trials, n_channels, n_times)
        max_trials: subsample if more

    Returns:
        Dictionary with:
        - diversity_ranking: channels ranked by distinctiveness
        - incremental_selection: ordered list of channels to keep
        - avg_similarity_matrix: average channel-channel correlation
    """
    n_trials, n_channels, n_times = trials.shape

    # Subsample if needed
    if n_trials > max_trials:
        print(f"    Subsampling {n_trials} -> {max_trials} trials for efficiency")
        indices = np.random.choice(n_trials, max_trials, replace=False)
        trials = trials[indices]
        n_trials = max_trials

    # Compute average similarity across all trials
    print(f"    Computing channel similarities across {n_trials} trials...")
    trial_similarities = []
    for trial_idx in range(n_trials):
        sim_matrix = compute_channel_similarity(trials[trial_idx])
        trial_similarities.append(sim_matrix)

    avg_similarity = np.mean(trial_similarities, axis=0)

    # Method 1: DIVERSITY SCORE - channels least correlated with others
    diversity_scores = []
    for i in range(n_channels):
        # Lower average correlation = more unique information
        avg_corr = np.mean(np.abs(avg_similarity[i, :]))
        diversity_score = 1 - avg_corr
        diversity_scores.append({
            'channel_idx': i,
            'diversity_score': diversity_score,
            'avg_correlation': avg_corr
        })

    # Sort by diversity (highest = most distinct = keep first)
    diversity_scores = sorted(diversity_scores, key=lambda x: x['diversity_score'], reverse=True)

    # Method 2: INCREMENTAL SELECTION - maximize diversity while adding channels
    selected_indices = []

    # Start with most diverse channel
    first_idx = diversity_scores[0]['channel_idx']
    selected_indices.append(first_idx)

    # Incrementally add channels that are least correlated with already selected
    for _ in range(min(n_channels - 1, 14)):  # Select up to 15 total
        best_score = -1
        best_idx = None

        for d in diversity_scores:
            ch_idx = d['channel_idx']
            if ch_idx in selected_indices:
                continue

            # Calculate average correlation with already selected channels
            avg_corr_with_selected = np.mean([np.abs(avg_similarity[ch_idx, si])
                                             for si in selected_indices])

            score = 1 - avg_corr_with_selected
            if score > best_score:
                best_score = score
                best_idx = ch_idx

        if best_idx is not None:
            selected_indices.append(best_idx)

    return {
        'diversity_ranking': diversity_scores,
        'incremental_selection': selected_indices,
        'avg_similarity_matrix': avg_similarity
    }


def analyze_subject_channels(subject_data, subject_id, task_type):
    """
    Analyze channel selection for one subject across all conditions.

    Args:
        subject_data: {condition_id: numpy_array(n_trials, n_channels, n_times)}
        subject_id: subject identifier
        task_type: task name

    Returns:
        Dictionary with channel rankings and selections
    """
    print(f"\nAnalyzing Subject {subject_id} (channel selection)...")

    # Collect all trials across conditions for this subject
    all_trials = []
    for condition_id, trials in subject_data.items():
        all_trials.append(trials)
        print(f"  Condition {condition_id}: {trials.shape[0]} trials")

    # Concatenate all trials
    combined_trials = np.concatenate(all_trials, axis=0)
    n_trials, n_channels, n_times = combined_trials.shape

    print(f"  Total: {n_trials} trials, {n_channels} channels")

    # Analyze channel selection
    results = identify_representative_channels(
        combined_trials,
        max_trials=MAX_TRIALS_FOR_ANALYSIS
    )

    results['subject'] = subject_id
    results['task'] = task_type
    results['n_channels'] = n_channels
    results['n_trials'] = n_trials

    return results

# ============================================================================
# SIMPLIFIED PLOTTING - Only essential plots
# ============================================================================

def plot_subject_incremental_selection(results, save_path):
    """
    Plot ONLY the incremental selection strategy for one subject.
    This is the row 2, left plot from your original visualization.
    """
    subject_id = results['subject']
    incremental_indices = results['incremental_selection'][:15]
    n_channels = results['n_channels']

    # Create channel names
    channel_names = [f'Ch{i}' for i in range(n_channels)]
    inc_channels = [channel_names[idx] for idx in incremental_indices]

    fig, ax = plt.subplots(1, 1, figsize=(12, 8))

    y_pos = np.arange(len(inc_channels))
    ax.barh(y_pos, range(len(inc_channels), 0, -1),
            color='steelblue', alpha=0.7, edgecolor='black', linewidth=1.5)
    ax.set_yticks(y_pos)
    ax.set_yticklabels(inc_channels, fontsize=11)
    ax.set_xlabel('Selection Priority (higher = more important)', fontsize=13, fontweight='bold')
    ax.set_title(f'Subject {subject_id} - Incremental Channel Selection\n'
                 f'(Maximizes diversity at each step)',
                fontsize=14, fontweight='bold')
    ax.invert_yaxis()
    ax.grid(True, alpha=0.3, axis='x')

    # Add selection cutoff lines
    ax.axhline(4.5, color='red', linestyle='--', linewidth=2, label='Top 5')
    ax.axhline(9.5, color='orange', linestyle='--', linewidth=2, label='Top 10')
    ax.legend(fontsize=11)

    # Add text annotation
    top5 = ', '.join([f'Ch{incremental_indices[i]}' for i in range(min(5, len(incremental_indices)))])
    ax.text(0.02, 0.98, f'Top 5 Channels: {top5}',
            transform=ax.transAxes, fontsize=10,
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved: {save_path}")


def create_cross_subject_summary(all_results):
    """
    Create summary of channel rankings across all subjects.
    """
    n_channels = all_results[0]['n_channels']
    channel_scores = {i: [] for i in range(n_channels)}

    # Collect scores for each channel across subjects
    for result in all_results:
        for ch_data in result['diversity_ranking']:
            ch_idx = ch_data['channel_idx']
            channel_scores[ch_idx].append(ch_data['diversity_score'])

    # Calculate statistics
    channel_summary = []
    for ch_idx in range(n_channels):
        scores = channel_scores[ch_idx]
        channel_summary.append({
            'channel_idx': ch_idx,
            'channel': f'Ch{ch_idx}',
            'avg_score': np.mean(scores),
            'std_score': np.std(scores),
            'min_score': np.min(scores),
            'max_score': np.max(scores)
        })

    # Sort by average score
    channel_summary = sorted(channel_summary, key=lambda x: x['avg_score'], reverse=True)

    return channel_summary


def plot_cross_subject_rankings(channel_summary, save_path):
    """
    Plot ONLY the all channels ranked visualization.
    This is the bottom left plot from your original visualization.
    """
    fig, ax = plt.subplots(1, 1, figsize=(16, 8))

    all_channels = [c['channel'] for c in channel_summary]
    all_scores = [c['avg_score'] for c in channel_summary]

    # Color code channels
    colors = ['darkgreen' if i < 5 else 'green' if i < 10 else 'yellowgreen' if i < 15
              else 'orange' if i < 20 else 'red' for i in range(len(all_channels))]

    ax.bar(range(len(all_channels)), all_scores, color=colors, alpha=0.7, edgecolor='black')

    # Add reference lines
    ax.axhline(y=np.mean(all_scores), color='black', linestyle='--', linewidth=2,
              label=f'Average: {np.mean(all_scores):.3f}')
    ax.axvline(x=4.5, color='red', linestyle='--', alpha=0.5, linewidth=2)
    ax.axvline(x=9.5, color='orange', linestyle='--', alpha=0.5, linewidth=2)
    ax.axvline(x=14.5, color='yellow', linestyle='--', alpha=0.5, linewidth=2)

    ax.set_xlabel('Channel (ranked)', fontsize=13, fontweight='bold')
    ax.set_ylabel('Average Distinctiveness', fontsize=13, fontweight='bold')
    ax.set_title(f'All Channels Ranked by Distinctiveness - {TASK} Task',
                 fontsize=15, fontweight='bold')
    ax.set_xticks(range(len(all_channels)))
    ax.set_xticklabels(all_channels, rotation=90, fontsize=9)
    ax.legend(fontsize=12)
    ax.grid(True, alpha=0.3, axis='y')

    # Add text box with recommendations
    top5 = ', '.join([c['channel'] for c in channel_summary[:5]])
    top10 = ', '.join([c['channel'] for c in channel_summary[:10]])
    remove = ', '.join([c['channel'] for c in channel_summary[-5:]])

    textstr = f'Top 5: {top5}\nTop 10: {top10}\nRemove: {remove}'
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
    ax.text(0.02, 0.98, textstr, transform=ax.transAxes, fontsize=10,
            verticalalignment='top', bbox=props, family='monospace')

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved: {save_path}")

# ============================================================================
# MAIN EXECUTION
# ============================================================================

if __name__ == "__main__":
    print(f"\n{'='*80}")
    print(f"Channel Selection Analysis - {TASK} Task")
    print(f"{'='*80}\n")

    # Load data
    root_dir = DATA_PATHS[TASK]
    print(f"Loading {TASK} data from: {root_dir}")

    kwargs = {}
    if TASK == 'MI':
        kwargs['label_dir'] = MI_LABEL_DIR

    subject_data = load_all_subjects_for_task(TASK, root_dir, **kwargs)

    print(f"\n{'='*40}")
    print(f"Found {len(subject_data)} subjects")
    print(f"{'='*40}\n")

    # Analyze each subject
    all_results = []

    for subject_id, conditions in subject_data.items():
        try:
            result = analyze_subject_channels(conditions, subject_id, TASK)
            all_results.append(result)

            # Plot individual subject (simplified)
            plot_path = OUTPUT_DIR / f'subject_{subject_id}_channel_selection.png'
            plot_subject_incremental_selection(result, plot_path)

        except Exception as e:
            print(f"Error with subject {subject_id}: {e}")
            import traceback
            traceback.print_exc()

    # Create cross-subject summary
    if all_results:
        print("\n" + "="*80)
        print("GENERATING CROSS-SUBJECT SUMMARY")
        print("="*80)

        channel_summary = create_cross_subject_summary(all_results)

        # Save CSV
        df = pd.DataFrame(channel_summary)
        df.to_csv(OUTPUT_DIR / 'channel_rankings.csv', index=False)
        print(f"\nSaved: {OUTPUT_DIR / 'channel_rankings.csv'}")

        # Plot summary (simplified)
        plot_path = OUTPUT_DIR / 'cross_subject_channel_rankings.png'
        plot_cross_subject_rankings(channel_summary, plot_path)

        # Save text recommendations
        summary_path = OUTPUT_DIR / 'channel_recommendations.txt'
        with open(summary_path, 'w') as f:
            f.write("="*80 + "\n")
            f.write(f"CHANNEL RECOMMENDATIONS - {TASK} TASK\n")
            f.write("="*80 + "\n\n")

            f.write("RANKED CHANNELS:\n")
            f.write("-"*80 + "\n")
            for i, ch_data in enumerate(channel_summary, 1):
                f.write(f"{i:3d}. {ch_data['channel']:10s} | "
                       f"Score: {ch_data['avg_score']:.4f} ± {ch_data['std_score']:.4f}\n")

            f.write("\n" + "="*80 + "\n")
            f.write("RECOMMENDATIONS:\n")
            f.write("="*80 + "\n")
            f.write(f"\nTop 5 (Must Keep):  {', '.join([c['channel'] for c in channel_summary[:5]])}\n")
            f.write(f"\nTop 10 (Recommended): {', '.join([c['channel'] for c in channel_summary[:10]])}\n")
            f.write(f"\nTop 15 (Conservative): {', '.join([c['channel'] for c in channel_summary[:15]])}\n")
            f.write(f"\nRemove (Lowest Priority): {', '.join([c['channel'] for c in channel_summary[-5:]])}\n")

        print(f"Saved: {summary_path}")

        print(f"\n{'='*80}")
        print(f"Analysis complete! Results saved to: {OUTPUT_DIR}")
        print(f"{'='*80}\n")

        # Print quick summary
        print("\nQUICK SUMMARY:")
        print(f"  Top 5:  {', '.join([c['channel'] for c in channel_summary[:5]])}")
        print(f"  Top 10: {', '.join([c['channel'] for c in channel_summary[:10]])}")

    else:
        print("\nNo results to summarize!")


Channel Selection Analysis - P300 Task



NameError: name 'DATA_PATHS' is not defined

In [5]:
"""
Cross-Subject Analysis for All Tasks
Analyzes: MI, SSVEP, P300, Imagined Speech

Questions:
Q3: Does the same channel for the same label look similar across subjects?
Q4: Are label representations consistent across subjects?
"""

from itertools import combinations
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# ============================================================================
# CONFIGURATION - Set your task here
# ============================================================================

TASK = 'P300'  # Change to: 'MI', 'SSVEP', 'P300', 'Imagined_speech'

OUTPUT_DIR = Path(f'/content/drive/MyDrive/IDL/IDL Project Team 5 F25/data analysis/{TASK}/cross_subject_{TASK}')
OUTPUT_DIR.mkdir(exist_ok=True)

# ============================================================================
# ALIGNMENT AND PREPROCESSING FUNCTIONS
# ============================================================================

def find_mi_window(trial_data, sfreq, baseline_end_idx):
    """Find MI onset window (only used for MI task)"""
    power = trial_data ** 2
    baseline_power = np.mean(power[:baseline_end_idx])
    baseline_std = np.std(power[:baseline_end_idx])
    threshold = baseline_power + 2 * baseline_std
    active = power > threshold

    min_duration = int(0.5 * sfreq)
    active_start = None

    for i in range(baseline_end_idx, len(active) - min_duration):
        if np.sum(active[i:i+min_duration]) > 0.7 * min_duration:
            active_start = i
            break

    if active_start is None:
        active_start = baseline_end_idx + int(0.5 * sfreq)

    active_end = min(active_start + int(1.5 * sfreq), len(trial_data))
    return active_start, active_end


def align_trials_to_mi_onset(data, sfreq, baseline_end_idx):
    """Align MI trials to onset (only used for MI task)"""
    n_trials = data.shape[0]
    aligned_trials = []

    for trial_idx in range(n_trials):
        trial = data[trial_idx]
        start, end = find_mi_window(trial, sfreq, baseline_end_idx)
        aligned_trial = trial[start:end]
        aligned_trials.append(aligned_trial)

    min_len = min(len(t) for t in aligned_trials)
    aligned_trials = [t[:min_len] for t in aligned_trials]

    return np.array(aligned_trials)


def compute_subject_channel_average_temporal(data, condition, task, sfreq):
    """
    Compute average temporal pattern per channel for a subject-condition pair.

    Args:
        data: dict of condition -> numpy array (n_trials, n_channels, n_times)
        condition: condition label
        task: task name
        sfreq: sampling frequency

    Returns: array of shape (n_channels, n_times_aligned)
    """
    trial_data = data[condition]  # (n_trials, n_channels, n_times)
    n_channels = trial_data.shape[1]

    use_alignment = TASK_CONFIG[task]['use_alignment']

    channel_averages = []

    for ch_idx in range(n_channels):
        channel_data = trial_data[:, ch_idx, :]  # (n_trials, n_times)

        # Apply alignment only for MI task
        if use_alignment and task == 'MI':
            baseline_end_idx = int(1.0 * sfreq)
            aligned_data = align_trials_to_mi_onset(channel_data, sfreq, baseline_end_idx)
        else:
            aligned_data = channel_data

        # Average across trials
        avg_pattern = np.mean(aligned_data, axis=0)

        # Normalize
        avg_pattern = (avg_pattern - np.mean(avg_pattern)) / (np.std(avg_pattern) + 1e-10)
        channel_averages.append(avg_pattern)

    # Find the minimum length across all channels
    min_length = min(len(pattern) for pattern in channel_averages)

    # Truncate all patterns to the minimum length
    channel_averages_truncated = [pattern[:min_length] for pattern in channel_averages]

    return np.array(channel_averages_truncated)


def compute_subject_channel_average_frequency(data, condition, fmin, fmax, sfreq):
    """
    Compute average frequency pattern per channel for a subject-condition pair.

    Args:
        data: dict of condition -> numpy array (n_trials, n_channels, n_times)
        condition: condition label
        fmin, fmax: frequency band
        sfreq: sampling frequency

    Returns: array of shape (n_channels, n_freqs)
    """
    from scipy import signal

    trial_data = data[condition]  # (n_trials, n_channels, n_times)
    n_trials, n_channels, n_times = trial_data.shape

    # Compute PSD using Welch's method
    psds_all = []

    for ch_idx in range(n_channels):
        channel_data = trial_data[:, ch_idx, :]  # (n_trials, n_times)

        trial_psds = []
        for trial_idx in range(n_trials):
            freqs, psd = signal.welch(
                channel_data[trial_idx],
                fs=sfreq,
                nperseg=min(512, n_times),
                noverlap=None
            )

            # Filter to frequency band
            freq_mask = (freqs >= fmin) & (freqs <= fmax)
            trial_psds.append(psd[freq_mask])

        # Average across trials
        avg_psd = np.mean(trial_psds, axis=0)
        psds_all.append(avg_psd)

    psds_all = np.array(psds_all)  # (n_channels, n_freqs)

    # Normalize each channel
    for ch_idx in range(psds_all.shape[0]):
        psds_all[ch_idx] = (psds_all[ch_idx] - np.mean(psds_all[ch_idx])) / (np.std(psds_all[ch_idx]) + 1e-10)

    return psds_all

# ============================================================================
# ANALYSIS FUNCTIONS
# ============================================================================

def analyze_q3_channel_consistency(all_subject_data, task, sfreq):
    """
    Q3: Does the same channel for the same label look similar across subjects?

    For each condition and each channel, compare patterns across all subject pairs.
    """
    results = {
        'temporal': {},
        'frequency': {band: {} for band in freq_bands.keys()}
    }

    subject_list = list(all_subject_data.keys())

    # Get all conditions (use first subject as reference)
    first_subject = subject_list[0]
    conditions = list(all_subject_data[first_subject].keys())

    # Get number of channels (assume consistent across subjects)
    first_condition = conditions[0]
    n_channels = all_subject_data[first_subject][first_condition].shape[1]

    print("\n" + "="*80)
    print("ANALYZING CHANNEL CONSISTENCY ACROSS SUBJECTS")
    print("="*80)

    # Precompute all subject patterns for each condition
    print("\nPrecomputing subject patterns...")
    subject_patterns = {}

    for subject_id in subject_list:
        print(f"  Processing subject {subject_id}...")
        subject_patterns[subject_id] = {}

        for condition in conditions:
            print(f"Condition {condition}")
            if condition not in all_subject_data[subject_id]:
                continue

            # Temporal patterns
            temporal = compute_subject_channel_average_temporal(
                all_subject_data[subject_id], condition, task, sfreq
            )

            # Frequency patterns
            freq_patterns = {}
            for band_name, (fmin, fmax) in freq_bands.items():
                freq_patterns[band_name] = compute_subject_channel_average_frequency(
                    all_subject_data[subject_id], condition, fmin, fmax, sfreq
                )

            subject_patterns[subject_id][condition] = {
                'temporal': temporal,
                'frequency': freq_patterns
            }

    # Now compute correlations
    for condition in conditions:
        print(f"\nAnalyzing condition {condition}...")

        # Filter subjects that have this condition
        available_subjects = [s for s in subject_list if condition in subject_patterns[s]]

        if len(available_subjects) < 2:
            print(f"  Skipping condition {condition} - not enough subjects")
            continue

        # Temporal domain
        temporal_correlations = np.zeros((n_channels, len(available_subjects), len(available_subjects)))

        for i, subj1 in enumerate(available_subjects):
            for j, subj2 in enumerate(available_subjects):
                pattern1 = subject_patterns[subj1][condition]['temporal']
                pattern2 = subject_patterns[subj2][condition]['temporal']

                # Compare each channel
                for ch_idx in range(n_channels):
                    # Handle different lengths by truncating to shorter
                    len1, len2 = len(pattern1[ch_idx]), len(pattern2[ch_idx])
                    min_len = min(len1, len2)

                    if min_len > 10:  # Minimum reasonable length
                        corr = np.corrcoef(
                            pattern1[ch_idx][:min_len],
                            pattern2[ch_idx][:min_len]
                        )[0, 1]
                        temporal_correlations[ch_idx, i, j] = corr
                    else:
                        temporal_correlations[ch_idx, i, j] = 0

        # Extract upper triangle (unique pairs) for each channel
        channel_consistency_scores = []
        for ch_idx in range(n_channels):
            corr_matrix = temporal_correlations[ch_idx]
            triu_idx = np.triu_indices_from(corr_matrix, k=1)
            if len(triu_idx[0]) > 0:
                channel_consistency_scores.append(np.mean(corr_matrix[triu_idx]))
            else:
                channel_consistency_scores.append(0)

        results['temporal'][condition] = {
            'channel_scores': np.array(channel_consistency_scores),
            'channel_names': [f'Ch{i+1}' for i in range(n_channels)],
            'mean': np.mean(channel_consistency_scores),
            'std': np.std(channel_consistency_scores)
        }

        # Frequency domain
        for band_name, (fmin, fmax) in freq_bands.items():
            freq_correlations = np.zeros((n_channels, len(available_subjects), len(available_subjects)))

            for i, subj1 in enumerate(available_subjects):
                for j, subj2 in enumerate(available_subjects):
                    pattern1 = subject_patterns[subj1][condition]['frequency'][band_name]
                    pattern2 = subject_patterns[subj2][condition]['frequency'][band_name]

                    for ch_idx in range(n_channels):
                        corr = np.corrcoef(pattern1[ch_idx], pattern2[ch_idx])[0, 1]
                        freq_correlations[ch_idx, i, j] = corr

            channel_consistency_scores = []
            for ch_idx in range(n_channels):
                corr_matrix = freq_correlations[ch_idx]
                triu_idx = np.triu_indices_from(corr_matrix, k=1)
                if len(triu_idx[0]) > 0:
                    channel_consistency_scores.append(np.mean(corr_matrix[triu_idx]))
                else:
                    channel_consistency_scores.append(0)

            results['frequency'][band_name][condition] = {
                'channel_scores': np.array(channel_consistency_scores),
                'mean': np.mean(channel_consistency_scores),
                'std': np.std(channel_consistency_scores)
            }

    return results


def analyze_q4_label_consistency(all_subject_data, task, sfreq):
    """
    Q4: Are label representations consistent across subjects?

    For each subject pair, compare how similar the same label is.
    """
    results = {
        'temporal': {},
        'frequency': {band: {} for band in freq_bands.keys()}
    }

    subject_list = list(all_subject_data.keys())

    # Get all conditions
    first_subject = subject_list[0]
    conditions = list(all_subject_data[first_subject].keys())

    print("\n" + "="*80)
    print("ANALYZING LABEL CONSISTENCY ACROSS SUBJECTS")
    print("="*80)

    # Precompute all subject patterns
    print("\nPrecomputing subject patterns...")
    subject_patterns = {}

    for subject_id in subject_list:
        subject_patterns[subject_id] = {}

        for condition in conditions:
            if condition not in all_subject_data[subject_id]:
                continue

            # Temporal patterns
            temporal = compute_subject_channel_average_temporal(
                all_subject_data[subject_id], condition, task, sfreq
            )

            # Frequency patterns
            freq_patterns = {}
            for band_name, (fmin, fmax) in freq_bands.items():
                freq_patterns[band_name] = compute_subject_channel_average_frequency(
                    all_subject_data[subject_id], condition, fmin, fmax, sfreq
                )

            subject_patterns[subject_id][condition] = {
                'temporal': temporal,
                'frequency': freq_patterns
            }
            print('next')

    # Compute label consistency
    for condition in conditions:
        print(f"\nAnalyzing condition {condition}...")

        # Filter subjects that have this condition
        available_subjects = [s for s in subject_list if condition in subject_patterns[s]]

        if len(available_subjects) < 2:
            print(f"  Skipping condition {condition} - not enough subjects")
            continue

        # Temporal domain: average across all channels for whole-brain representation
        pairwise_correlations = []

        for subj1, subj2 in combinations(available_subjects, 2):
            pattern1 = subject_patterns[subj1][condition]['temporal']
            pattern2 = subject_patterns[subj2][condition]['temporal']

            # Flatten all channels and compute correlation
            pattern1_flat = pattern1.flatten()
            pattern2_flat = pattern2.flatten()

            # Handle different lengths
            min_len = min(len(pattern1_flat), len(pattern2_flat))

            if min_len > 100:  # Reasonable minimum
                corr = np.corrcoef(
                    pattern1_flat[:min_len],
                    pattern2_flat[:min_len]
                )[0, 1]
                pairwise_correlations.append(corr)

        if pairwise_correlations:
            results['temporal'][condition] = {
                'correlations': np.array(pairwise_correlations),
                'mean': np.mean(pairwise_correlations),
                'std': np.std(pairwise_correlations),
                'median': np.median(pairwise_correlations)
            }

        # Frequency domain
        for band_name, (fmin, fmax) in freq_bands.items():
            pairwise_correlations = []

            for subj1, subj2 in combinations(available_subjects, 2):
                pattern1 = subject_patterns[subj1][condition]['frequency'][band_name]
                pattern2 = subject_patterns[subj2][condition]['frequency'][band_name]

                pattern1_flat = pattern1.flatten()
                pattern2_flat = pattern2.flatten()

                corr = np.corrcoef(pattern1_flat, pattern2_flat)[0, 1]
                pairwise_correlations.append(corr)

            if pairwise_correlations:
                results['frequency'][band_name][condition] = {
                    'correlations': np.array(pairwise_correlations),
                    'mean': np.mean(pairwise_correlations),
                    'std': np.std(pairwise_correlations),
                    'median': np.median(pairwise_correlations)
                }

    return results

# ============================================================================
# VISUALIZATION FUNCTIONS
# ============================================================================

def plot_q3_results(q3_results, save_path, task):
    """Visualize Q3: channel consistency across subjects."""
    conditions = list(q3_results['temporal'].keys())
    n_conditions = len(conditions)

    # Create figure with appropriate size
    n_rows = min(3, (n_conditions + 1) // 2 + 1)
    fig, axes = plt.subplots(n_rows, 3, figsize=(18, 6*n_rows))
    axes = axes.flatten()

    # Plot temporal consistency per channel for each condition
    for idx, condition in enumerate(conditions[:6]):  # Limit to 6 conditions for space
        if idx >= len(axes) - 3:
            break

        ax = axes[idx]
        scores = q3_results['temporal'][condition]['channel_scores']
        channel_names = q3_results['temporal'][condition]['channel_names']

        x = np.arange(len(scores))
        bars = ax.bar(x, scores, alpha=0.7, edgecolor='black', linewidth=1)

        # Color code by performance
        colors = ['green' if s > 0.7 else 'orange' if s > 0.5 else 'red' for s in scores]
        for bar, color in zip(bars, colors):
            bar.set_color(color)

        ax.axhline(0.7, color='green', linestyle='--', alpha=0.5, linewidth=2, label='High')
        ax.axhline(0.5, color='orange', linestyle='--', alpha=0.5, linewidth=2, label='Medium')
        ax.set_xlabel('Channel', fontsize=10)
        ax.set_ylabel('Mean Correlation', fontsize=10)
        ax.set_title(f'Condition {condition} (Temporal)', fontweight='bold', fontsize=11)
        ax.set_xticks(x[::max(1, len(x)//10)])
        ax.set_xticklabels([channel_names[i] for i in x[::max(1, len(x)//10)]], rotation=45, ha='right')
        ax.set_ylim([0, 1])
        ax.grid(True, alpha=0.3, axis='y')
        if idx == 0:
            ax.legend(fontsize=8)

    # Frequency band comparison
    n_conditions_shown = min(len(conditions), 6)  # Only show up to 6 conditions
    if len(axes) > n_conditions_shown:
        ax = axes[n_conditions_shown]  # Use the next available subplot
        band_names = list(freq_bands.keys())

        # Show frequency bands for first 4 displayed conditions
        for condition in conditions[:4]:
            if condition in q3_results['frequency'][band_names[0]]:  # Check if data exists
                band_scores = [q3_results['frequency'][band][condition]['mean']
                              for band in band_names if condition in q3_results['frequency'][band]]
                if len(band_scores) == len(band_names):  # Only plot if we have all bands
                    ax.plot(band_names, band_scores, marker='o', label=f'C{condition}', linewidth=2, markersize=8)

        ax.axhline(0.7, color='green', linestyle='--', alpha=0.3, linewidth=2)
        ax.axhline(0.5, color='orange', linestyle='--', alpha=0.3, linewidth=2)
        ax.set_xlabel('Frequency Band', fontsize=10)
        ax.set_ylabel('Mean Correlation', fontsize=10)
        ax.set_title('Frequency Band Consistency', fontweight='bold', fontsize=11)
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
        ax.set_ylim([0, 1])

    # Statistical summary
    if len(axes) > n_conditions_shown + 1:
        ax = axes[n_conditions_shown + 1]
        ax.axis('off')

        summary_text = "STATISTICAL SUMMARY\n" + "="*40 + "\n\n"
        summary_text += "Temporal Domain:\n"
        for condition in conditions[:8]:  # Limit display
            mean = q3_results['temporal'][condition]['mean']
            std = q3_results['temporal'][condition]['std']
            summary_text += f"  C{condition}: μ={mean:.3f}, σ={std:.3f}\n"

        ax.text(0.1, 0.5, summary_text, fontsize=9, family='monospace',
                verticalalignment='center', transform=ax.transAxes)

    # Hide unused axes
    for idx in range(n_conditions_shown + 2, len(axes)):
        axes[idx].axis('off')

    plt.suptitle(f'Q3: Channel Consistency Across Subjects - {task} Task',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved Q3 plot: {save_path}")


def plot_q4_results(q4_results, save_path, task):
    """Visualize Q4: label consistency across subjects."""
    conditions = list(q4_results['temporal'].keys())
    band_names = list(freq_bands.keys())

    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    # Plot 1: Temporal consistency per condition
    ax = axes[0, 0]
    means = [q4_results['temporal'][c]['mean'] for c in conditions]
    stds = [q4_results['temporal'][c]['std'] for c in conditions]

    x = np.arange(len(conditions))
    bars = ax.bar(x, means, yerr=stds, capsize=5, alpha=0.7, edgecolor='black', linewidth=1)

    colors = ['green' if m > 0.6 else 'orange' if m > 0.3 else 'red' for m in means]
    for bar, color in zip(bars, colors):
        bar.set_color(color)

    ax.axhline(0.6, color='green', linestyle='--', alpha=0.3, label='High')
    ax.axhline(0.3, color='orange', linestyle='--', alpha=0.3, label='Medium')
    ax.set_xlabel('Condition', fontsize=10)
    ax.set_ylabel('Mean Correlation', fontsize=10)
    ax.set_title('Temporal Label Consistency', fontweight='bold', fontsize=11)
    ax.set_xticks(x)
    ax.set_xticklabels([f'C{c}' for c in conditions])
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3, axis='y')
    ax.set_ylim([0, 1])

    # Plot 2: Distribution of correlations
    ax = axes[0, 1]
    all_corrs = [q4_results['temporal'][c]['correlations'] for c in conditions]
    bp = ax.boxplot(all_corrs, labels=[f'C{c}' for c in conditions], patch_artist=True)
    for patch in bp['boxes']:
        patch.set_facecolor('lightblue')
    ax.axhline(0.6, color='green', linestyle='--', alpha=0.3)
    ax.axhline(0.3, color='orange', linestyle='--', alpha=0.3)
    ax.set_xlabel('Condition', fontsize=10)
    ax.set_ylabel('Correlation', fontsize=10)
    ax.set_title('Distribution of Pairwise Correlations', fontweight='bold', fontsize=11)
    ax.grid(True, alpha=0.3, axis='y')

    # Plot 3: Mean vs Median
    ax = axes[0, 2]
    means = [q4_results['temporal'][c]['mean'] for c in conditions]
    medians = [q4_results['temporal'][c]['median'] for c in conditions]

    x = np.arange(len(conditions))
    width = 0.35
    ax.bar(x - width/2, means, width, label='Mean', alpha=0.7)
    ax.bar(x + width/2, medians, width, label='Median', alpha=0.7)
    ax.axhline(0.6, color='green', linestyle='--', alpha=0.3)
    ax.axhline(0.3, color='orange', linestyle='--', alpha=0.3)
    ax.set_xticks(x)
    ax.set_xticklabels([f'C{c}' for c in conditions])
    ax.set_ylabel('Correlation', fontsize=10)
    ax.set_title('Mean vs Median Consistency', fontweight='bold', fontsize=11)
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')

    # Plot 4: Frequency band comparison
    ax = axes[1, 0]
    x = np.arange(len(band_names))
    width = 0.8 / len(conditions)

    for i, condition in enumerate(conditions[:8]):  # Limit to 8 for visibility
        if condition not in q4_results['frequency'][band_names[0]]:
            continue
        band_scores = [q4_results['frequency'][band][condition]['mean']
                      for band in band_names]
        ax.bar(x + i*width, band_scores, width, label=f'C{condition}', alpha=0.7)

    ax.axhline(0.6, color='green', linestyle='--', alpha=0.3)
    ax.axhline(0.3, color='orange', linestyle='--', alpha=0.3)
    ax.set_xlabel('Frequency Band', fontsize=10)
    ax.set_ylabel('Mean Correlation', fontsize=10)
    ax.set_title('Frequency Band Label Consistency', fontweight='bold', fontsize=11)
    ax.set_xticks(x + width * (len(conditions[:8])-1) / 2)
    ax.set_xticklabels(band_names)
    ax.legend(fontsize=8, ncol=2)
    ax.grid(True, alpha=0.3, axis='y')

    # Plot 5: Temporal vs best frequency band
    ax = axes[1, 1]
    temporal_means = [q4_results['temporal'][c]['mean'] for c in conditions]

    # Find best band for each condition
    best_freq_means = []
    for condition in conditions:
        band_scores = {band: q4_results['frequency'][band][condition]['mean']
                      for band in band_names if condition in q4_results['frequency'][band]}
        if band_scores:
            best_freq_means.append(max(band_scores.values()))
        else:
            best_freq_means.append(0)

    x = np.arange(len(conditions))
    width = 0.35
    ax.bar(x - width/2, temporal_means, width, label='Temporal', alpha=0.7)
    ax.bar(x + width/2, best_freq_means, width, label='Best Frequency', alpha=0.7)
    ax.axhline(0.6, color='green', linestyle='--', alpha=0.3)
    ax.axhline(0.3, color='orange', linestyle='--', alpha=0.3)
    ax.set_xticks(x)
    ax.set_xticklabels([f'C{c}' for c in conditions])
    ax.set_ylabel('Mean Correlation', fontsize=10)
    ax.set_title('Temporal vs Frequency Consistency', fontweight='bold', fontsize=11)
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')

    # Plot 6: Statistical summary
    ax = axes[1, 2]
    ax.axis('off')

    summary_text = "STATISTICAL SUMMARY\n" + "="*40 + "\n\n"
    summary_text += "Temporal Domain:\n"
    for condition in conditions[:8]:  # Limit display
        mean = q4_results['temporal'][condition]['mean']
        median = q4_results['temporal'][condition]['median']
        summary_text += f"  C{condition}: μ={mean:.3f}, med={median:.3f}\n"

    summary_text += "\nBest Frequency Band:\n"
    for condition in conditions[:8]:
        band_scores = {band: q4_results['frequency'][band][condition]['mean']
                      for band in band_names if condition in q4_results['frequency'][band]}
        if band_scores:
            best_band = max(band_scores, key=band_scores.get)
            best_score = band_scores[best_band]
            summary_text += f"  C{condition}: {best_band} ({best_score:.3f})\n"

    ax.text(0.1, 0.5, summary_text, fontsize=9, family='monospace',
            verticalalignment='center', transform=ax.transAxes)

    plt.suptitle(f'Q4: Label Consistency Across Subjects - {task} Task',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved Q4 plot: {save_path}")

# ============================================================================
# REPORTING FUNCTIONS
# ============================================================================

def create_summary_report(q3_results, q4_results, task):
    """Generate comprehensive text report."""
    print("\n" + "="*80)
    print(f"CROSS-SUBJECT CONSISTENCY ANALYSIS SUMMARY - {task} TASK")
    print("="*80)

    conditions = list(q3_results['temporal'].keys())

    # Q3 Summary
    print("\n" + "-"*80)
    print("Q3: CHANNEL CONSISTENCY ACROSS SUBJECTS")
    print("-"*80)
    print("\nDoes the same channel for the same label look similar across subjects?")
    print("\nTemporal Domain (per condition):")

    for condition in conditions:
        mean = q3_results['temporal'][condition]['mean']
        std = q3_results['temporal'][condition]['std']
        print(f"  Condition {condition:3d}: {mean:.3f} ± {std:.3f}")

    overall_temporal_mean = np.mean([q3_results['temporal'][c]['mean'] for c in conditions])
    print(f"\n  Overall:         {overall_temporal_mean:.3f}")

    print("\nFrequency Domain (averaged across conditions):")
    for band in freq_bands.keys():
        band_means = [q3_results['frequency'][band][c]['mean']
                     for c in conditions if c in q3_results['frequency'][band]]
        if band_means:
            print(f"  {band:8s}: {np.mean(band_means):.3f}")

    # Q4 Summary
    print("\n" + "-"*80)
    print("Q4: LABEL CONSISTENCY ACROSS SUBJECTS")
    print("-"*80)
    print("\nAre label representations consistent across subjects?")
    print("\nTemporal Domain (per condition):")

    for condition in conditions:
        if condition in q4_results['temporal']:
            mean = q4_results['temporal'][condition]['mean']
            median = q4_results['temporal'][condition]['median']
            std = q4_results['temporal'][condition]['std']
            print(f"  Condition {condition:3d}: μ={mean:.3f}, med={median:.3f}, σ={std:.3f}")

    valid_conditions = [c for c in conditions if c in q4_results['temporal']]
    if valid_conditions:
        overall_q4_temporal = np.mean([q4_results['temporal'][c]['mean'] for c in valid_conditions])
        print(f"\n  Overall:         {overall_q4_temporal:.3f}")

    print("\nFrequency Domain (best band per condition):")
    for condition in conditions:
        band_scores = {band: q4_results['frequency'][band][condition]['mean']
                      for band in freq_bands.keys()
                      if condition in q4_results['frequency'][band]}
        if band_scores:
            best_band = max(band_scores, key=band_scores.get)
            best_score = band_scores[best_band]
            print(f"  Condition {condition:3d}: {best_band} ({best_score:.3f})")


def save_results_to_csv(q3_results, q4_results, output_dir):
    """Save results to CSV files."""
    conditions = list(q3_results['temporal'].keys())

    # Q3 CSV - Channel consistency
    q3_rows = []
    for condition in conditions:
        channel_names = q3_results['temporal'][condition]['channel_names']
        channel_scores = q3_results['temporal'][condition]['channel_scores']
        for ch_name, score in zip(channel_names, channel_scores):
            q3_rows.append({
                'condition': condition,
                'channel': ch_name,
                'temporal_correlation': score
            })

    if q3_rows:
        q3_df = pd.DataFrame(q3_rows)
        q3_df.to_csv(output_dir / 'q3_channel_consistency.csv', index=False)
        print(f"\nSaved Q3 results to: {output_dir / 'q3_channel_consistency.csv'}")

    # Q4 CSV - Label consistency
    q4_rows = []
    for condition in conditions:
        if condition not in q4_results['temporal']:
            continue

        row = {
            'condition': condition,
            'temporal_mean': q4_results['temporal'][condition]['mean'],
            'temporal_std': q4_results['temporal'][condition]['std'],
            'temporal_median': q4_results['temporal'][condition]['median']
        }

        # Add frequency bands
        for band in freq_bands.keys():
            if condition in q4_results['frequency'][band]:
                row[f'{band}_mean'] = q4_results['frequency'][band][condition]['mean']
                row[f'{band}_std'] = q4_results['frequency'][band][condition]['std']

        q4_rows.append(row)

    if q4_rows:
        q4_df = pd.DataFrame(q4_rows)
        q4_df.to_csv(output_dir / 'q4_label_consistency.csv', index=False)
        print(f"Saved Q4 results to: {output_dir / 'q4_label_consistency.csv'}")

# ============================================================================
# MAIN EXECUTION
# ============================================================================

if __name__ == "__main__":
    print(f"\n{'='*80}")
    print(f"Cross-Subject Consistency Analysis - {TASK} Task")
    print(f"{'='*80}\n")

    # Get task configuration
    sfreq = TASK_CONFIG[TASK]['sfreq']

    # Load ALL data for all subjects
    root_dir = DATA_PATHS[TASK]
    print(f"Loading ALL {TASK} data from: {root_dir}")

    kwargs = {}
    if TASK == 'MI':
        kwargs['label_dir'] = None  # Set your label directory if needed

    subject_data = load_all_subjects_for_task(TASK, root_dir, **kwargs)

    print(f"\n{'='*40}")
    print(f"Found {len(subject_data)} subjects")
    for sid, conds in subject_data.items():
        print(f"  Subject {sid}: {len(conds)} conditions, ", end='')
        total_trials = sum(trials.shape[0] for trials in conds.values())
        print(f"{total_trials} total trials")
    print(f"{'='*40}\n")

    # # Analyze Q3: Channel consistency
    # print("\n" + "="*80)
    # print("STARTING Q3 ANALYSIS")
    # print("="*80)
    # q3_results = analyze_q3_channel_consistency(subject_data, TASK, sfreq)
    # plot_q3_results(q3_results, OUTPUT_DIR / 'q3_channel_consistency.png', TASK)

    # Analyze Q4: Label consistency
    print("\n" + "="*80)
    print("STARTING Q4 ANALYSIS")
    print("="*80)
    q4_results = analyze_q4_label_consistency(subject_data, TASK, sfreq)
    plot_q4_results(q4_results, OUTPUT_DIR / 'q4_label_consistency.png', TASK)

    # Generate report
    create_summary_report(q3_results, q4_results, TASK)

    # Save results to CSV
    save_results_to_csv(q3_results, q4_results, OUTPUT_DIR)

    print(f"\n{'='*80}")
    print(f"Analysis complete! Results saved to: {OUTPUT_DIR}")
    print(f"{'='*80}\n")


Cross-Subject Consistency Analysis - P300 Task

Loading ALL P300 data from: /content/drive/MyDrive/IDL/IDL Project Team 5 F25/dataset/p300/bi2015a/cleaned_data
Loaded Subject 1: 4956 trials
Loaded Subject 2: 1512 trials
Loaded Subject 3: 1044 trials
Loaded Subject 4: 1440 trials
Loaded Subject 5: 1188 trials
Loaded Subject 6: 1584 trials
Loaded Subject 7: 1224 trials
Loaded Subject 8: 1512 trials
Loaded Subject 9: 1332 trials
Loaded Subject 10: 1116 trials
Loaded Subject 11: 1440 trials
Loaded Subject 12: 2232 trials
Loaded Subject 13: 1368 trials
Loaded Subject 14: 1404 trials
Loaded Subject 15: 1260 trials
Loaded Subject 16: 1260 trials
Loaded Subject 17: 1620 trials
Loaded Subject 18: 1620 trials
Loaded Subject 19: 1800 trials
Loaded Subject 20: 1620 trials
Loaded Subject 21: 1260 trials
Loaded Subject 22: 1584 trials
Loaded Subject 23: 1368 trials
Loaded Subject 24: 1296 trials
Loaded Subject 25: 2088 trials
Loaded Subject 26: 1116 trials
Loaded Subject 27: 2844 trials
Loaded Subj

  bp = ax.boxplot(all_corrs, labels=[f'C{c}' for c in conditions], patch_artist=True)


Saved Q4 plot: /content/drive/MyDrive/IDL/IDL Project Team 5 F25/data analysis/P300/cross_subject_P300/q4_label_consistency.png


NameError: name 'q3_results' is not defined