In [1]:
import numpy as np
from scipy.signal import detrend
from sklearn.utils import resample
from scipy.stats import mode
from itertools import groupby
import mne
from mne.time_frequency import psd_array_multitaper
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import os
import pandas as pd
from datetime import datetime
from scipy.integrate import simpson
import xml.etree.ElementTree as ET

In [2]:
def get_iso(edf, sleep_stages, artifact_indicator, ch_groups):
    """
    Computes the infraslow oscillation (ISO) spectrum from EEG data during Non-REM (NREM) sleep epochs.

    Parameters
    ----------
    edf : mne.io.Raw
        The EEG data object loaded using MNE, containing EEG signals and sampling frequency.
        
    sleep_stages : np.ndarray, shape (n_samples,)
        Numeric array indicating sleep stages at each EEG sample. Expected coding:
        0 = NREM, 1 = REM, 2 = Wake, 3 = Artifact.
        (Sleep stages will be resampled internally to match EEG data length.)
        
    artifact_indicator : np.ndarray, shape (n_samples,)
        Binary array (1-dimensional) indicating artifacts at each EEG sample (1 = artifact, 0 = clean).
        (Artifacts are resampled internally to match EEG data length.)

    ch_groups : list of list of str
        List defining channel groupings for analysis. Each inner list contains channel names to average.

    Returns
    -------
    spec_iso : np.ndarray, shape (n_channels, 3, n_freqs)
        ISO spectrum, separately computed for sigma_all, sigma_slow, and sigma_fast bands, for each channel group.

    freq_iso : np.ndarray, shape (n_freqs,)
        Frequency bins corresponding to ISO power spectra, typically in the infraslow range (0–0.1 Hz).
    """
    eeg = edf.get_data() * 1e6
    original_fs = edf.info['sfreq']
    fs = original_fs
 
    # Segment EEG into epochs
    window_size = int(round(4 * fs))
    step_size = int(round(2 * fs))
    epoch_time = 2
    start_ids = np.arange(0, eeg.shape[1] - window_size + 1, step_size)
    epochs = np.array([eeg[:, x:x + window_size] for x in start_ids])

    # Assign sleep stages and artifact indicators per epoch (identical logic to original function)
    sleep_stages = np.array([mode(sleep_stages[x:x + window_size], keepdims=False).mode for x in start_ids])
    artifact_indicator = np.array([artifact_indicator[x:x + window_size].any(axis=0) for x in start_ids])

    sleep_ids = np.where(sleep_stages == 0)[0]
    if len(sleep_ids) == 0:
        raise ValueError("No NREM epochs found.")
    start = sleep_ids[0]
    end = sleep_ids[-1] + 1
    epochs = epochs[start:end]
    sleep_stages = sleep_stages[start:end]
    artifact_indicator = artifact_indicator[start:end]

    # Detrend epochs
    epochs = detrend(epochs, axis=-1)

    # Calculate sigma-band power
    spec, freq = psd_array_multitaper(
        epochs, sfreq=fs, fmin=11, fmax=15, bandwidth=1,
        normalization='full', remove_dc=True, verbose=False
    )
    spec[np.isinf(spec)] = np.nan
    dfreq = freq[1] - freq[0]

    sigma_db = 10 * np.log10(np.sum(spec, axis=-1) * dfreq).T
    sigma_db[np.isinf(sigma_db)] = np.nan
    fs_sigma = 1 / epoch_time

    sigma_db = np.array([np.nanmean(sigma_db[[edf.ch_names.index(x) for x in xx]], axis=0) for xx in ch_groups])

    # Artifact_indicator is 1-dimensional; original assumed 2D channel-specific artifacts
    artifact_indicator2 = np.array([artifact_indicator for _ in ch_groups])

    window_size = int(round(256 * fs_sigma))
    step_size = int(round(64 * fs_sigma))
    freq_iso = np.linspace(0, 0.1, 101)[1:]
    spec_iso_all_ch = []

    for chi in range(len(ch_groups)):
        good_ids = (sleep_stages == 0) & (~artifact_indicator2[chi])
        spec_isos = []
        cc = 0
        for k, l in groupby(good_ids):
            ll = len(list(l))
            if not k:
                cc += ll
                continue
            for start in np.arange(cc, cc + ll - window_size + 1, step_size):
                xx = np.array([
                    sigma_db[chi, start:start + window_size],
                ])

                # Skip empty or all-NaN epochs explicitly
                if np.isnan(xx).all() or np.all(xx == 0):
                    continue

                xx = detrend(xx, axis=-1)
                spec_iso, freq_out = psd_array_multitaper(
                    xx, fs_sigma, fmin=0, fmax=0.2, bandwidth=0.01,
                    normalization='full', verbose=False
                )

                ff = interp1d(freq_out, spec_iso, axis=-1, bounds_error=False, fill_value='extrapolate')
                spec_iso = ff(freq_iso)

                spec_iso /= spec_iso.sum(axis=-1, keepdims=True)
                spec_isos.append(spec_iso)
            cc += ll

        # Explicit handling if no valid epochs found
        if len(spec_isos) == 0:
            print(f"No valid epochs for channel {ch_groups[chi]}. Returning NaN.")
            spec_iso_ch = np.full((1, len(freq_iso)), np.nan)
        else:
            spec_iso_ch = np.nanmean(np.array(spec_isos), axis=0)

        spec_iso_all_ch.append(spec_iso_ch)

    spec_iso = np.array(spec_iso_all_ch)
    return spec_iso, freq_iso


