In [22]:
import numpy as np
from scipy.signal import resample, detrend
from scipy.stats import mode
from itertools import groupby
import mne
from mne.time_frequency import psd_array_multitaper
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import os
import pandas as pd
from datetime import datetime

In [2]:
def get_iso(edf, sleep_stages, artifact_indicator, ch_groups):
    """
    Computes the infraslow oscillation (ISO) spectrum from EEG data during Non-REM (NREM) sleep epochs.

    Parameters
    ----------
    edf : mne.io.Raw
        The EEG data object loaded using MNE, containing EEG signals and sampling frequency.
        
    sleep_stages : np.ndarray, shape (n_samples,)
        Numeric array indicating sleep stages at each EEG sample. Expected coding:
        0 = NREM, 1 = REM, 2 = Wake, 3 = Artifact.
        (Sleep stages will be resampled internally to match EEG data length.)
        
    artifact_indicator : np.ndarray, shape (n_samples,)
        Binary array (1-dimensional) indicating artifacts at each EEG sample (1 = artifact, 0 = clean).
        (Artifacts are resampled internally to match EEG data length.)

    ch_groups : list of list of str
        List defining channel groupings for analysis. Each inner list contains channel names to average.

    Returns
    -------
    spec_iso : np.ndarray, shape (n_channels, 3, n_freqs)
        ISO spectrum, separately computed for sigma_all, sigma_slow, and sigma_fast bands, for each channel group.

    freq_iso : np.ndarray, shape (n_freqs,)
        Frequency bins corresponding to ISO power spectra, typically in the infraslow range (0–0.1 Hz).

    Deviations from the Original Function
    -------------------------------------
    - Sleep stages and artifacts are resampled explicitly from sample-level data, rather than repeated
      assuming epoch-level (30-s) input as in the original.
    - NREM epochs (label = 0) are exclusively and correctly selected for ISO analysis,
      correcting an earlier oversight that selected REM (1) and Wake (2) incorrectly.
    - Artifact indicator handling was adjusted to expect global (1D) artifact indicators, rather than
      the original channel-wise (2D) artifact indicators.
    - The function explicitly handles cases with insufficient valid epochs, providing informative warnings.
    - Empty or invalid epochs are explicitly skipped to prevent NaN outputs in spectral calculations.
    """
    eeg = edf.get_data() * 1e6
    original_fs = edf.info['sfreq']

    # Resample EEG data to 128 Hz
    eeg = resample(eeg, int(round(eeg.shape[1] / original_fs * 128)), axis=1)
    fs = 128

    # Original code repeats sleep_stages assuming they are epoch-level labels (one per 30-sec epoch).
    # Here, instead, sleep stages and artifacts (originally sampled at EEG frequency) are directly resampled
    # to match the EEG data length after downsampling.
    n_samples_resampled = eeg.shape[1]

    # Resample sleep stages (float interpolation then rounding back to integer stages)
    sleep_stages_resampled = resample(sleep_stages.astype(float), n_samples_resampled)
    sleep_stages = np.round(sleep_stages_resampled).astype(int)

    # Segment EEG into epochs
    window_size = int(round(4 * fs))
    step_size = int(round(2 * fs))
    epoch_time = 2
    start_ids = np.arange(0, eeg.shape[1] - window_size + 1, step_size)
    epochs = np.array([eeg[:, x:x + window_size] for x in start_ids])

    # Assign sleep stages and artifact indicators per epoch (identical logic to original function)
    sleep_stages = np.array([mode(sleep_stages[x:x + window_size], keepdims=False).mode for x in start_ids])
    artifact_indicator = np.array([artifact_indicator[x:x + window_size].any(axis=0) for x in start_ids])

    sleep_ids = np.where(sleep_stages == 0)[0]  # Original used [0,1], now explicitly only NREM (0)
    if len(sleep_ids) == 0:
        raise ValueError("No NREM epochs found.")
    start = sleep_ids[0]
    end = sleep_ids[-1] + 1
    epochs = epochs[start:end]
    sleep_stages = sleep_stages[start:end]
    artifact_indicator = artifact_indicator[start:end]

    # Detrend epochs
    epochs = detrend(epochs, axis=-1)

    # Calculate sigma-band power
    spec, freq = psd_array_multitaper(
        epochs, sfreq=fs, fmin=11, fmax=15, bandwidth=1,
        normalization='full', remove_dc=True, verbose=False
    )
    spec[np.isinf(spec)] = np.nan
    dfreq = freq[1] - freq[0]

    sigma_db = 10 * np.log10(np.sum(spec, axis=-1) * dfreq).T
    sigma_db_slow = 10 * np.log10(np.sum(spec[..., freq < 13], axis=-1) * dfreq).T
    sigma_db_fast = 10 * np.log10(np.sum(spec[..., freq >= 13], axis=-1) * dfreq).T
    sigma_db[np.isinf(sigma_db)] = np.nan
    sigma_db_slow[np.isinf(sigma_db_slow)] = np.nan
    sigma_db_fast[np.isinf(sigma_db_fast)] = np.nan
    fs_sigma = 1 / epoch_time

    sigma_db = np.array([np.nanmean(sigma_db[[edf.ch_names.index(x) for x in xx]], axis=0) for xx in ch_groups])
    sigma_db_slow = np.array([np.nanmean(sigma_db_slow[[edf.ch_names.index(x) for x in xx]], axis=0) for xx in ch_groups])
    sigma_db_fast = np.array([np.nanmean(sigma_db_fast[[edf.ch_names.index(x) for x in xx]], axis=0) for xx in ch_groups])

    # Artifact_indicator is 1-dimensional; original assumed 2D channel-specific artifacts
    artifact_indicator2 = np.array([artifact_indicator for _ in ch_groups])

    window_size = int(round(256 * fs_sigma))
    step_size = int(round(64 * fs_sigma))
    freq_iso = np.linspace(0, 0.1, 101)[1:]
    spec_iso_all_ch = []

    for chi in range(len(ch_groups)):
        # Now explicitly selects NREM(0)
        good_ids = (sleep_stages == 0) & (~artifact_indicator2[chi])
        spec_isos = []
        cc = 0
        for k, l in groupby(good_ids):
            ll = len(list(l))
            if not k:
                cc += ll
                continue
            for start in np.arange(cc, cc + ll - window_size + 1, step_size):
                xx = np.array([
                    sigma_db[chi, start:start + window_size],
                    sigma_db_slow[chi, start:start + window_size],
                    sigma_db_fast[chi, start:start + window_size],
                ])

                # Skip empty or all-NaN epochs explicitly
                if np.isnan(xx).all() or np.all(xx == 0):
                    continue

                xx = detrend(xx, axis=-1)
                spec_iso, freq_out = psd_array_multitaper(
                    xx, fs_sigma, fmin=0, fmax=0.2, bandwidth=0.01,
                    normalization='full', verbose=False
                )

                ff = interp1d(freq_out, spec_iso, axis=-1, bounds_error=False, fill_value='extrapolate')
                spec_iso = ff(freq_iso)

                spec_iso /= spec_iso.sum(axis=-1, keepdims=True)
                spec_isos.append(spec_iso)
            cc += ll

        # Explicit handling if no valid epochs found
        if len(spec_isos) == 0:
            print(f"No valid epochs for channel {ch_groups[chi]}. Returning NaN.")
            spec_iso_ch = np.full((3, len(freq_iso)), np.nan)
        else:
            spec_iso_ch = np.nanmean(np.array(spec_isos), axis=0)

        spec_iso_all_ch.append(spec_iso_ch)

    spec_iso = np.array(spec_iso_all_ch)
    return spec_iso, freq_iso

