# ERP Analysis Workflow
    
This notebook documents the workflow and code used for analyzing ERP data, specifically focusing on P300b and MMN difference waves.

## Setup and Imports

In [None]:
import mne
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from mne.stats import permutation_cluster_test
from scipy.ndimage import gaussian_filter1d
from scipy.stats import t

## Load Data

In [None]:
# Load the epochs
epochs = mne.read_epochs('42.fif')
print(epochs)

## Helper Functions

In [None]:
def calculate_consecutive_significant(p_values, threshold):
    sig_mask = p_values < threshold
    consecutive_count = np.zeros_like(sig_mask, dtype=int)
    
    for ch in range(sig_mask.shape[0]):
        count = 0
        for t in range(sig_mask.shape[1]):
            if sig_mask[ch, t]:
                count += 1
            else:
                count = 0
            consecutive_count[ch, t] = count
    
    return np.max(consecutive_count)

In [None]:
def calculate_and_plot_difference_wave(conditions1, conditions2, electrodes, time_window, title, find_negative=False):
    # Combine the epochs for the specified conditions
    epochs_cond1 = mne.concatenate_epochs([epochs[cond] for cond in conditions1]).pick(electrodes)
    epochs_cond2 = mne.concatenate_epochs([epochs[cond] for cond in conditions2]).pick(electrodes)
    
    evoked_cond1 = epochs_cond1.average()
    evoked_cond2 = epochs_cond2.average()
    
    diff_data = evoked_cond1.data - evoked_cond2.data
    diff_wave = mne.EvokedArray(diff_data, evoked_cond1.info, tmin=evoked_cond1.times[0])
    
    time_mask = (diff_wave.times >= time_window[0]) & (diff_wave.times <= time_window[1])
    data_window = diff_wave.data[:, time_mask]
    mean_data_window = data_window.mean(axis=0)
    
    if find_negative:
        peak_amplitude = np.min(mean_data_window)
        peak_latency = diff_wave.times[time_mask][np.argmin(mean_data_window)]
    else:
        peak_amplitude = np.max(mean_data_window)
        peak_latency = diff_wave.times[time_mask][np.argmax(mean_data_window)]
    
    # Perform sample-by-sample Welch's t-tests
    t_values, p_values = stats.ttest_ind(epochs_cond1.get_data(), epochs_cond2.get_data(), 
                                         axis=0, equal_var=False)
    
    # Calculate consecutive significant samples
    max_consecutive_005 = calculate_consecutive_significant(p_values, 0.05)
    max_consecutive_001 = calculate_consecutive_significant(p_values, 0.01)
    
    # Plot the difference wave
    plt.figure(figsize=(10, 6))
    mean_diff = diff_wave.data.mean(axis=0)
    plt.plot(diff_wave.times, mean_diff * 1e6, 'b-', linewidth=2)  # Convert to microvolts
    plt.axvline(peak_latency, color='r', linestyle='--', label='Peak Latency: ' + str(round(peak_latency, 3)) + 's')
    plt.axhline(0, color='k', linestyle='-', linewidth=0.5)
    plt.title(title)
    plt.xlabel('Time (s)')
    plt.ylabel('Voltage (µV)')
    plt.legend()
    plt.grid(True, linestyle=':', alpha=0.6)
    
    plt.text(0.02, 0.98, 'Peak Amplitude: ' + str(round(peak_amplitude*1e6, 3)) + ' µV', 
             transform=plt.gca().transAxes, verticalalignment='top')
    plt.text(0.02, 0.93, 'Peak Latency: ' + str(round(peak_latency, 3)) + 's', 
             transform=plt.gca().transAxes, verticalalignment='top')
    
    plt.tight_layout()
    plt.savefig(title.lower().replace(" ", "_") + '.png', dpi=300)
    plt.close()
    
    print(title + " Results:")
    print("Peak Amplitude: " + str(round(peak_amplitude*1e6, 3)) + " µV")
    print("Peak Latency: " + str(round(peak_latency, 3)) + "s")
    print("Max consecutive significant samples (p < 0.05): " + str(max_consecutive_005))
    print("Max consecutive significant samples (p < 0.01): " + str(max_consecutive_001))

## MMN Analysis

