<a href="https://colab.research.google.com/github/fat-91/Binary-Driver-Attentive-States/blob/main/ERP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install mne numpy pandas scikit-learn tensorflow matplotlib seaborn
!pip install scikit-learn

In [None]:
import mne
import numpy as np
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, confusion_matrix, roc_auc_score, roc_curve
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Conv1D, BatchNormalization, ReLU, MaxPooling1D, Flatten, Dropout, Dense
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.utils import class_weight
from sklearn.metrics import f1_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import seaborn as sns
import os

DEVIATION_MARKER_IDS = {'dev_left': 1, 'dev_right': 2}
RESPONSE_MARKER_ID = {'response': 3}
TMIN, TMAX = -0.2, 0.8
channels = ['FZ', 'CZ', 'PZ', 'P3', 'P4']
RT_THRESHOLD_PERCENTILE = 80

In [None]:
# 1. Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
base_path = '/content/drive/My Drive/data'

full_paths = [os.path.join(base_path, fname) for fname in file_list]
raws = []
for fpath in full_paths:
    # Read the EEG file (EEGLAB .set format)
    raw = mne.io.read_raw_eeglab(fpath, preload=True)
    common = [ch for ch in channels if ch in raw.ch_names]
    raw.pick(common)
    raws.append(raw)

raw = mne.concatenate_raws(raws)
print("Combined data shape (channels, timepoints):", raw.get_data().shape)

#Extracting the raw data as a NumPy array.
data = raw.get_data()
sfreq = raw.info['sfreq']
events, event_id = mne.events_from_annotations(raw)
print(sfreq)

channel_map = {}
for fpath in full_paths:
    raw_temp = mne.io.read_raw_eeglab(fpath, preload=False)
    channel_map[fpath] = raw_temp.ch_names

for fname, chs in channel_map.items():
    print(f"{fname} → {chs}")


selected_channels = raw.ch_names
print("\nSelected Channels after filtering:", selected_channels)

print("\nSnippet of Extracted Data (first 5 time points for each channel):")
for i, ch_name in enumerate(selected_channels):
    print(f"{ch_name}: {data[i, :5]}")

In [None]:
def get_reaction_times(events, sfreq):
    deviation_indices = [i for i, event_id in enumerate(events[:, 2]) if event_id in DEVIATION_MARKER_IDS.values()]
    rt_data = []
    for dev_idx in deviation_indices:
        dev_time = events[dev_idx, 0]
        next_responses = [j for j in range(dev_idx + 1, len(events)) if events[j, 2] in RESPONSE_MARKER_ID.values()]
        if next_responses:
            rt = (events[next_responses[0], 0] - dev_time) / sfreq
            rt_data.append({'event_index': dev_idx, 'rt': rt})
        else:
            rt_data.append({'event_index': dev_idx, 'rt': np.nan})
    return pd.DataFrame(rt_data)

def create_labels(rt_df, threshold_percentile):
    valid_rts = rt_df['rt'].dropna()
    if valid_rts.empty:
        rt_df['label'] = np.nan
        return rt_df, np.nan
    threshold = np.percentile(valid_rts, threshold_percentile)
    rt_df['label'] = rt_df['rt'].apply(lambda x: 1 if x > threshold else 0)
    return rt_df, threshold

In [None]:
rts = get_reaction_times(events, sfreq)
rts = rts.dropna(subset=['rt'])
rts, rt_thresh = create_labels(rts, RT_THRESHOLD_PERCENTILE)

# --- Create Labeled Events and Metadata ---
valid_event_indices = rts['event_index'].values
labeled_events = events[valid_event_indices]
metadata = pd.DataFrame({'label': rts['label'].values})

# --- Epoching ---
epochs = mne.Epochs(
    raw, labeled_events,
    event_id={str(v): v for v in DEVIATION_MARKER_IDS.values()},
    tmin=TMIN, tmax=TMAX, metadata=metadata,
    baseline=None, preload=True)

# --- Final Dataset ---
X = epochs.get_data()
y = epochs.metadata['label'].values.astype(int)
X = np.transpose(X, (0, 2, 1))


# --- Optimization Strategies ---
# 1. Data Augmentation (Simple Time Warping - careful with EEG)
def time_warp(data, sigma=0.1):
    from scipy.interpolate import CubicSpline
    n_samples, n_times, n_channels = data.shape
    new_data = np.zeros_like(data)
    for i in range(n_samples):
        for j in range(n_channels):
            tt = np.arange(n_times)
            warp_factor = np.random.normal(loc=1.0, scale=sigma, size=n_times)
            tt_stretched = np.cumsum(warp_factor)
            tt_stretched /= np.max(tt_stretched)
            f = CubicSpline(tt_stretched, data[i, :, j])
            tt_new = np.linspace(0, 1, n_times)
            new_data[i, :, j] = f(tt_new)
    return new_data

# Consider applying augmentation only to the training set
X_augmented = time_warp(X)
X_train_aug, X_test_aug, y_train_aug, y_test_aug = train_test_split(
    X_augmented, y, test_size=0.3, random_state=42, stratify=y
)

In [None]:
try:
    print(f"Epochs object created with {len(epochs)} epochs.")
    print(f"Metadata labels distribution:\n{epochs.metadata['label'].value_counts()}")
    print(f"Available channels in epochs: {epochs.ch_names}")
except NameError:
    print("Error: The 'epochs' object does not seem to exist.")
    print("Please make sure you have successfully run the previous code sections.")
    # Exit or raise error if epochs aren't ready
    exit()

available_channels = epochs.ch_names
picks_n200 = [ch for ch in ['FZ', 'CZ'] if ch in available_channels]
picks_p300 = [ch for ch in ['PZ', 'P3', 'P4'] if ch in available_channels]
theta_channel_names = ['FZ', 'CZ'] # Check spelling/case!
picks_theta = [ch for ch in theta_channel_names if ch in available_channels]

print(f"\nUsing channels for N200: {picks_n200}")
print(f"Using channels for P300: {picks_p300}")
print(f"Using channels for Theta: {picks_theta}")

# Check if any pick lists are empty
if not picks_n200:
    print("Warning: No suitable channels found for N200 analysis.")
if not picks_p300:
    print("Warning: No suitable channels found for P300 analysis.")
if not picks_theta:
    print("Warning: No suitable channels found for Theta analysis.")

baseline_period = (TMIN, 0) # (-0.2, 0) seconds
print(f"Applying baseline correction: {baseline_period}")


In [None]:
# Separate epochs based on the 'label' in metadata
epochs_fast = epochs[epochs.metadata['label'] == 0]
epochs_slow = epochs[epochs.metadata['label'] == 1]

print(f"\nNumber of fast RT epochs (label 0): {len(epochs_fast)}")
print(f"Number of slow RT epochs (label 1): {len(epochs_slow)}")


In [None]:
print("Epochs data shape:", epochs.get_data().shape)
print("Channels used:", epochs.info['ch_names'])


In [None]:
import numpy as np

print("NaNs in X_train_aug:", np.isnan(X_train_aug).sum())
print("Infs in X_train_aug:", np.isinf(X_train_aug).sum())
print("NaNs in y_train_aug:", np.isnan(y_train_aug).sum())
print("Infs in y_train_aug:", np.isinf(y_train_aug).sum())

np.unique(y_train_aug, return_counts=True)

In [None]:
print("y_train labels:", np.unique(y_train_aug, return_counts=True))
print("y_test labels:", np.unique(y_test_aug, return_counts=True))

In [None]:
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
all_histories = []
all_val_aucs = []
best_model = None
best_auc = -1

