In [None]:
# Modified Microstate Analysis Script for Combined Go/NoGo & Pre/Post Comparisons
# Supports: Pre_combined vs Post_combined

import os
import numpy as np
import pandas as pd
import mne
from mne.preprocessing import ICA
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
from scipy.signal import find_peaks

# Parameters
low_freq = 1.0
high_freq = 40.0
notch_freq = 50.0
ica_n_components = 10
ica_random_state = 42
montage_name = "standard_1020"
min_duration = 150
resample_freq = 250
n_microstates = 4
input_dir = r"D:\\Microstate_Payal\\ERP"
output_dir = r"D:\\Microstate_Payal\\ERP\\ERP_Processed"
os.makedirs(output_dir, exist_ok=True)

# Event IDs
event_ids = {'Go': 'Go', 'NoGo': 'NoGo'}

# Data containers
conditions = [f"{sess}_combined" for sess in ["pre", "post"]]
group_templates = {cond: [] for cond in conditions}
group_transitions = {cond: [] for cond in conditions}
metrics_records = []


def compute_transition_matrix(labels, n_microstates):
    matrix = np.zeros((n_microstates, n_microstates))
    for a, b in zip(labels[:-1], labels[1:]):
        matrix[a, b] += 1
    row_sums = matrix.sum(axis=1, keepdims=True)
    return matrix / np.where(row_sums == 0, 1, row_sums)


def preprocess_and_epoch(raw):
    if any(ch.startswith("EEG ") for ch in raw.ch_names):
        rename_mapping = {ch: ch.replace("EEG ", "") for ch in raw.ch_names if ch.startswith("EEG ")}
        raw.rename_channels(rename_mapping)

    channels_to_drop = ['E', 'EEG AFp3', 'EEG AFp4', 'EEG FTT7h', 'EEG FTT8h',
                        'EEG T9', 'EEG T10', 'EEG P9', 'EEG P10', 'EEG PO5', 'EEG PO6',
                        'EEG Nz', 'EEG Iz', 'EEG I1', 'EEG I2', 'EEG AF5', 'EEG AF6']
    raw.drop_channels([ch for ch in channels_to_drop if ch in raw.ch_names])

    raw.set_montage(montage_name, match_case=False, on_missing='ignore')
    raw.filter(l_freq=low_freq, h_freq=high_freq)
    raw.notch_filter(freqs=notch_freq)

    ica = ICA(n_components=ica_n_components, random_state=ica_random_state, max_iter="auto")
    ica.fit(raw)

    eog_channels = ['Fp1', 'Fp2', 'Fpz']
    for ch in eog_channels:
        if ch in raw.ch_names:
            raw.set_channel_types({'Fp1': 'eog', 'Fp2': 'eog'})

    try:
        eog_inds, _ = ica.find_bads_eog(raw, threshold=3.0)
        ica.exclude = eog_inds
        print(f"✅ Excluding {len(eog_inds)} EOG-related component(s)")
    except Exception as e:
        print(f"⚠️ Skipping EOG artifact removal: {e}")

    raw = ica.apply(raw)
    raw.resample(sfreq=resample_freq)
    return raw


def plot_topomap_with_standard_channels(template, info):
    standard_montage = mne.channels.make_standard_montage("standard_1020")
    valid_chs = [ch for ch in info['ch_names'] if ch in standard_montage.ch_names]
    picks = mne.pick_channels(info['ch_names'], include=valid_chs)
    info_subset = mne.create_info([info['ch_names'][i] for i in picks], sfreq=info['sfreq'], ch_types='eeg')
    info_subset.set_montage(standard_montage)
    return template[:, picks], info_subset


def save_average_topomap(templates_list, info, condition):
    if len(templates_list) == 0 or info is None:
        print(f"⚠️ No templates or info to average for {condition}")
        return None
    avg_templates = np.mean(templates_list, axis=0)
    templates_subset, info_subset = plot_topomap_with_standard_channels(avg_templates, info)
    fig, axes = plt.subplots(1, n_microstates, figsize=(4 * n_microstates, 3))
    for i, ax in enumerate(axes):
        mne.viz.plot_topomap(templates_subset[i], info_subset, axes=ax, show=False)
        ax.set_title(f'Microstate {chr(65 + i)}')
    plt.suptitle(f'Average Microstate Topographies - {condition}', y=1.05)
    plt.savefig(os.path.join(output_dir, f'Average_{condition}_Topomap.png'), dpi=300, bbox_inches='tight')
    plt.close()
    return avg_templates


def save_average_transition(transitions, condition):
    if len(transitions) == 0:
        print(f"⚠️ No transitions to average for {condition}")
        return None
    avg_matrix = np.mean(transitions, axis=0)
    df = pd.DataFrame(avg_matrix, index=[chr(65+i) for i in range(n_microstates)],
                      columns=[chr(65+i) for i in range(n_microstates)])
    plt.figure(figsize=(6, 5))
    sns.heatmap(df, annot=True, cmap='Blues', fmt=".2f")
    plt.title(f'Average Transition Matrix - {condition}')
    plt.savefig(os.path.join(output_dir, f'Average_{condition}_TransitionMatrix.png'), dpi=300, bbox_inches='tight')
    plt.close()
    return avg_matrix