def plot_iso_single_subject_save(iso_df, subject_id, save_path):
    """
    Generate a single combined ISO plot for all channels and save the plot image.

    Parameters
    ----------
    iso_df : pd.DataFrame
        DataFrame containing ISO data for a single subject.

    subject_id : str
        Subject identifier used in the plot title.

    save_path : str
        Path (including filename) to save the generated plot.

    Returns
    -------
    None
    """
    freq = iso_df['frequency']
    channels = [col.split('_sigma')[0] for col in iso_df.columns if 'sigma_all' in col]

    plt.figure(figsize=(10, 6))

    # Plot each channel explicitly on the same plot
    for ch in channels:
        plt.plot(freq, iso_df[f'{ch}_sigma_all'], label=f'{ch}')

    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Relative Power')
    plt.title(f'Combined ISO Spectrum for Subject: {subject_id}')
    plt.legend(loc='best', fontsize='small', ncol=2)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()

    # Save explicitly
    plt.savefig(save_path, dpi=300)
    plt.close()

    print(f"Plot explicitly saved to: {save_path}")


def load_custom_montage(montage_path):
    """
    Load EEG channel montage from an XML file.

    Parameters
    ----------
    montage_path : str
        Path to the EEG montage XML file.

    Returns
    -------
    montage : mne.channels.DigMontage
        Custom montage object.
    """
    tree = ET.parse(montage_path)
    root = tree.getroot()
    namespace = {'ns': 'http://www.egi.com/coordinates_mff'}

    sensor_elements = root.findall('.//ns:sensor', namespaces=namespace)

    positions = []
    channel_names = []

    for sensor in sensor_elements:
        sensor_type = sensor.find('ns:type', namespaces=namespace).text
        if sensor_type == '0':
            number = int(sensor.find('ns:number', namespaces=namespace).text)
            x = float(sensor.find('ns:x', namespaces=namespace).text)
            y = float(sensor.find('ns:y', namespaces=namespace).text)
            z = float(sensor.find('ns:z', namespaces=namespace).text)

            positions.append([x, y, z])
            channel_names.append(f'EEG {number}')

    ch_pos = {name: pos for name, pos in zip(channel_names, positions)}
    montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame='head')

    return montage