In [None]:
def load_eeg_data(base_path):
    """
    Load EEG data (EDF files), sleep stage labels, and artifact indicators from a structured directory.

    The function reads EEG recordings in EDF format along with their corresponding sleep stage and 
    artifact indicator files (stored as `.npz`) for two participant groups: Autism Spectrum Disorder (ASD)
    and Typically Developing (TD). Each group's data should reside within its own subfolder.

    Directory structure should follow this format explicitly:

    base_path/
    ├── edf/
    │   ├── ASD/
    │   └── TD/
    ├── sleep_stages/
    │   ├── ASD/
    │   └── TD/
    └── artifacts/
        ├── ASD/
        └── TD/

    Parameters
    ----------
    base_path : str
        Path to the root directory containing the 'edf', 'sleep_stages', and 'artifacts' subdirectories.

    Returns
    -------
    data : list of dict
        A list of dictionaries, each containing the following keys:
            - 'subject_id': str, subject identifier extracted from the EDF filename.
            - 'group': str, participant group ('ASD' or 'TD').
            - 'edf': mne.io.Raw, raw EEG data object loaded with MNE (not preloaded).
            - 'sleep_stages': np.ndarray, sleep stage labels loaded from the corresponding `.npz` file.
            - 'artifact_indicator': np.ndarray, binary artifact indicators loaded from the corresponding `.npz` file.
    """
    groups = ['ASD', 'TD']
    data = []

    for group in groups:
        edf_path = os.path.join(base_path, 'edf', group)
        stages_path = os.path.join(base_path, 'sleep_stages', group)
        artifact_path = os.path.join(base_path, 'artifacts', group)

        edf_files = sorted([f for f in os.listdir(edf_path) if f.lower().endswith('.edf')])

        for edf_file in edf_files:
            subject_id = os.path.splitext(edf_file)[0]

            stages_file = os.path.join(stages_path, f'sleep_stages_{subject_id}.npz')
            artifact_file = os.path.join(artifact_path, f'artifacts_{subject_id}.npz')
            edf_file_full = os.path.join(edf_path, edf_file)

            if not os.path.exists(stages_file):
                print(f"Missing sleep stage file for {subject_id}, skipping...")
                continue
            if not os.path.exists(artifact_file):
                print(f"Missing artifact file for {subject_id}, skipping...")
                continue

            edf_raw = mne.io.read_raw_edf(edf_file_full, preload=False, verbose=False)

            # Load the first available array from npz files automatically
            with np.load(stages_file) as stages_npz:
                sleep_stages = stages_npz[stages_npz.files[0]]

            with np.load(artifact_file) as artifacts_npz:
                artifact_indicator = artifacts_npz[artifacts_npz.files[0]]

            data.append({
                'subject_id': subject_id,
                'group': group,
                'edf': edf_raw,
                'sleep_stages': sleep_stages,
                'artifact_indicator': artifact_indicator
            })

            print(f"Loaded subject {subject_id} from group {group}")

    return data