def save_difference_map(template1, template2, info, label):
    if template1 is None or template2 is None or info is None:
        print(f"⚠️ Skipping difference map for {label} due to missing data")
        return
    diff_templates = template2 - template1
    templates_subset, info_subset = plot_topomap_with_standard_channels(diff_templates, info)
    fig, axes = plt.subplots(1, n_microstates, figsize=(4 * n_microstates, 3))
    for i, ax in enumerate(axes):
        mne.viz.plot_topomap(templates_subset[i], info_subset, axes=ax, show=False)
        ax.set_title(f'Diff {chr(65 + i)}')
    plt.suptitle(f'Topography Difference ({label})', y=1.05)
    plt.savefig(os.path.join(output_dir, f'Topomap_Difference_{label}.png'), dpi=300, bbox_inches='tight')
    plt.close()


def save_difference_transition(mat1, mat2, label):
    if mat1 is None or mat2 is None:
        print(f"⚠️ Skipping difference matrix for {label} due to missing data")
        return
    diff_matrix = mat2 - mat1
    df = pd.DataFrame(diff_matrix, index=[chr(65+i) for i in range(n_microstates)],
                      columns=[chr(65+i) for i in range(n_microstates)])
    plt.figure(figsize=(6, 5))
    sns.heatmap(df, annot=True, cmap='RdBu_r', fmt=".2f", center=0)
    plt.title(f'Transition Matrix Difference ({label})')
    plt.savefig(os.path.join(output_dir, f'TransitionMatrix_Difference_{label}.png'), dpi=300, bbox_inches='tight')
    plt.close()


# Main execution
if __name__ == "__main__":
    info_for_avg = None
    for file in os.listdir(input_dir):
        if not file.endswith(".edf"):
            continue
        name = file.replace(".edf", "")
        subject, session_code = name.split("_")
        session = "pre" if session_code == "s0" else "post"
        file_path = os.path.join(input_dir, file)
        print(f"\U0001F504 Processing {subject} - Session: {session}")

        try:
            raw = mne.io.read_raw_edf(file_path, preload=True)
            if raw.times[-1] < min_duration:
                continue
            raw = preprocess_and_epoch(raw)
            events, event_dict = mne.events_from_annotations(raw)
            valid_event_codes = [event_dict[e] for e in event_ids.values() if e in event_dict]
            combined_events = events[np.isin(events[:, 2], valid_event_codes)]

            if len(combined_events) == 0:
                print(f"⚠️ No Go/NoGo events found for {subject} - {session}")
                continue

            condition = f"{session}_combined"
            epochs = mne.Epochs(raw, events=combined_events,
                                event_id={k: event_dict[k] for k in event_ids.values() if k in event_dict},
                                tmin=-0.2, tmax=0.8, baseline=(-0.2, 0), preload=True, detrend=1)

            data = epochs.get_data()
            sfreq = epochs.info['sfreq']
            reshaped = data.transpose(0, 2, 1).reshape(-1, data.shape[1])
            gfp = np.std(reshaped, axis=1)
            peaks, _ = find_peaks(gfp)
            if len(peaks) == 0:
                print(f"⚠️ No GFP peaks for {subject} - {condition}")
                continue
            maps = reshaped[peaks]

            kmeans = KMeans(n_clusters=n_microstates, random_state=ica_random_state).fit(maps)
            templates = kmeans.cluster_centers_
            labels = pairwise_distances_argmin(reshaped, templates).reshape(data.shape[0], data.shape[2])
            full_labels = labels.ravel()
            total_time = len(full_labels) / sfreq

            trans_matrix = compute_transition_matrix(full_labels, n_microstates)
            group_templates[condition].append(templates)
            group_transitions[condition].append(trans_matrix)

            for i in range(n_microstates):
                durations = [len(list(g)) for k, g in itertools.groupby(full_labels) if k == i]
                occurrence = len(durations) / total_time
                coverage = np.sum(full_labels == i) / len(full_labels)
                duration_ms = np.mean(durations) * 1000 / sfreq if durations else 0
                metrics_records.append({
                    'Subject': subject,
                    'Condition': condition,
                    'Microstate': chr(65 + i),
                    'Duration_ms': duration_ms,
                    'Occurrence_per_sec': occurrence,
                    'Coverage_percent': coverage * 100
                })

            if info_for_avg is None:
                info_for_avg = epochs.info.copy()

        except Exception as e:
            print(f"❌ Error processing {file}: {e}")

    metrics_df = pd.DataFrame(metrics_records)
    metrics_csv_path = os.path.join(output_dir, 'Microstate_Metrics_SubjectWise.csv')
    metrics_df.to_csv(metrics_csv_path, index=False)
    print(f"✅ Subject-wise microstate metrics saved to: {metrics_csv_path}")

    for condition in conditions:
        avg_temp = save_average_topomap(group_templates[condition], info_for_avg, condition)
        avg_tran = save_average_transition(group_transitions[condition], condition)
        group_templates[condition] = avg_temp
        group_transitions[condition] = avg_tran

    save_difference_map(group_templates['pre_combined'], group_templates['post_combined'], info_for_avg, 'Pre_vs_Post_Combined')
    save_difference_transition(group_transitions['pre_combined'], group_transitions['post_combined'], 'Pre_vs_Post_Combined')