def plot_iso_topomap(edf_raw, spec_iso, freq_iso, subject_id, montage_path, save_path):
    """
    Generate and save topographical plots for ISO band power and peak frequency.

    Parameters
    ----------
    edf_raw : mne.io.Raw
        EEG data loaded using MNE.
    
    spec_iso : np.ndarray, shape (n_channels, 1, n_freqs)
        ISO spectrum computed by get_iso function.
    
    freq_iso : np.ndarray, shape (n_freqs,)
        ISO frequency bins.
    
    subject_id : str
        Identifier for the subject, used in plot titles and filenames.

    montage_path : str
        File path to the EEG channel montage (XML file).

    save_path : str
        File path to explicitly save the generated topomap figure.

    Returns
    -------
    None
    """
    # Load montage from XML
    tree = ET.parse(montage_path)
    root = tree.getroot()
    namespace = {'ns': 'http://www.egi.com/coordinates_mff'}

    sensor_elements = root.findall('.//ns:sensor', namespaces=namespace)

    positions = []
    channel_names = []

    for sensor in sensor_elements:
        sensor_type = sensor.find('ns:type', namespaces=namespace).text
        if sensor_type == '0':
            number = int(sensor.find('ns:number', namespaces=namespace).text)
            x = float(sensor.find('ns:x', namespaces=namespace).text)
            y = float(sensor.find('ns:y', namespaces=namespace).text)
            z = float(sensor.find('ns:z', namespaces=namespace).text)

            positions.append([x, y, z])
            channel_names.append(f'EEG {number}')

    ch_pos = {name: pos for name, pos in zip(channel_names, positions)}
    montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame='head')

    edf_raw = edf_raw.copy().pick(picks=channel_names)
    edf_raw.set_montage(montage, on_missing='warn')

    # Compute ISO metrics
    iso_band = (freq_iso >= 0.005) & (freq_iso <= 0.03)
    iso_band_power = np.trapz(spec_iso[:, 0, iso_band], freq_iso[iso_band], axis=-1)
    iso_peak_freq = freq_iso[np.argmax(spec_iso[:, 0, :], axis=-1)]

    # Plotting without explicit scaling adjustments
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    mne.viz.plot_topomap(
        iso_band_power, edf_raw.info, axes=axes[0], cmap='viridis', show=False
    )
    axes[0].set_title('ISO Band Power (0.005–0.03 Hz)')

    mne.viz.plot_topomap(
        iso_peak_freq, edf_raw.info, axes=axes[1], cmap='plasma', show=False
    )
    axes[1].set_title('ISO Peak Frequency (Hz)')

    plt.suptitle(f'Subject: {subject_id}', fontsize=14)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    # Explicitly save the figure
    plt.savefig(save_path, dpi=300)
    plt.close(fig)
    print(f"Topomap explicitly saved to: {save_path}")


def bootstrap_ci(data, ci=95, n_bootstrap=1000):
    bootstrapped_means = np.array([
        np.nanmean(resample(data, replace=True, n_samples=len(data)), axis=0)
        for _ in range(n_bootstrap)
    ])
    lower_bound = np.percentile(bootstrapped_means, (100 - ci) / 2, axis=0)
    upper_bound = np.percentile(bootstrapped_means, 100 - (100 - ci) / 2, axis=0)
    mean_spectrum = np.nanmean(data, axis=0)
    return mean_spectrum, lower_bound, upper_bound


def plot_group_iso_ci(
    iso_results_df, freq_iso, channel, group_col='group', ci=95, save_path=None
):
    """
    Generate and save ISO spectra plot with mean and CI for ASD and TD groups for a single channel.

    Parameters
    ----------
    iso_results_df : pd.DataFrame
        ISO results dataframe including subject and group data.
    freq_iso : np.ndarray
        ISO frequency bins.
    channel : str
        EEG channel name.
    group_col : str
        Column name indicating group labels.
    ci : int
        Confidence interval percentage.
    save_path : str or None
        File path to save the plot image.

    Returns
    -------
    None
    """
    channel_cols = [
        col for col in iso_results_df.columns if col.startswith(f'{channel}_sigma_all_')
    ]

    plt.figure(figsize=(10, 6))
    groups = iso_results_df[group_col].dropna().unique()

    for group in groups:
        group_data = iso_results_df[iso_results_df[group_col] == group][channel_cols].dropna()

        if group_data.empty or np.isnan(group_data.values).all():
            print(f"Warning: No valid data for group '{group}' in channel '{channel}'. Skipping.")
            continue

        mean_spectrum, lower_ci, upper_ci = bootstrap_ci(group_data.values, ci=ci)

        plt.plot(freq_iso, mean_spectrum, label=f'{group} Mean')
        plt.fill_between(freq_iso, lower_ci, upper_ci, alpha=0.3, label=f'{group} {ci}% CI')

    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Relative Power')
    plt.title(f'ISO Spectrum with {ci}% CI - {channel}')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300)
        print(f"Group plot saved explicitly to {save_path}")

    plt.close()