def run_iso_and_save_plots(data_list, output_base_path='./iso_analysis_outputs'):
    """
    Runs ISO analysis explicitly for all subjects, generates plots, and saves each subject's plot
    into a new timestamped output folder.

    Parameters
    ----------
    data_list : list of dict
        Each dictionary must contain keys:
        'subject_id', 'edf' (mne.io.Raw), 'sleep_stages', 'artifact_indicator', and 'group'.
        
    output_base_path : str, optional
        Base directory to save the output folders. Defaults to './iso_analysis_outputs'.
        
    Returns
    -------
    iso_results_df : pd.DataFrame
        DataFrame containing ISO analysis results for all subjects.
        
    freq_iso : np.ndarray
        Array of ISO frequency bins.
    """
    iso_results = []
    freq_iso = None

    # Create explicit output folder with current timestamp
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    output_folder = os.path.join(output_base_path, f'iso_results_{timestamp}')
    os.makedirs(output_folder, exist_ok=True)
    print(f"Saving all plots explicitly to: {output_folder}")

    for idx, subject in enumerate(data_list):
        subject_id = subject['subject_id']
        group = subject['group']
        print(f"Processing subject {idx + 1}/{len(data_list)}: {subject_id}")

        try:
            edf_raw = subject['edf'].copy().load_data()
            sleep_stages = subject['sleep_stages']
            artifact_indicator = subject['artifact_indicator']

            ch_groups = [[ch] for ch in edf_raw.ch_names if ch != 'EEG VREF']

            spec_iso, freq_iso = get_iso(edf_raw, sleep_stages, artifact_indicator, ch_groups)

            # Flatten ISO results into dictionary
            iso_dict = {'subject_id': subject_id, 'group': group}
            for ch_idx, ch_name in enumerate([ch[0] for ch in ch_groups]):
                for band_idx, band in enumerate(['sigma_all', 'sigma_slow', 'sigma_fast']):
                    for f_idx, freq in enumerate(freq_iso):
                        col_label = f'{ch_name}_{band}_{freq:.4f}Hz'
                        iso_dict[col_label] = spec_iso[ch_idx, band_idx, f_idx]

            iso_results.append(iso_dict)

            # Generate and explicitly save plot
            iso_df_single_subject = {'frequency': freq_iso}
            for ch_idx, ch_name in enumerate([ch[0] for ch in ch_groups]):
                iso_df_single_subject[f'{ch_name}_sigma_all'] = spec_iso[ch_idx, 0, :]
                iso_df_single_subject[f'{ch_name}_sigma_slow'] = spec_iso[ch_idx, 1, :]
                iso_df_single_subject[f'{ch_name}_sigma_fast'] = spec_iso[ch_idx, 2, :]
            
            iso_df_single_subject = pd.DataFrame(iso_df_single_subject)

            plot_iso_single_subject_save(
                iso_df_single_subject, subject_id, 
                save_path=os.path.join(output_folder, f'{subject_id}_iso.png')
            )

        except Exception as e:
            print(f"Skipping subject {subject_id} due to error: {e}")
            continue

    iso_results_df = pd.DataFrame(iso_results)

    # Save the full DataFrame explicitly
    iso_results_csv_path = os.path.join(output_folder, 'iso_analysis_all_subjects.csv')
    iso_results_df.to_csv(iso_results_csv_path, index=False)
    print(f"\nISO results DataFrame explicitly saved to: {iso_results_csv_path}")

    return iso_results_df, freq_iso