for fold, (train_index, val_index) in enumerate(skf.split(X, y)):
    print(f"\n--- Fold {fold + 1}/{n_splits} ---")

    X_train, X_val = X[train_index], X[val_index]
    y_train, y_val = y[train_index], y[val_index]

    # Normalize per fold (using overall mean/std from training set)
    X_mean_fold = X_train.mean()
    X_std_fold = X_train.std()
    X_train_norm_fold = (X_train - X_mean_fold) / X_std_fold
    X_val_norm_fold   = (X_val   - X_mean_fold) / X_std_fold

    # Calculate class weights per fold
    class_weights_fold = class_weight.compute_class_weight(
        class_weight='balanced',
        classes=np.unique(y_train),
        y=y_train
    )
    class_weights_dict_fold = dict(enumerate(class_weights_fold))
    print("Class Weights (Fold {}):".format(fold + 1), class_weights_dict_fold)
    model = Sequential([
        # Input layer with shape as per fold
        Input(shape=(X_train_norm_fold.shape[1], X_train_norm_fold.shape[2])),

        # First convolution block
        Conv1D(32, kernel_size=7, padding='same', activation='relu', kernel_initializer='he_normal'),
        BatchNormalization(),
        MaxPooling1D(pool_size=2),
        Dropout(0.3),

        # Second convolution block
        Conv1D(64, kernel_size=5, padding='same', activation='relu', kernel_initializer='he_normal'),
        BatchNormalization(),
        MaxPooling1D(pool_size=2),
        Dropout(0.3),

        # Third convolution block - optional if your data supports extra complexity
        Conv1D(128, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal'),
        BatchNormalization(),
        MaxPooling1D(pool_size=2),
        Dropout(0.3),

        Flatten(),
        Dense(128, activation='relu', kernel_initializer='he_normal'),
        BatchNormalization(),
        Dropout(0.5),
        Dense(1, activation='sigmoid')  # Binary classification
    ])

    # Compile the model with an Adam optimizer and a learning rate you can experiment with.
    optimizer = Adam(learning_rate=0.001)
    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy', tf.keras.metrics.AUC(name='auc')])

    # Callbacks: early stopping on val_auc, model checkpoint saving best model per fold, and reduce LR on plateau.
    callbacks = [
        EarlyStopping(patience=20, restore_best_weights=True, monitor='val_auc', mode='max'),
        ModelCheckpoint(f"best_model_fold_{fold + 1}.keras", monitor='val_auc', mode='max', save_best_only=True),
        ReduceLROnPlateau(monitor='val_auc', factor=0.5, patience=5, verbose=1, mode='max', min_lr=1e-5)
    ]

    # Train the model for this fold
    history = model.fit(
        X_train_norm_fold, y_train,
        validation_data=(X_val_norm_fold, y_val),
        epochs=100,  # You may increase epochs; early stopping is active.
        batch_size=32,
        class_weight=class_weights_dict_fold,
        callbacks=callbacks,
        verbose=2
    )
    all_histories.append(history.history)

    # Load the best model from this fold based on highest val_auc
    loaded_model_fold = tf.keras.models.load_model(f"best_model_fold_{fold + 1}.keras")

    # Evaluate on the validation set of this fold
    _, _, val_auc = loaded_model_fold.evaluate(X_val_norm_fold, y_val, verbose=0)
    print(f"Fold {fold + 1} Validation AUC: {val_auc:.4f}")
    all_val_aucs.append(val_auc)

    if val_auc > best_auc:
        best_auc = val_auc
        best_model = loaded_model_fold


X_train_full, X_test_full, y_train_full, y_test_full = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
X_mean_full = X_train_full.mean()
X_std_full = X_train_full.std()
X_test_norm_full = (X_test_full - X_mean_full) / X_std_full

loss, acc, auc = best_model.evaluate(X_test_norm_full, y_test_full, verbose=0)
print("\n--- Best Model Evaluation on Test Set ---")
print(f"Test Loss: {loss:.4f}")
print(f"Test Accuracy: {acc:.4f}")
print(f"Test AUC: {auc:.4f}")

# Generate predictions and evaluation metrics
y_pred_prob = best_model.predict(X_test_norm_full)
y_pred = (y_pred_prob > 0.5).astype(int)

f1 = f1_score(y_test_full, y_pred)
cm = confusion_matrix(y_test_full, y_pred)
report = classification_report(y_test_full, y_pred)

print(f"\nF1 Score: {f1:.4f}")
print("\nConfusion Matrix:")
print(cm)
print("\nClassification Report:")
print(report)

In [None]:
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

# --- General Plot for Loss Across All Folds ---
plt.figure(figsize=(6, 4))
for i, history in enumerate(all_histories):
    fold_num = i + 1
    epochs = range(1, len(history['loss']) + 1)
    plt.plot(epochs, history['loss'], label=f'Fold {fold_num} - Train Loss', linestyle='--')
    plt.plot(epochs, history['val_loss'], label=f'Fold {fold_num} - Validation Loss', linestyle='-')

plt.title('Loss Across All Folds')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# --- General Plot for Accuracy Across All Folds ---
plt.figure(figsize=(6, 4))
for i, history in enumerate(all_histories):
    fold_num = i + 1
    epochs = range(1, len(history['accuracy']) + 1)
    plt.plot(epochs, history['accuracy'], label=f'Fold {fold_num} - Train Accuracy', linestyle='--')
    plt.plot(epochs, history['val_accuracy'], label=f'Fold {fold_num} - Validation Accuracy', linestyle='-')

plt.title('Accuracy Across All Folds')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# --- General Plot for AUC Across All Folds ---
plt.figure(figsize=(6, 4))
for i, history in enumerate(all_histories):
    fold_num = i + 1
    epochs = range(1, len(history['auc']) + 1)
    plt.plot(epochs, history['auc'], label=f'Fold {fold_num} - Train AUC', linestyle='--')
    plt.plot(epochs, history['val_auc'], label=f'Fold {fold_num} - Validation AUC', linestyle='-')

plt.title('AUC Across All Folds')
plt.xlabel('Epochs')
plt.ylabel('AUC')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# 2. Visualize Validation AUCs Across Folds
plt.figure(figsize=(6, 4))
plt.bar(range(1, n_splits + 1), all_val_aucs, color='skyblue')
plt.xlabel('Fold')
plt.ylabel('Validation AUC')
plt.title('Validation AUC for Each Fold')
plt.xticks(range(1, n_splits + 1))
plt.ylim(0, 1)
plt.grid(axis='y', linestyle='--')
plt.show()

# 3. Visualize Confusion Matrix
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Negative (0)', 'Positive (1)'],
            yticklabels=['Negative (0)', 'Positive (1)'])
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Confusion Matrix on Test Set')
plt.show()

# 4. Visualize ROC Curve
fpr, tpr, thresholds = roc_curve(y_test_full, y_pred_prob)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(6, 4))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.show()



In [None]:
# --- General Plot for Loss Across All Folds (Averaged) ---
plt.figure(figsize=(6, 4))
all_train_loss = [history['loss'] for history in all_histories]
all_val_loss = [history['val_loss'] for history in all_histories]

min_epochs_loss = min(len(loss) for loss in all_train_loss)
avg_train_loss = np.mean([loss[:min_epochs_loss] for loss in all_train_loss], axis=0)
avg_val_loss = np.mean([val_loss[:min_epochs_loss] for val_loss in all_val_loss], axis=0)

epochs_loss = range(1, min_epochs_loss + 1)
plt.plot(epochs_loss, avg_train_loss, 'r--', label='Average Train Loss')
plt.plot(epochs_loss, avg_val_loss, 'b-', label='Average Validation Loss')
plt.title('Average Loss Across All Folds')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# --- General Plot for Accuracy Across All Folds (Averaged) ---
plt.figure(figsize=(6, 4))
all_train_accuracy = [history['accuracy'] for history in all_histories]
all_val_accuracy = [history['val_accuracy'] for history in all_histories]

min_epochs_accuracy = min(len(acc) for acc in all_train_accuracy)
avg_train_accuracy = np.mean([acc[:min_epochs_accuracy] for acc in all_train_accuracy], axis=0)
avg_val_accuracy = np.mean([val_acc[:min_epochs_accuracy] for val_acc in all_val_accuracy], axis=0)