def load_eeg_data(base_path):
    """
    Load EEG data (EDF files), sleep stage labels, and artifact indicators from a structured directory.

    The function reads EEG recordings in EDF format along with their corresponding sleep stage and 
    artifact indicator files (stored as `.npz`) for two participant groups: Autism Spectrum Disorder (ASD)
    and Typically Developing (TD). Each group's data should reside within its own subfolder.

    Directory structure should follow this format explicitly:

    base_path/
    ├── edf/
    │   ├── ASD/
    │   └── TD/
    ├── sleep_stages/
    │   ├── ASD/
    │   └── TD/
    └── artifacts/
        ├── ASD/
        └── TD/

    Parameters
    ----------
    base_path : str
        Path to the root directory containing the 'edf', 'sleep_stages', and 'artifacts' subdirectories.

    Returns
    -------
    data : list of dict
        A list of dictionaries, each containing the following keys:
            - 'subject_id': str, subject identifier extracted from the EDF filename.
            - 'group': str, participant group ('ASD' or 'TD').
            - 'edf': mne.io.Raw, raw EEG data object loaded with MNE (not preloaded).
            - 'sleep_stages': np.ndarray, sleep stage labels loaded from the corresponding `.npz` file.
            - 'artifact_indicator': np.ndarray, binary artifact indicators loaded from the corresponding `.npz` file.
    """
    groups = ['ASD', 'TD']
    data = []

    for group in groups:
        edf_path = os.path.join(base_path, 'edf', group)
        stages_path = os.path.join(base_path, 'sleep_stages', group)
        artifact_path = os.path.join(base_path, 'artifacts', group)

        edf_files = sorted([f for f in os.listdir(edf_path) if f.lower().endswith('.edf')])

        for edf_file in edf_files:
            subject_id = os.path.splitext(edf_file)[0]

            stages_file = os.path.join(stages_path, f'sleep_stages_{subject_id}.npz')
            artifact_file = os.path.join(artifact_path, f'artifacts_{subject_id}.npz')
            edf_file_full = os.path.join(edf_path, edf_file)

            if not os.path.exists(stages_file):
                print(f"Missing sleep stage file for {subject_id}, skipping...")
                continue
            if not os.path.exists(artifact_file):
                print(f"Missing artifact file for {subject_id}, skipping...")
                continue

            edf_raw = mne.io.read_raw_edf(edf_file_full, preload=False, verbose=False)

            # Load the first available array from npz files automatically
            with np.load(stages_file) as stages_npz:
                sleep_stages = stages_npz[stages_npz.files[0]]

            with np.load(artifact_file) as artifacts_npz:
                artifact_indicator = artifacts_npz[artifacts_npz.files[0]]

            data.append({
                'subject_id': subject_id,
                'group': group,
                'edf': edf_raw,
                'sleep_stages': sleep_stages,
                'artifact_indicator': artifact_indicator
            })

            print(f"Loaded subject {subject_id} from group {group}")

    return data