def plot_iso_single_subject_save(iso_df, subject_id, save_path):
    """
    Generate ISO plots explicitly for a single subject and save the plot image to a specified path.

    Parameters
    ----------
    iso_df : pd.DataFrame
        DataFrame containing ISO data for a single subject.
    
    subject_id : str
        Subject identifier used in plot title.
    
    save_path : str
        Path (including filename) to explicitly save the generated plot.
    
    Returns
    -------
    None
    """
    freq = iso_df['frequency']
    channels = [col.split('_sigma')[0] for col in iso_df.columns if 'sigma_all' in col]

    num_channels = len(channels)
    fig, axes = plt.subplots(num_channels, 1, figsize=(5, 2.5 * num_channels), sharex=True)

    if num_channels == 1:
        axes = [axes]

    for ax, ch in zip(axes, channels):
        ax.plot(freq, iso_df[f'{ch}_sigma_all'], label='Sigma All')
        ax.plot(freq, iso_df[f'{ch}_sigma_slow'], label='Sigma Slow', linestyle='--')
        ax.plot(freq, iso_df[f'{ch}_sigma_fast'], label='Sigma Fast', linestyle=':')
        ax.set_title(f'ISO Spectrum - Channel: {ch}')
        ax.set_ylabel('Relative Power')
        ax.legend(loc='upper right')

    axes[-1].set_xlabel('Frequency (Hz)')
    plt.suptitle(f'ISO Spectrum for Subject: {subject_id}', fontsize=10)
    plt.tight_layout(rect=[0, 0.03, 1, 0.97])
    
    plt.savefig(save_path, dpi=300)
    plt.close(fig)