epochs_accuracy = range(1, min_epochs_accuracy + 1)
plt.plot(epochs_accuracy, avg_train_accuracy, 'r--', label='Average Train Accuracy')
plt.plot(epochs_accuracy, avg_val_accuracy, 'b-', label='Average Validation Accuracy')
plt.title('Average Accuracy Across All Folds')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# --- General Plot for AUC Across All Folds (Averaged) ---
plt.figure(figsize=(6, 4))
all_train_auc = [history['auc'] for history in all_histories]
all_val_auc = [history['val_auc'] for history in all_histories]

min_epochs_auc = min(len(auc) for auc in all_train_auc)
avg_train_auc = np.mean([auc[:min_epochs_auc] for auc in all_train_auc], axis=0)
avg_val_auc = np.mean([val_auc[:min_epochs_auc] for val_auc in all_val_auc], axis=0)

epochs_auc = range(1, min_epochs_auc + 1)
plt.plot(epochs_auc, avg_train_auc, 'r--', label='Average Train AUC')
plt.plot(epochs_auc, avg_val_auc, 'b-', label='Average Validation AUC')
plt.title('Average AUC Across All Folds')
plt.xlabel('Epochs')
plt.ylabel('AUC')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
test_events = np.zeros((len(y_pred), 3), dtype=int)
test_events[:, 0] = np.arange(len(y_pred))  # Dummy sample numbers
test_events[:, 2] = 1  # Using a dummy event ID

test_epochs_data = np.transpose(X_test_full, (0, 2, 1)) # Ensure correct shape (n_epochs, n_channels, n_times)

info = mne.create_info(ch_names=channels, sfreq=sfreq, ch_types='eeg')
test_epochs = mne.EpochsArray(test_epochs_data, info, events=test_events, tmin=TMIN)

# Add the model's predictions as metadata
predicted_labels = pd.DataFrame({'label': y_pred.flatten()})
test_epochs.metadata = predicted_labels

# Separate epochs based on the model's predicted labels
epochs_fast_pred = test_epochs[test_epochs.metadata['label'] == 0]
epochs_slow_pred = test_epochs[test_epochs.metadata['label'] == 1]

print(f"\nNumber of epochs predicted as fast RT (label 0): {len(epochs_fast_pred)}")
print(f"Number of epochs predicted as slow RT (label 1): {len(epochs_slow_pred)}")

baseline_period = (TMIN, 0)

if len(epochs_fast_pred) > 0:
    evoked_fast_pred = epochs_fast_pred.average().apply_baseline(baseline=baseline_period)
else:
    evoked_fast_pred = None
    print("Warning: No epochs predicted as fast RT.")

if len(epochs_slow_pred) > 0:
    evoked_slow_pred = epochs_slow_pred.average().apply_baseline(baseline=baseline_period)
else:
    evoked_slow_pred = None
    print("Warning: No epochs predicted as slow RT.")

n200_window = (0.15, 0.28)
p300_window = (0.28, 0.60)
theta_window = (0.10, 0.60)
theta_freqs = np.arange(4, 9, 1)
n_cycles_theta = theta_freqs / 2.0
picks_n200 = [ch for ch in ['FZ', 'CZ'] if ch in channels]
picks_p300 = [ch for ch in ['PZ', 'P3', 'P4'] if ch in channels]
picks_theta = [ch for ch in ['FZ', 'CZ'] if ch in channels]

# --- N200 Analysis based on Model Predictions ---
print("\n--- N200 Analysis (Based on Model Predictions) ---")
if picks_n200 and evoked_fast_pred is not None and evoked_slow_pred is not None:
    fig_n200_pred = mne.viz.plot_compare_evokeds(
        {'Predicted Fast RT (label 0)': evoked_fast_pred, 'Predicted Slow RT (label 1)': evoked_slow_pred},
        picks=picks_n200,
        title=f'N200 Comparison (Predicted Labels - {", ".join(picks_n200)})',
        show_sensors='upper right',
        legend='upper left',
        ci=0.95
    )
    current_fig_n200_pred = fig_n200_pred[0] if isinstance(fig_n200_pred, list) else fig_n200_pred
    for i, ax in enumerate(current_fig_n200_pred.axes[:len(picks_n200)]):
        ax.axvspan(n200_window[0], n200_window[1], color='gray', alpha=0.2, label='N200 Window')
        # Add vertical lines for potential peaks (you might need to calculate these as before)
    plt.show()
elif not picks_n200:
    print("Skipping N200 analysis: No suitable channels found.")
else:
    print("Skipping N200 analysis: Not enough predicted epochs for both conditions.")

# --- P300 Analysis based on Model Predictions ---
print("\n--- P300 Analysis (Based on Model Predictions) ---")
if picks_p300 and evoked_fast_pred is not None and evoked_slow_pred is not None:
    fig_p300_pred = mne.viz.plot_compare_evokeds(
        {'Predicted Fast RT (label 0)': evoked_fast_pred, 'Predicted Slow RT (label 1)': evoked_slow_pred},
        picks=picks_p300,
        title=f'P300 Comparison (Predicted Labels - {", ".join(picks_p300)})',
        show_sensors='upper right',
        legend='upper left',
        ci=0.95
    )
    current_fig_p300_pred = fig_p300_pred[0] if isinstance(fig_p300_pred, list) else fig_p300_pred
    for i, ax in enumerate(current_fig_p300_pred.axes[:len(picks_p300)]):
        ax.axvspan(p300_window[0], p300_window[1], color='gray', alpha=0.2, label='P300 Window')
        # Add vertical lines for potential peaks
    plt.show()
elif not picks_p300:
    print("Skipping P300 analysis: No suitable channels found.")
else:
    print("Skipping P300 analysis: Not enough predicted epochs for both conditions.")





print(f"picks_theta: {picks_theta}")
print(f"Number of epochs predicted as fast RT (label 0): {len(epochs_fast_pred)}")
print(f"Number of epochs predicted as slow RT (label 1): {len(epochs_slow_pred)}")
print("\n--- Time-Frequency Analysis for Theta (Based on Model Predictions) ---")

# Ensure n_cycles_theta is compatible with epoch length and freqs
# Example: n_cycles_theta = theta_freqs / 2.0

if picks_theta and len(epochs_fast_pred) > 0 and len(epochs_slow_pred) > 0:
    try:
        # Use the newer .compute_tfr method
        tfr_fast_pred = epochs_fast_pred.compute_tfr(
            method="morlet", freqs=theta_freqs, n_cycles=n_cycles_theta, use_fft=True,
            return_itc=False, decim=3, n_jobs=-1, picks=picks_theta, verbose=False
        )
        tfr_slow_pred = epochs_slow_pred.compute_tfr(
            method="morlet", freqs=theta_freqs, n_cycles=n_cycles_theta, use_fft=True,
            return_itc=False, decim=3, n_jobs=-1, picks=picks_theta, verbose=False
        )

        # Apply baseline correction (percentage change is common for power)
        tfr_fast_pred.apply_baseline(baseline=baseline_period, mode='percent')
        tfr_slow_pred.apply_baseline(baseline=baseline_period, mode='percent')

        # Calculate the averaged TFR for plotting TFR maps
        avg_tfr_fast_pred = tfr_fast_pred.average()
        avg_tfr_slow_pred = tfr_slow_pred.average()



        # --- Visualize Average Theta Power Over Time ---
        times_tfr_pred = avg_tfr_fast_pred.times  # Get time points from the averaged TFR
        theta_power_fast_pred = avg_tfr_fast_pred.data.mean(axis=(0, 1))  # Mean over channels & freqs
        theta_power_slow_pred = avg_tfr_slow_pred.data.mean(axis=(0, 1))

        plt.figure(figsize=(10, 6))
        plt.plot(times_tfr_pred, theta_power_fast_pred, label='Predicted Fast Condition')
        plt.plot(times_tfr_pred, theta_power_slow_pred, label='Predicted Slow Condition')
        plt.xlabel('Time (s)')
        plt.ylabel('Average Theta Power (% Change from Baseline)')
        plt.title('Average Theta Power Over Time (Based on Model Predictions)')
        plt.legend()
        plt.grid(True)
        plt.show()

    except Exception as e:
        print(f"Error during TFR analysis (predicted labels): {e}")