In [3]:
def run_iso_analysis(
    data_list,
    montage_paths={'32': 'eeg_ch_coords/coordinates32.xml', '64': 'eeg_ch_coords/coordinates64.xml'},
    output_base_path='./iso_analysis_outputs'
):
    """
    Full pipeline to run ISO analysis on all subjects, saving all plots and metrics explicitly.

    Parameters
    ----------
    data_list : list of dict
        EEG data loaded with load_eeg_data function.
        
    montage_paths : dict
        Dictionary with keys '32' and '64' pointing to corresponding montage XML files.

    output_base_path : str
        Directory path to save results explicitly.

    Returns
    -------
    iso_results_df : pd.DataFrame
        ISO spectral results DataFrame.

    additional_metrics_df : pd.DataFrame
        ISO band power and peak frequency DataFrame.
    """
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    output_folder = os.path.join(output_base_path, f'iso_results_{timestamp}')
    os.makedirs(output_folder, exist_ok=True)
    print(f"Results explicitly saved to: {output_folder}")

    iso_results = []
    additional_metrics = []
    freq_iso = None

    for idx, subject in enumerate(data_list):
        subject_id = subject['subject_id']
        group = subject['group']
        print(f"\n[{idx+1}/{len(data_list)}] Processing subject: {subject_id} ({group})")

        try:
            edf_raw = subject['edf'].copy().load_data()
            sleep_stages = subject['sleep_stages']
            artifact_indicator = subject['artifact_indicator']

            # Choose correct montage explicitly
            n_channels = len(edf_raw.ch_names) - ('EEG VREF' in edf_raw.ch_names)
            montage_key = '64' if n_channels > 32 else '32'
            montage_path = montage_paths[montage_key]

            montage = load_custom_montage(montage_path)
            edf_raw.pick(picks=[ch for ch in edf_raw.ch_names if ch != 'EEG VREF'])
            edf_raw.set_montage(montage, on_missing='warn')

            ch_groups = [[ch] for ch in edf_raw.ch_names]

            # Run ISO analysis explicitly
            spec_iso, freq_iso = get_iso(edf_raw, sleep_stages, artifact_indicator, ch_groups)

            # Save ISO spectrum data explicitly
            iso_dict = {'subject_id': subject_id, 'group': group}
            for ch_idx, ch_name in enumerate(edf_raw.ch_names):
                for f_idx, freq in enumerate(freq_iso):
                    col_label = f'{ch_name}_sigma_all_{freq:.4f}Hz'
                    iso_dict[col_label] = spec_iso[ch_idx, 0, f_idx]
            iso_results.append(iso_dict)

            # ISO spectrum plot explicitly saved
            iso_df_single_subject = pd.DataFrame({'frequency': freq_iso})
            for ch_idx, ch_name in enumerate(edf_raw.ch_names):
                iso_df_single_subject[f'{ch_name}_sigma_all'] = spec_iso[ch_idx, 0, :]

            spectrum_plot_path = os.path.join(output_folder, f'{subject_id}_iso_spectrum.png')
            plot_iso_single_subject_save(iso_df_single_subject, subject_id, spectrum_plot_path)

            # ISO topomap explicitly saved
            topomap_plot_path = os.path.join(output_folder, f'{subject_id}_iso_topomaps.png')
            plot_iso_topomap(
                edf_raw, spec_iso, freq_iso, subject_id, montage_path, topomap_plot_path
            )

            # Additional metrics explicitly computed and saved
            band_power_dict = {'subject_id': subject_id, 'group': group}
            peak_freq_dict = {'subject_id': subject_id, 'group': group}
            iso_band = (freq_iso >= 0.005) & (freq_iso <= 0.03)

            for ch_idx, ch_name in enumerate(edf_raw.ch_names):
                iso_power = spec_iso[ch_idx, 0, :]
                auc_band_power = np.trapz(iso_power[iso_band], freq_iso[iso_band])
                peak_freq = freq_iso[np.argmax(iso_power)]

                band_power_dict[f'{ch_name}_ISO_bandpower_0.005-0.03Hz'] = auc_band_power
                peak_freq_dict[f'{ch_name}_ISO_peak_frequency'] = peak_freq

            additional_metrics.append({**band_power_dict, **peak_freq_dict})

        except Exception as e:
            print(f"Error processing subject {subject_id}: {e}")

    # Save ISO results explicitly
    iso_results_df = pd.DataFrame(iso_results)
    iso_results_csv = os.path.join(output_folder, 'iso_analysis_all_subjects.csv')
    iso_results_df.to_csv(iso_results_csv, index=False)
    print(f"\nISO spectra explicitly saved: {iso_results_csv}")

    # Save additional metrics explicitly
    additional_metrics_df = pd.DataFrame(additional_metrics)
    metrics_csv = os.path.join(output_folder, 'iso_additional_metrics.csv')
    additional_metrics_df.to_csv(metrics_csv, index=False)
    print(f"Additional ISO metrics explicitly saved: {metrics_csv}")

    # Generate group-wise ISO plots explicitly
    group_plot_dir = os.path.join(output_folder, 'group_plots')
    os.makedirs(group_plot_dir, exist_ok=True)

    for channel in edf_raw.ch_names:
        group_plot_path = os.path.join(group_plot_dir, f'ISO_CI_{channel}.png')
        plot_group_iso_ci(iso_results_df, freq_iso, channel, save_path=group_plot_path)

    return iso_results_df, additional_metrics_df