In [None]:
# Calculate and plot MMN (negative peak)
mmn_electrodes = ['F3', 'Fz', 'F4', 'C3', 'Cz', 'C4']
mmn_time_window = (0.1, 0.25)
calculate_and_plot_difference_wave(['LDGS', 'LDGD'], ['LSGS', 'LSGD'], mmn_electrodes, mmn_time_window, 
                                   'MMN Difference Wave (LDGS + LDGD) vs (LSGS + LSGD)', find_negative=True)

## P300b Analysis

In [None]:
# Calculate and plot P300b (positive peak)
p300b_electrodes = ['C3', 'Cz', 'C4', 'P3', 'Pz', 'P4']
p300b_time_window = (0.25, 0.7)
calculate_and_plot_difference_wave(['LSGD', 'LDGD'], ['LDGS', 'LSGS'], p300b_electrodes, p300b_time_window, 
                                   'P300b Difference Wave (LSGD + LDGD) vs (LDGS + LSGS)')

## Cluster Permutation Analysis

In [None]:
def cluster_permutation_analysis_adjusted(conditions1, conditions2, electrodes, time_window, title, 
                                          n_permutations=10000, alpha=0.05, tail=1):
    # Combine epochs for each set of conditions
    epochs_cond1 = mne.concatenate_epochs([epochs[cond] for cond in conditions1]).pick(electrodes)
    epochs_cond2 = mne.concatenate_epochs([epochs[cond] for cond in conditions2]).pick(electrodes)

    # Extract data and create time mask
    data1 = epochs_cond1.get_data()
    data2 = epochs_cond2.get_data()
    times = epochs_cond1.times
    time_mask = (times >= time_window[0]) & (times <= time_window[1])

    # Apply time window mask
    data1 = data1[:, :, time_mask]
    data2 = data2[:, :, time_mask]
    times = times[time_mask]

    # Apply Gaussian smoothing
    data1_smoothed = gaussian_filter1d(data1, sigma=2, axis=2)
    data2_smoothed = gaussian_filter1d(data2, sigma=2, axis=2)

    # Calculate degrees of freedom and t-threshold
    n_samples = data1.shape[0] + data2.shape[0]
    df = n_samples - 2

    if tail == 1:
        t_threshold = t.ppf(1 - alpha, df)
    elif tail == -1:
        t_threshold = -t.ppf(1 - alpha, df)
    else:
        raise ValueError("Invalid tail parameter. Use 1 for positive or -1 for negative.")

    print("Using t-threshold: " + str(round(t_threshold, 3)) + " for alpha=" + str(alpha) + 
          ", df=" + str(df) + ", tail=" + str(tail))

    # Perform cluster permutation test
    t_obs, clusters, cluster_pv, H0 = permutation_cluster_test(
        [data1_smoothed, data2_smoothed], n_permutations=n_permutations, threshold=t_threshold, 
        tail=tail, n_jobs=1, verbose=True)

    # Plot results
    plt.figure(figsize=(10, 6))
    plt.plot(times, np.mean(t_obs, axis=0), label='Observed t-values')
    plt.axhline(0, color='k', linestyle='--', linewidth=0.5)
    plt.title(title)
    plt.xlabel('Time (s)')
    plt.ylabel('t-value')
    plt.grid(True, linestyle=':', alpha=0.6)
    plt.legend()
    plt.tight_layout()
    plt.savefig(title.lower().replace(" ", "_") + '_cluster_permutation.png', dpi=300)
    plt.close()

    print("Cluster Permutation Analysis Results:")
    print("Number of clusters: " + str(len(clusters)))
    print("Cluster p-values: " + str(cluster_pv))

In [None]:
# Perform cluster permutation analysis for MMN
cluster_permutation_analysis_adjusted(['LDGS', 'LDGD'], ['LSGS', 'LSGD'], mmn_electrodes, mmn_time_window, 
                                      'MMN Cluster Permutation Analysis', tail=-1)

# Perform cluster permutation analysis for P300b
cluster_permutation_analysis_adjusted(['LSGD', 'LDGD'], ['LDGS', 'LSGS'], p300b_electrodes, p300b_time_window, 
                                      'P300b Cluster Permutation Analysis', tail=1)