mam
    print(f"Plot explicitly saved to: {save_path}")

In [5]:
# Load all data
base_path = '/Users/kevinliu/git/mgh_eeg_spindle/data'
data_list = load_eeg_data(base_path)

# Quick check of first loaded subject
first_subject = data_list[0]

print("Subject ID:", first_subject['subject_id'])
print("Group:", first_subject['group'])
print("Sleep stages shape:", first_subject['sleep_stages'].shape)
print("Artifacts shape:", first_subject['artifact_indicator'].shape)
print("EDF info:", first_subject['edf'].info)

Loaded subject 10918067_20220415_142547_fil from group ASD
Loaded subject 10994324_20220805_082743_fil from group ASD
Loaded subject 11113647_20220818_142928_fil from group ASD
Loaded subject 11196000_20220412_142232_fil from group ASD
Loaded subject 11200338-1_20221102_144056_fil from group ASD
Loaded subject 11212718-1_20221101_150706_fil from group ASD
Loaded subject 11336460_20220725_100434_fil from group ASD
Loaded subject 11353940-1_20221017_151814_fil from group ASD
Loaded subject 11367187_20220621_143328_fil from group ASD
Loaded subject 11402947_20220330_143311_fil from group ASD
Loaded subject 11460732_20220524_100527_fil from group ASD
Loaded subject 11537456_20220614_101224_fil from group ASD
Loaded subject 11550215_20220331_142205_fil from group ASD
Loaded subject 11630678_20220629_102630_fil from group ASD
Loaded subject 11806907_20220406_142529 2_fil from group ASD
Loaded subject 12315596-2_20220210_105023_fil from group ASD
Loaded subject 12361287_20220401_102433_fil fr

In [23]:
# Run ISO analysis explicitly for all subjects
iso_df_all, freq_bins = run_iso_and_save_plots(data_list)

Saving all plots explicitly to: ./iso_analysis_outputs/iso_results_20250329_041327
Processing subject 1/65: 10918067_20220415_142547_fil
Reading 0 ... 433749  =      0.000 ...  1734.996 secs...


  sigma_db = 10 * np.log10(np.sum(spec, axis=-1) * dfreq).T
  sigma_db_slow = 10 * np.log10(np.sum(spec[..., freq < 13], axis=-1) * dfreq).T
  sigma_db_fast = 10 * np.log10(np.sum(spec[..., freq >= 13], axis=-1) * dfreq).T


No valid epochs for channel ['EEG 1']. Returning NaN.
No valid epochs for channel ['EEG 2']. Returning NaN.
No valid epochs for channel ['EEG 3']. Returning NaN.
No valid epochs for channel ['EEG 4']. Returning NaN.
No valid epochs for channel ['EEG 5']. Returning NaN.
No valid epochs for channel ['EEG 6']. Returning NaN.
No valid epochs for channel ['EEG 7']. Returning NaN.
No valid epochs for channel ['EEG 8']. Returning NaN.
No valid epochs for channel ['EEG 9']. Returning NaN.
No valid epochs for channel ['EEG 10']. Returning NaN.
No valid epochs for channel ['EEG 11']. Returning NaN.
No valid epochs for channel ['EEG 12']. Returning NaN.
No valid epochs for channel ['EEG 13']. Returning NaN.
No valid epochs for channel ['EEG 14']. Returning NaN.
No valid epochs for channel ['EEG 15']. Returning NaN.
No valid epochs for channel ['EEG 16']. Returning NaN.
No valid epochs for channel ['EEG 17']. Returning NaN.
No valid epochs for channel ['EEG 18']. Returning NaN.
No valid epochs for