In [4]:
# Load all data
base_path = '/Users/kevinliu/git/mgh_eeg_spindle/data'
data_list = load_eeg_data(base_path)

# Quick check of first loaded subject
first_subject = data_list[0]

print("Subject ID:", first_subject['subject_id'])
print("Group:", first_subject['group'])
print("Sleep stages shape:", first_subject['sleep_stages'].shape)
print("Artifacts shape:", first_subject['artifact_indicator'].shape)
print("EDF info:", first_subject['edf'].info)

Loaded subject 10918067_20220415_142547_fil from group ASD
Loaded subject 10994324_20220805_082743_fil from group ASD
Loaded subject 11113647_20220818_142928_fil from group ASD
Loaded subject 11196000_20220412_142232_fil from group ASD
Loaded subject 11200338-1_20221102_144056_fil from group ASD
Loaded subject 11212718-1_20221101_150706_fil from group ASD
Loaded subject 11336460_20220725_100434_fil from group ASD
Loaded subject 11353940-1_20221017_151814_fil from group ASD
Loaded subject 11367187_20220621_143328_fil from group ASD
Loaded subject 11402947_20220330_143311_fil from group ASD
Loaded subject 11460732_20220524_100527_fil from group ASD
Loaded subject 11537456_20220614_101224_fil from group ASD
Loaded subject 11550215_20220331_142205_fil from group ASD
Loaded subject 11630678_20220629_102630_fil from group ASD
Loaded subject 11806907_20220406_142529 2_fil from group ASD
Loaded subject 12315596-2_20220210_105023_fil from group ASD
Loaded subject 12361287_20220401_102433_fil fr

In [5]:
# Run ISO analysis for all subjects
iso_results_df, additional_metrics_df = run_iso_analysis(
    data_list,
    montage_paths={
        '32': 'eeg_ch_coords/coordinates32.xml',
        '64': 'eeg_ch_coords/coordinates64.xml'
    },
    output_base_path='./iso_analysis_outputs'
)

Results explicitly saved to: ./iso_analysis_outputs/iso_results_20250401_184315

[1/65] Processing subject: 10918067_20220415_142547_fil (ASD)
Reading 0 ... 433749  =      0.000 ...  1734.996 secs...
No valid epochs for channel ['EEG 1']. Returning NaN.
No valid epochs for channel ['EEG 2']. Returning NaN.
No valid epochs for channel ['EEG 3']. Returning NaN.
No valid epochs for channel ['EEG 4']. Returning NaN.
No valid epochs for channel ['EEG 5']. Returning NaN.
No valid epochs for channel ['EEG 6']. Returning NaN.
No valid epochs for channel ['EEG 7']. Returning NaN.
No valid epochs for channel ['EEG 8']. Returning NaN.
No valid epochs for channel ['EEG 9']. Returning NaN.
No valid epochs for channel ['EEG 10']. Returning NaN.
No valid epochs for channel ['EEG 11']. Returning NaN.
No valid epochs for channel ['EEG 12']. Returning NaN.
No valid epochs for channel ['EEG 13']. Returning NaN.
No valid epochs for channel ['EEG 14']. Returning NaN.
No valid epochs for channel ['EEG 15'].