elif not picks_theta:
    print("Skipping Theta TFR analysis: No suitable channels found.")
else:
    print("Skipping Theta TFR analysis: Not enough predicted epochs for both conditions.")






In [None]:
# Normalize training data
X_train_norm_full = (X_train_full - X_mean_full) / X_std_full

# Predict on training data
y_pred_prob_train = best_model.predict(X_train_norm_full)
y_pred_train = (y_pred_prob_train > 0.5).astype(int)

# Create dummy events for training data
train_events = np.zeros((len(y_pred_train), 3), dtype=int)
train_events[:, 0] = np.arange(len(y_pred_train))
train_events[:, 2] = 1  # Dummy event code

# Prepare EEG data: (n_epochs, n_channels, n_times)
train_epochs_data = np.transpose(X_train_full, (0, 2, 1))

# Create MNE Epochs
info = mne.create_info(ch_names=channels, sfreq=sfreq, ch_types='eeg')
train_epochs = mne.EpochsArray(train_epochs_data, info, events=train_events, tmin=TMIN)

# Add predictions as metadata
train_epochs.metadata = pd.DataFrame({'label': y_pred_train.flatten()})

# Split into predicted classes
epochs_fast_pred = train_epochs[train_epochs.metadata['label'] == 0]
epochs_slow_pred = train_epochs[train_epochs.metadata['label'] == 1]


print(f"\nNumber of epochs predicted as fast RT (label 0): {len(epochs_fast_pred)}")
print(f"Number of epochs predicted as slow RT (label 1): {len(epochs_slow_pred)}")

baseline_period = (TMIN, 0)

if len(epochs_fast_pred) > 0:
    evoked_fast_pred = epochs_fast_pred.average().apply_baseline(baseline=baseline_period)
else:
    evoked_fast_pred = None
    print("Warning: No epochs predicted as fast RT.")

if len(epochs_slow_pred) > 0:
    evoked_slow_pred = epochs_slow_pred.average().apply_baseline(baseline=baseline_period)
else:
    evoked_slow_pred = None
    print("Warning: No epochs predicted as slow RT.")

n200_window = (0.15, 0.28)
p300_window = (0.28, 0.60)
theta_window = (0.10, 0.60)
theta_freqs = np.arange(4, 9, 1)
n_cycles_theta = theta_freqs / 2.0
picks_n200 = [ch for ch in ['FZ', 'CZ'] if ch in channels]
picks_p300 = [ch for ch in ['PZ', 'P3', 'P4'] if ch in channels]
picks_theta = [ch for ch in ['FZ', 'CZ'] if ch in channels]

# --- N200 Analysis based on Model Predictions ---
print("\n--- N200 Analysis (Based on Model Predictions) ---")
if picks_n200 and evoked_fast_pred is not None and evoked_slow_pred is not None:
    fig_n200_pred = mne.viz.plot_compare_evokeds(
        {'Fast RT (label 0)': evoked_fast_pred, 'Slow RT (label 1)': evoked_slow_pred},
        picks=picks_n200,
        title=f'N200 Comparison {", ".join(picks_n200)})',
        show_sensors='upper right',
        legend='upper left',
        ci=0.95
    )
    current_fig_n200_pred = fig_n200_pred[0] if isinstance(fig_n200_pred, list) else fig_n200_pred
    for i, ax in enumerate(current_fig_n200_pred.axes[:len(picks_n200)]):
        ax.axvspan(n200_window[0], n200_window[1], color='gray', alpha=0.2, label='N200 Window')
        # Add vertical lines for potential peaks (you might need to calculate these as before)
    plt.show()
elif not picks_n200:
    print("Skipping N200 analysis: No suitable channels found.")
else:
    print("Skipping N200 analysis: Not enough predicted epochs for both conditions.")

# --- P300 Analysis based on Model Predictions ---
print("\n--- P300 Analysis (Based on Model Predictions) ---")
if picks_p300 and evoked_fast_pred is not None and evoked_slow_pred is not None:
    fig_p300_pred = mne.viz.plot_compare_evokeds(
        {'Fast RT (label 0)': evoked_fast_pred, 'Slow RT (label 1)': evoked_slow_pred},
        picks=picks_p300,
        title=f'P300 Comparison {", ".join(picks_p300)})',
        show_sensors='upper right',
        legend='upper left',
        ci=0.95
    )
    current_fig_p300_pred = fig_p300_pred[0] if isinstance(fig_p300_pred, list) else fig_p300_pred
    for i, ax in enumerate(current_fig_p300_pred.axes[:len(picks_p300)]):
        ax.axvspan(p300_window[0], p300_window[1], color='gray', alpha=0.2, label='P300 Window')
        # Add vertical lines for potential peaks
    plt.show()
elif not picks_p300:
    print("Skipping P300 analysis: No suitable channels found.")
else:
    print("Skipping P300 analysis: Not enough predicted epochs for both conditions.")





print(f"picks_theta: {picks_theta}")
print(f"Number of epochs predicted as fast RT (label 0): {len(epochs_fast_pred)}")
print(f"Number of epochs predicted as slow RT (label 1): {len(epochs_slow_pred)}")
print("\n--- Time-Frequency Analysis for Theta (Based on Model Predictions) ---")

# Ensure n_cycles_theta is compatible with epoch length and freqs
# Example: n_cycles_theta = theta_freqs / 2.0

if picks_theta and len(epochs_fast_pred) > 0 and len(epochs_slow_pred) > 0:
    try:
        # Use the newer .compute_tfr method
        tfr_fast_pred = epochs_fast_pred.compute_tfr(
            method="morlet", freqs=theta_freqs, n_cycles=n_cycles_theta, use_fft=True,
            return_itc=False, decim=3, n_jobs=-1, picks=picks_theta, verbose=False
        )
        tfr_slow_pred = epochs_slow_pred.compute_tfr(
            method="morlet", freqs=theta_freqs, n_cycles=n_cycles_theta, use_fft=True,
            return_itc=False, decim=3, n_jobs=-1, picks=picks_theta, verbose=False
        )

        # Apply baseline correction (percentage change is common for power)
        tfr_fast_pred.apply_baseline(baseline=baseline_period, mode='percent')
        tfr_slow_pred.apply_baseline(baseline=baseline_period, mode='percent')

        # Calculate the averaged TFR for plotting TFR maps
        avg_tfr_fast_pred = tfr_fast_pred.average()
        avg_tfr_slow_pred = tfr_slow_pred.average()


        # --- Visualize Average Theta Power Over Time ---
        times_tfr_pred = avg_tfr_fast_pred.times  # Get time points from the averaged TFR
        theta_power_fast_pred = avg_tfr_fast_pred.data.mean(axis=(0, 1))  # Mean over channels & freqs
        theta_power_slow_pred = avg_tfr_slow_pred.data.mean(axis=(0, 1))

        plt.figure(figsize=(10, 6))
        plt.plot(times_tfr_pred, theta_power_fast_pred, label='Fast Condition')
        plt.plot(times_tfr_pred, theta_power_slow_pred, label='Slow Condition')
        plt.xlabel('Time (s)')
        plt.ylabel('Average Theta Power (% Change from Baseline)')
        plt.title('Average Theta Power Over Time')
        plt.legend()
        plt.grid(True)
        plt.show()

    except Exception as e:
        print(f"Error during TFR analysis (predicted labels): {e}")
elif not picks_theta:
    print("Skipping Theta TFR analysis: No suitable channels found.")
else:
    print("Skipping Theta TFR analysis: Not enough predicted epochs for both conditions.")

In [None]:
import mne.stats # Import MNE stats module

# -
if 'epochs' not in locals():
    raise NameError("Original 'epochs' object not found. Please ensure your preceding code creates it.")
info = epochs.info

# Ensure TMIN is defined for EpochsArray and baseline
try:
    TMIN # Check if TMIN is defined
except NameError:
     print("TMIN not found, using -0.2 as a default. Please define TMIN based on your epoching.")
     TMIN = -0.2 # Default value, replace with your actual TMIN

baseline_period = (TMIN, 0) # Define baseline period


# --- Create EpochsArray from Split Train Data (for True Label Analysis) ---
print("\n--- Preparing Train Data for Statistical Analysis (True Labels) ---")

if X_train_full.shape[-1] == len(info['ch_names']) and X_train_full.shape[1] != len(info['ch_names']):
    print(f"Transposing X_train_full from {X_train_full.shape} to (n_epochs, n_channels, n_times)")
    X_train_mne_shape = np.transpose(X_train_full, (0, 2, 1))
else:
     print(f"X_train_full shape {X_train_full.shape} looks suitable for MNE (n_epochs, n_channels, n_times).")
     X_train_mne_shape = X_train_full # Assume it's already in the right shape


# Create dummy events for EpochsArray (needed even if times are relative to epoch start)
train_events = np.zeros((len(X_train_mne_shape), 3), dtype=int)
train_events[:, 0] = np.arange(len(X_train_mne_shape)) # Dummy sample numbers starting from 0
train_events[:, 2] = 1 # Using a dummy event ID

# Create EpochsArray for Training Data using TRUE labels (y_train_full)
epochs_train_array = mne.EpochsArray(
    X_train_mne_shape,
    info, # Use info from original epochs to get channel names and locations
    events=train_events,
    tmin=TMIN, # Use the correct epoch start time
    verbose=False
)
# Add true training labels as metadata
if len(y_train_full) != len(epochs_train_array):
     raise ValueError(f"Length of y_train_full ({len(y_train_full)}) does not match the number of train epochs ({len(epochs_train_array)}).")
epochs_train_array.metadata = pd.DataFrame({'label': y_train_full.flatten()})
print(f"Created epochs_train_array with {len(epochs_train_array)} epochs.")

# Filter based on the 'label' metadata (true labels)
epochs_train_fast_true = epochs_train_array[epochs_train_array.metadata['label'] == 0]
epochs_train_slow_true = epochs_train_array[epochs_train_array.metadata['label'] == 1]

print(f"Number of train epochs (True Labels) Fast RT (label 0): {len(epochs_train_fast_true)}")
print(f"Number of train epochs (True Labels) Slow RT (label 1): {len(epochs_train_slow_true)}")



try:
    theta_freqs # Check if theta_freqs is defined
except NameError:
    print("theta_freqs not found. Skipping Theta TFR calculation for train data.")
    theta_freqs = None # Ensure it's None if not defined (overwrites check above for simplicity)

tfr_fast_train_true = None
tfr_slow_train_true = None

# Ensure theta_freqs is defined for TFR
if theta_freqs is not None and len(epochs_train_fast_true) > 0 and len(epochs_train_slow_true) > 0:
    print("\n--- Computing Time-Frequency (True Labels on Train Data) ---")
    # Ensure n_cycles_theta is compatible
    n_cycles_theta = theta_freqs / 2.0 # Common choice

    # Define picks for theta analysis (using all EEG channels if not specified)
    try:
        picks_theta # Check if picks_theta is defined
    except NameError:
         print("picks_theta not found. Defining default for Theta TFR (all EEG channels).")
         picks_theta = mne.pick_types(info, eeg=True, exclude='bads')
         if len(picks_theta) == 0:
             print("No EEG channels found after excluding bads, trying all channels.")
             picks_theta = mne.pick_types(info, eeg=True) # Try all if excluding bads results in none


    if picks_theta and len(picks_theta) > 0:
        try:
            # Use the newer .compute_tfr method
            tfr_fast_train_true = epochs_train_fast_true.compute_tfr(
                method="morlet", freqs=theta_freqs, n_cycles=n_cycles_theta, use_fft=True,
                return_itc=False, decim=3, n_jobs=-1, picks=picks_theta, verbose=False
            )
            tfr_slow_train_true = epochs_train_slow_true.compute_tfr(
                method="morlet", freqs=theta_freqs, n_cycles=n_cycles_theta, use_fft=True,
                return_itc=False, decim=3, n_jobs=-1, picks=picks_theta, verbose=False
            )

            # Apply baseline correction (percentage change is common for power)
            tfr_fast_train_true.apply_baseline(baseline=baseline_period, mode='percent')
            tfr_slow_train_true.apply_baseline(baseline=baseline_period, mode='percent')
            print("Train TFR computation complete.")

        except Exception as e:
            print(f"Error during Train TFR computation: {e}")
            tfr_fast_train_true = None
            tfr_slow_train_true = None
    else:
         print("Skipping Train TFR computation: No suitable channels found for picks_theta.")

else:
    print("\nSkipping Train TFR computation: Not enough true train epochs or theta_freqs not defined.")


# --- Statistical Verification (True Labels on Train Data) ---
print("\n--- Statistical Verification (True Labels on Train Data) ---")

# --- 1. Statistical Test for Evoked Potential Differences (True Slow vs Fast on Train Data) ---
print("\n--> Comparing Evoked Potentials (True Slow vs Fast on Train Data)")

# Check if we have enough epochs in both true train conditions for stats
if len(epochs_train_fast_true) > 1 and len(epochs_train_slow_true) > 1:
    # Extract data arrays for the statistical test
    # Shape should be (n_observations, n_channels, n_times)
    data_fast_train = epochs_train_fast_true.get_data(copy=False).astype(np.float64)
    data_slow_train = epochs_train_slow_true.get_data(copy=False).astype(np.float64)

    # Define connectivity for the cluster test (based on channel layout)
    # This assumes your `info` object has a montage with channel locations
    try:
        # Attempt to find connectivity. If it fails, the test will run without spatial connectivity.
        connectivity, ch_names_conn = mne.channels.find_nearest_neighbors(info, ch_type='eeg')
    except Exception:
         try:
             print("find_nearest_neighbors failed, trying find_layout...")
             connectivity, ch_names_conn = mne.channels.find_layout(info, ch_types='eeg').get_connectivity(info)
         except Exception as e_conn:
             print(f"Could not determine channel connectivity: {e_conn}")
             print("Train Evoked cluster test will run without spatial connectivity.")
             connectivity = None


    # --- Cluster-based permutation test ---
    # This test compares two independent groups of epochs across time and channels
    print(f"Running spatio-temporal cluster permutation test ({'with' if connectivity is not None else 'without'} spatial connectivity)...")
    try:
        # Set a threshold based on t-value. A value around 2.0 is a common starting point.
        # This threshold determines which data points are included in potential clusters.
        threshold = 2.0 # T-value threshold

        # Number of permutations. More permutations give more accurate p-values but take longer.
        # For publication, 10000+ is often used, but 1000-2000 is good for exploration.
        n_permutations = 1000

        T_obs_train, clusters_train, p_values_train, _ = mne.stats.spatio_temporal_cluster_test(
            [data_slow_train, data_fast_train], # List of data arrays for each group/condition (Slow vs Fast)
            n_permutations=n_permutations,
            threshold=threshold,
            tail=0, # 0: two-tailed (look for slow > fast or slow < fast)
            stat_fun='indep_t', # Independent t-test suitable for comparing two groups of epochs
            connectivity=connectivity, # Use determined connectivity (or None)
            n_jobs=-1, # Use all available cores
            verbose=False # Set to True for more details during the test
        )

        print("Train Evoked cluster permutation test complete.")

        # --- Report significant clusters ---
        alpha = 0.05 # Significance level for the clusters
        good_clusters_train_idx = [i for i, p_val in enumerate(p_values_train) if p_val < alpha]

        if good_clusters_train_idx:
            print(f"\nFound {len(good_clusters_train_idx)} significant spatio-temporal clusters for Train Evoked (p < {alpha}):")
            # Get the time vector from the epochs object
            times = epochs_train_array.times
            for i in good_clusters_train_idx:
                T_obs_this_cluster = T_obs_train[clusters_train[i]] # Get the T-values within this cluster
                min_t = np.min(T_obs_this_cluster)
                max_t = np.max(T_obs_this_cluster)
                cluster_p_value = p_values_train[i]

                # Get cluster spatial and temporal indices
                cluster_channels_idx, cluster_times_idx = clusters_train[i]

                # Convert time indices to seconds and find range
                unique_times_idx = np.unique(cluster_times_idx)
                cluster_times = times[unique_times_idx]
                tmin_cluster, tmax_cluster = cluster_times.min(), cluster_times.max()

                # Get channel names involved in the cluster
                unique_channels_idx = np.unique(cluster_channels_idx)
                cluster_channel_names = [epochs_train_array.ch_names[j] for j in unique_channels_idx]

                print(f"  - Train Cluster {i+1}: p-value = {cluster_p_value:.4f}")
                print(f"    Time range: [{tmin_cluster:.3f} s, {tmax_cluster:.3f} s]")
                print(f"    Channels ({len(unique_channels_idx)}) : {cluster_channel_names}")
                print(f"    T-value range: [{min_t:.2f}, {max_t:.2f}]")
                # Optional: You could add code here to visualize this cluster on the evoked average plot for the train data

        else:
            print(f"\nNo significant spatio-temporal clusters found for Train Evoked at p < {alpha}.")

    except Exception as e:
        print(f"An unexpected error occurred during the Train Evoked cluster test: {e}")
        print("Skipping Train Evoked cluster test due to error.")
else:
    print("\nSkipping Train Evoked cluster test: Not enough epochs (min 2 per condition) in true train sets.")


# --- 2. Statistical Test for Time-Frequency Power Differences (True Slow vs Fast on Train Data) ---
print("\n--> Comparing Theta Power (True Slow vs Fast on Train Data)")

# Ensure TFR objects for true train data are available and have enough epochs
if tfr_fast_train_true is not None and tfr_slow_train_true is not None and len(tfr_fast_train_true.data) > 1 and len(tfr_slow_train_true.data) > 1:
    # Extract power data
    # Shape is (n_epochs, n_channels, n_freqs, n_times)
    power_fast_train = tfr_fast_train_true.data.astype(np.float64)
    power_slow_train = tfr_slow_train_true.data.astype(np.float64)

    # Define connectivity (using the same spatial connectivity from evoked)
    # Find frequency indices for the theta band
    try:
        freqs = tfr_fast_train_true.freqs # Get the actual frequencies from the TFR object
        if theta_freqs is None:
             print("theta_freqs not defined, cannot find indices for theta band.")
             theta_freq_indices = []
        else:
             # Ensure theta_freqs match the freqs in the TFR object calculation
             theta_freq_indices = np.where(np.logical_and(freqs >= theta_freqs.min(), freqs <= theta_freqs.max()))[0]
             if len(theta_freq_indices) == 0:
                 print(f"Warning: No frequencies in TFR object fall within the specified theta band ({theta_freqs.min()}-{theta_freqs.max()} Hz).")

    except Exception as e_freq:
         print(f"Could not get frequencies from Train TFR object or find theta band indices: {e_freq}")
         theta_freq_indices = []


    if len(theta_freq_indices) > 0:
         print(f"Averaging power over {len(theta_freq_indices)} frequencies in the theta band for Train TFR cluster test.")
         # Ensure we average over a valid axis
         if power_fast_train.ndim == 4 and power_fast_train.shape[2] > 0:
             power_fast_train_theta_avg = power_fast_train[:, :, theta_freq_indices, :].mean(axis=2) # Shape (n_epochs, n_channels, n_times)
             power_slow_train_theta_avg = power_slow_train[:, :, theta_freq_indices, :].mean(axis=2) # Shape (n_epochs, n_channels, n_times)
         else:
              print("Train TFR power data shape is not as expected for averaging over frequency (expected 4D).")
              power_fast_train_theta_avg = None
              power_slow_train_theta_avg = None
    else:
         print("Skipping Train TFR cluster test: No frequencies found in the specified theta band or issues getting freq indices.")
         power_fast_train_theta_avg = None
         power_slow_train_theta_avg = None


    if power_fast_train_theta_avg is not None and power_slow_train_theta_avg is not None:
         print(f"Running spatio-temporal cluster permutation test for average Train Theta power ({'with' if connectivity is not None else 'without'} spatial connectivity)...")

         try:
             # Use similar threshold and permutations
             threshold_tf = 2.0
             n_permutations_tf = 1000

             T_obs_tf_train, clusters_tf_train, p_values_tf_train, _ = mne.stats.spatio_temporal_cluster_test(
                 [power_slow_train_theta_avg, power_fast_train_theta_avg], # Slow vs Fast
                 n_permutations=n_permutations_tf,
                 threshold=threshold_tf,
                 tail=0, # Two-tailed
                 stat_fun='indep_t',
                 connectivity=connectivity, # Spatial connectivity (reused from evoked test)
                 n_jobs=-1,
                 verbose=False
             )

             print("Train Theta power cluster permutation test complete.")

             # --- Report significant clusters ---
             alpha = 0.05
             good_clusters_tf_train_idx = [i for i, p_val in enumerate(p_values_tf_train) if p_val < alpha]

             if good_clusters_tf_train_idx:
                 print(f"\nFound {len(good_clusters_tf_train_idx)} significant spatio-temporal clusters for Train Theta power (p < {alpha}):")
                 times_tf = tfr_fast_train_true.times # Or epochs_train_array.times
                 for i in good_clusters_tf_train_idx:
                     T_obs_this_cluster = T_obs_tf_train[clusters_tf_train[i]]
                     min_t = np.min(T_obs_this_cluster)
                     max_t = np.max(T_obs_this_cluster)
                     cluster_p_value = p_values_tf_train[i]

                     cluster_channels_idx, cluster_times_idx = clusters_tf_train[i]
                     unique_times_idx = np.unique(cluster_times_idx)
                     cluster_times = times_tf[unique_times_idx]
                     tmin_cluster, tmax_cluster = cluster_times.min(), cluster_times.max()

                     unique_channels_idx = np.unique(cluster_channels_idx)
                     cluster_channel_names = [tfr_fast_train_true.ch_names[j] for j in unique_channels_idx] # Or epochs_train_array.ch_names

                     print(f"  - Train Theta Cluster {i+1}: p-value = {cluster_p_value:.4f}")
                     print(f"    Time range: [{tmin_cluster:.3f} s, {tmax_cluster:.3f} s]")
                     print(f"    Channels ({len(unique_channels_idx)}) : {cluster_channel_names}")
                     print(f"    T-value range: [{min_t:.2f}, {max_t:.2f}]")
                     # Optional: Add plotting code here to visualize this cluster on the TFR plot for the train data

             else:
                 print(f"\nNo significant spatio-temporal clusters found for Train Theta power at p < {alpha}.")

         except Exception as e:
             print(f"An unexpected error occurred during the Train Theta power cluster test: {e}")
             print("Skipping Train Theta power cluster test due to error.")
    else:
         print("\nSkipping Train TFR cluster test: Data averaging failed or data not available.")

else:
     print("\nSkipping Train TFR cluster test: Not enough epochs (min 2 per condition) or TFR objects not available for true train sets.")


# --- Create EpochsArray from Split Test Data (for Prediction Analysis) ---
# This section prepares the data for the Test (Predicted) stats

print("\n--- Preparing Test Data for Statistical Analysis (Predicted Labels) ---")

# MNE's EpochsArray expects data in shape (n_epochs, n_channels, n_times)
# Check the shape of X_test_full and transpose if needed
# Assuming X_test_full is (n_test_epochs, n_times, n_channels)
if X_test_full.shape[-1] == len(info['ch_names']) and X_test_full.shape[1] != len(info['ch_names']):
    print(f"Transposing X_test_full from {X_test_full.shape} to (n_epochs, n_channels, n_times)")
    X_test_mne_shape = np.transpose(X_test_full, (0, 2, 1))
else:
    print(f"X_test_full shape {X_test_full.shape} looks suitable for MNE (n_epochs, n_channels, n_times).")
    X_test_mne_shape = X_test_full # Assume it's already in the right shape


# Create dummy events for EpochsArray
test_events = np.zeros((len(X_test_mne_shape), 3), dtype=int)
test_events[:, 0] = np.arange(len(X_test_mne_shape))
test_events[:, 2] = 1 # Dummy event ID

# Create EpochsArray for Test Data using PREDICTED labels
epochs_test_pred_array = mne.EpochsArray(
    X_test_mne_shape,
    info,
    events=test_events,
    tmin=TMIN,
    verbose=False
)
# Add predicted test labels as metadata
if len(y_pred) != len(epochs_test_pred_array):
     raise ValueError(f"Length of y_pred ({len(y_pred)}) does not match the number of test epochs ({len(epochs_test_pred_array)}).")
epochs_test_pred_array.metadata = pd.DataFrame({'predicted_label': y_pred.flatten()})
print(f"Created epochs_test_pred_array with {len(epochs_test_pred_array)} epochs.")

# Filter based on the 'predicted_label' metadata
epochs_test_fast_pred = epochs_test_pred_array[epochs_test_pred_array.metadata['predicted_label'] == 0]
epochs_test_slow_pred = epochs_test_pred_array[epochs_test_pred_array.metadata['predicted_label'] == 1]

print(f"Number of test epochs predicted as fast RT (label 0): {len(epochs_test_fast_pred)}")
print(f"Number of test epochs predicted as slow RT (label 1): {len(epochs_test_slow_pred)}")


# --- Compute Time-Frequency (for Predicted Labels on Test Data) ---
# This is needed for Test (Predicted) stats

# Ensure theta_freqs is defined for TFR
if theta_freqs is not None and len(epochs_test_fast_pred) > 0 and len(epochs_test_slow_pred) > 0:
    print("\n--- Computing Time-Frequency (Predicted Labels on Test Data) ---")
    # Ensure n_cycles_theta is compatible
    n_cycles_theta = theta_freqs / 2.0 # Common choice

    # Define picks for theta analysis (using all EEG channels if not specified)
    try:
        picks_theta # Check if picks_theta is defined (defined earlier for train TFR)
    except NameError:
         print("picks_theta not found. Defining default for Theta TFR (all EEG channels).")
         picks_theta = mne.pick_types(info, eeg=True, exclude='bads')
         if len(picks_theta) == 0:
             print("No EEG channels found after excluding bads, trying all channels.")
             picks_theta = mne.pick_types(info, eeg=True)

    if picks_theta and len(picks_theta) > 0:
        try:
            tfr_fast_pred = epochs_test_fast_pred.compute_tfr(
                method="morlet", freqs=theta_freqs, n_cycles=n_cycles_theta, use_fft=True,
                return_itc=False, decim=3, n_jobs=-1, picks=picks_theta, verbose=False
            )
            tfr_slow_pred = epochs_test_slow_pred.compute_tfr(
                method="morlet", freqs=theta_freqs, n_cycles=n_cycles_theta, use_fft=True,
                return_itc=False, decim=3, n_jobs=-1, picks=picks_theta, verbose=False
            )

            # Apply baseline correction
            tfr_fast_pred.apply_baseline(baseline=baseline_period, mode='percent')
            tfr_slow_pred.apply_baseline(baseline=baseline_period, mode='percent')
            print("Test TFR computation complete.")

        except Exception as e:
            print(f"Error during Test TFR computation for predicted labels: {e}")
            tfr_fast_pred = None
            tfr_slow_pred = None
    else:
         print("Skipping Test TFR computation: No suitable channels found for picks_theta.")

else:
    print("\nSkipping Test TFR computation: Not enough predicted test epochs or theta_freqs not defined.")


# --- Statistical Verification (Predicted Labels on Test Data) ---
# Adding Cluster-based permutation tests

print("\n--- Statistical Verification (Predicted Labels on Test Data) ---")

# --- 1. Statistical Test for Evoked Potential Differences (Predicted Slow vs Fast on Test Data) ---
print("\n--> Comparing Evoked Potentials (Predicted Slow vs Fast on Test Data)")

# Check if we have enough epochs in both predicted conditions on the test set for stats
if len(epochs_test_fast_pred) > 1 and len(epochs_test_slow_pred) > 1:
    # Extract data arrays for the statistical test
    data_fast = epochs_test_fast_pred.get_data(copy=False).astype(np.float64)
    data_slow = epochs_test_slow_pred.get_data(copy=False).astype(np.float64)

    # Define connectivity (reuse the connectivity found during train stats if successful)
    try:
        connectivity # Check if connectivity was found during train stats
    except NameError:
         print("Connectivity not found from train stats, trying to find it now for test stats.")
         try:
            connectivity, ch_names_conn = mne.channels.find_nearest_neighbors(info, ch_type='eeg')
         except Exception:
             try:
                 print("find_nearest_neighbors failed, trying find_layout...")
                 connectivity, ch_names_conn = mne.channels.find_layout(info, ch_types='eeg').get_connectivity(info)
             except Exception as e_conn:
                 print(f"Could not determine channel connectivity: {e_conn}")
                 print("Test Evoked cluster test will run without spatial connectivity.")
                 connectivity = None


    # --- Cluster-based permutation test ---
    print(f"Running spatio-temporal cluster permutation test ({'with' if connectivity is not None else 'without'} spatial connectivity)...")
    try:
        # Use the same threshold and permutations as for the train data for comparison
        threshold = 2.0 # T-value threshold

        # Number of permutations. More permutations give more accurate p-values but take longer.
        n_permutations = 1000

        T_obs_test, clusters_test, p_values_test, _ = mne.stats.spatio_temporal_cluster_test(
            [data_slow, data_fast], # List of data arrays for each group/condition (Slow vs Fast)
            n_permutations=n_permutations,
            threshold=threshold,
            tail=0, # Two-tailed
            stat_fun='indep_t',
            connectivity=connectivity, # Use determined connectivity (or None)
            n_jobs=-1, # Use all available cores
            verbose=False # Set to True for more details during the test
        )

        print("Test Evoked cluster permutation test complete.")

        # --- Report significant clusters ---
        alpha = 0.05 # Significance level for the clusters
        good_clusters_test_idx = [i for i, p_val in enumerate(p_values_test) if p_val < alpha]

        if good_clusters_test_idx:
            print(f"\nFound {len(good_clusters_test_idx)} significant spatio-temporal clusters for Test Evoked (p < {alpha}):")
            # Get the time vector from the epochs object
            times = epochs_test_pred_array.times
            for i in good_clusters_test_idx:
                T_obs_this_cluster = T_obs_test[clusters_test[i]] # Get the T-values within this cluster
                min_t = np.min(T_obs_this_cluster)
                max_t = np.max(T_obs_this_cluster)
                cluster_p_value = p_values_test[i]

                # Get cluster spatial and temporal indices
                cluster_channels_idx, cluster_times_idx = clusters_test[i]

                # Convert time indices to seconds and find range
                unique_times_idx = np.unique(cluster_times_idx)
                cluster_times = times[unique_times_idx]
                tmin_cluster, tmax_cluster = cluster_times.min(), cluster_times.max()

                # Get channel names involved in the cluster
                unique_channels_idx = np.unique(cluster_channels_idx)
                cluster_channel_names = [epochs_test_pred_array.ch_names[j] for j in unique_channels_idx]

                print(f"  - Test Cluster {i+1}: p-value = {cluster_p_value:.4f}")
                print(f"    Time range: [{tmin_cluster:.3f} s, {tmax_cluster:.3f} s]")
                print(f"    Channels ({len(unique_channels_idx)}) : {cluster_channel_names}")
                print(f"    T-value range: [{min_t:.2f}, {max_t:.2f}]")
                # Optional: You could add code here to visualize this cluster on the evoked average plot for the test data

        else:
            print(f"\nNo significant spatio-temporal clusters found for Test Evoked at p < {alpha}.")

    except Exception as e:
        print(f"An unexpected error occurred during the Test Evoked cluster test: {e}")
        print("Skipping Test Evoked cluster test due to error.")
else:
    print("\nSkipping Test Evoked cluster test: Not enough epochs (min 2 per condition) in predicted test sets.")


# --- 2. Statistical Test for Time-Frequency Power Differences (Predicted Slow vs Fast on Test Data) ---
print("\n--> Comparing Theta Power (Predicted Slow vs Fast on Test Data)")

# Ensure TFR objects for test data predicted labels are available and have enough epochs
if tfr_fast_pred is not None and tfr_slow_pred is not None and len(tfr_fast_pred.data) > 1 and len(tfr_slow_pred.data) > 1:
    # Extract power data
    power_fast = tfr_fast_pred.data.astype(np.float64)
    power_slow = tfr_slow_pred.data.astype(np.float64)

    # Define connectivity (using the same spatial connectivity)
    # Find frequency indices for the theta band
    try:
        freqs = tfr_fast_pred.freqs
        if theta_freqs is None:
             print("theta_freqs not defined, cannot find indices for theta band.")
             theta_freq_indices = []
        else:
             # Ensure theta_freqs match the freqs in the TFR object calculation
             theta_freq_indices = np.where(np.logical_and(freqs >= theta_freqs.min(), freqs <= theta_freqs.max()))[0]
             if len(theta_freq_indices) == 0:
                 print(f"Warning: No frequencies in TFR object fall within the specified theta band ({theta_freqs.min()}-{theta_freqs.max()} Hz).")

    except Exception as e_freq:
         print(f"Could not get frequencies from Test TFR object or find theta band indices: {e_freq}")
         theta_freq_indices = []


    if len(theta_freq_indices) > 0:
         print(f"Averaging power over {len(theta_freq_indices)} frequencies in the theta band for Test TFR cluster test.")
         # Ensure we average over a valid axis
         if power_fast.ndim == 4 and power_fast.shape[2] > 0:
             power_fast_theta_avg = power_fast[:, :, theta_freq_indices, :].mean(axis=2) # Shape (n_epochs, n_channels, n_times)
             power_slow_theta_avg = power_slow[:, :, theta_freq_indices, :].mean(axis=2) # Shape (n_epochs, n_channels, n_times)
         else:
              print("Test TFR power data shape is not as expected for averaging over frequency (expected 4D).")
              power_fast_theta_avg = None
              power_slow_theta_avg = None
    else:
         print("Skipping Test TFR cluster test: No frequencies found in the specified theta band or issues getting freq indices.")
         power_fast_theta_avg = None
         power_slow_theta_avg = None


    if power_fast_theta_avg is not None and power_slow_theta_avg is not None:
         print(f"Running spatio-temporal cluster permutation test for average Test Theta power ({'with' if connectivity is not None else 'without'} spatial connectivity)...")

         try:
             # Use similar threshold and permutations
             threshold_tf = 2.0
             n_permutations_tf = 1000

             T_obs_tf_test, clusters_tf_test, p_values_tf_test, _ = mne.stats.spatio_temporal_cluster_test(
                 [power_slow_theta_avg, power_fast_theta_avg], # Slow vs Fast
                 n_permutations=n_permutations_tf,
                 threshold=threshold_tf,
                 tail=0, # Two-tailed
                 stat_fun='indep_t',
                 connectivity=connectivity, # Spatial connectivity (reused from evoked test)
                 n_jobs=-1,
                 verbose=False
             )

             print("Test Theta power cluster permutation test complete.")

             # --- Report significant clusters ---
             alpha = 0.05
             good_clusters_tf_test_idx = [i for i, p_val in enumerate(p_values_tf_test) if p_val < alpha]

             if good_clusters_tf_test_idx:
                 print(f"\nFound {len(good_clusters_tf_test_idx)} significant spatio-temporal clusters for Test Theta power (p < {alpha}):")
                 times_tf = tfr_fast_pred.times # Or epochs_test_pred_array.times
                 for i in good_clusters_tf_test_idx:
                     T_obs_this_cluster = T_obs_tf_test[clusters_tf_test[i]]
                     min_t = np.min(T_obs_this_cluster)
                     max_t = np.max(T_obs_this_cluster)
                     cluster_p_value = p_values_tf_test[i]

                     cluster_channels_idx, cluster_times_idx = clusters_tf_test[i]
                     unique_times_idx = np.unique(cluster_times_idx)
                     cluster_times = times_tf[unique_times_idx]
                     tmin_cluster, tmax_cluster = cluster_times.min(), cluster_times.max()

                     unique_channels_idx = np.unique(cluster_channels_idx)
                     cluster_channel_names = [tfr_fast_pred.ch_names[j] for j in unique_channels_idx] # Or epochs_test_pred_array.ch_names

                     print(f"  - Test Theta Cluster {i+1}: p-value = {cluster_p_value:.4f}")
                     print(f"    Time range: [{tmin_cluster:.3f} s, {tmax_cluster:.3f} s]")
                     print(f"    Channels ({len(unique_channels_idx)}) : {cluster_channel_names}")
                     print(f"    T-value range: [{min_t:.2f}, {max_t:.2f}]")
                     # Optional: You could add code here to visualize this cluster on the TFR plot for the test data

             else:
                 print(f"\nNo significant spatio-temporal clusters found for Test Theta power at p < {alpha}.")

         except Exception as e:
             print(f"An unexpected error occurred during the Test Theta power cluster test: {e}")
             print("Skipping Test Theta power cluster test due to error.")
    else:
         print("\nSkipping Test TFR cluster test: Data averaging failed or data not available.")

else:
     print("\nSkipping Test TFR cluster test: Not enough epochs (min 2 per condition) or TFR objects not available for predicted test sets.")


print("\n--- All Statistical Verification Complete ---")

In [None]:
import numpy as np
import mne
from mne.stats import spatio_temporal_cluster_test
from mne.channels import find_ch_adjacency

# 1) Compute the difference‑wave arrays
data_train = epochs_fast_train.get_data() - epochs_slow_train.get_data()
data_test  = epochs_fast_test.get_data()  - epochs_slow_test.get_data()

# 2) Build adjacency for EEG sensors
adjacency, _ = find_ch_adjacency(info, ch_type='eeg')

# 3) Cluster permutation test (train vs. test difference‑waves)
X = [data_train, data_test]  # list of two arrays: (n_epochs, n_channels, n_times)
T_obs, clusters, cluster_p_values, H0 = spatio_temporal_cluster_test(
    X,
    n_permutations=1000,
    adjacency=adjacency,
    tail=0,            # two‑tailed test
    n_jobs=1,
    out_type='mask'    # clusters as boolean masks
)

# 4) Report significant clusters
sig_idxs = np.where(cluster_p_values < 0.05)[0]
print("Significant cluster indices:", sig_idxs)

# 5) Plot the first significant cluster (if any)
if sig_idxs.size > 0:
    for clu in sig_idxs:
        mask = clusters[clu]  # shape: (n_channels, n_times)
        times = epochs_fast_train.times
        # Find peak time within this cluster
        t_inds = np.where(mask.any(axis=0))[0]
        t_peak = times[t_inds[np.argmax(np.abs(T_obs[:, t_inds]).mean(axis=0))]]
        # Topomap at peak
        mne.viz.plot_topomap(
            T_obs[:, times.tolist().index(t_peak)],
            info,
            mask=mask[:, times.tolist().index(t_peak)],
            show=True,
            title=f"Cluster {clu} at {t_peak*1000:.0f} ms"
        )
else:
    print("No significant train–test cluster differences found (p < 0.05).")
