In [1]:
import mne
import matplotlib.pyplot as plt
import numpy as np

import os
import pandas as pd
from mne.preprocessing import ICA
from mne_icalabel import label_components
from autoreject import AutoReject
from scipy.signal.windows import hamming
from scipy.stats import spearmanr
from autoreject import get_rejection_threshold
import gc
from specparam import SpectralGroupModel
#from specparam.analysis import get_band_peak_group
from mne.preprocessing import create_eog_epochs, create_ecg_epochs
from scipy.io import loadmat
from pathlib import Path
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

%matplotlib qt
#the following line allows interactive plotting -> https://mne.discourse.group/t/semi-interactive-plot-with-matplotlib-backend/5094/2
#matplotlib.use('Qt5Agg')

In [None]:
### 1. SETUP 
#load in data and annotations 
subj_id = "0016" 
psg_path = "/Users/elizabethkaplan/Desktop/SS2_Data/01-02-0016-PSG.edf" 
annotation_file_path = "/Users/elizabethkaplan/Desktop/SS2_Data/01-02-0016 KComplexes_E1.edf" 
spindles_annotations = "/Users/elizabethkaplan/Desktop/SS2_Data/01-02-0016 Spindles_E1.edf" 
staging_file_path = '/Users/elizabethkaplan/Desktop/SS2_Data/01-02-0016-Base.edf' 

# preload=True loads the data into memory, enabling faster operations

raw = mne.io.read_raw_edf(psg_path, preload=True) 
annot = mne.read_annotations(annotation_file_path) 
spindles = mne.read_annotations(spindles_annotations) 
stages = mne.read_annotations(staging_file_path)

# Create output folders 
full_sess_name = f"01-02-{subj_id}" 
out_dir = Path("/Users/elizabethkaplan/Desktop/SS2_Results") 
session_name = full_sess_name 
subject_model_dir = out_dir / full_sess_name / "spectral_models" 
subject_model_dir.mkdir(parents=True, exist_ok=True)

# ---- Rename channels properly (run once) ---- 
rename_map = {} 
for ch in raw.ch_names: 
    if ch.startswith("EEG "): 
        base = ch.replace("EEG ", "").replace("-CLE", "") 
        rename_map[ch] = base 
        
raw.rename_channels(rename_map, allow_duplicates=True)

# ---- Attach montage AFTER rename ---- 
montage = mne.channels.make_standard_montage("standard_1020") 
raw.set_montage(montage, match_case=False, on_missing="ignore") 

# set ch types 
ch_types = { "EOG Upper Vertic": "eog", 
             "EOG Lower Vertic": "eog", 
             "EOG Left Horiz": "eog", 
             "EOG Right Horiz": "eog", 
             "EMG Chin": "emg", "ECG ECGI": 
             "ecg", "Resp Nasal": "misc", # or 'resp' if you prefer 
             # A2 is a mastoid ref; you can treat as EEG or misc depending on how you use it 
             "A2": "eeg",} # this is the renamed "EEG A2-CLE" }

raw.set_channel_types({k: v for k, v in ch_types.items() if k in raw.ch_names})

# attach montage 
montage = mne.channels.make_standard_montage("standard_1020") 
raw.set_montage(montage, match_case=False, on_missing="ignore")


# N2 intervals from staging 
n2_intervals = [] 
for a in stages: 
    if a["description"] == "Sleep stage 2": 
        tmin = float(a["onset"]) 
        tmax = float(a["onset"] + a["duration"]) 
        n2_intervals.append((tmin, tmax)) 
    
if not n2_intervals: 
    raise RuntimeError("No 'Sleep stage 2' intervals found.")

#merge intervals if they are close in time 
def merge_intervals(intervals, gap=0.0): 
    """Merge intervals that overlap or are within gap seconds.""" 
    if not intervals: 
        return [] 
    intervals = sorted(intervals, key=lambda x: x[0]) 
    merged = [list(intervals[0])] 
    for start, end in intervals[1:]: 
        if start <= merged[-1][1] + gap: 
            merged[-1][1] = max(merged[-1][1], end) 
        else: merged.append([start, end]) 
    return [(s, e) for s, e in merged]

n2_intervals_merged = merge_intervals(n2_intervals, gap=0.5) # 0.5s gap tolerance is safe 

# Build N2-only raw 
raws = [] 
for tmin, tmax in n2_intervals_merged: 
    r = raw.copy() 
    r.crop(tmin=tmin, tmax=tmax) 
    raws.append(r) 
    
raw_n2 = mne.concatenate_raws(raws)

# Map global time -> compressed N2 time 
def map_to_n2_time(t, intervals): 
    elapsed = 0.0 
    for tmin, tmax in intervals: 
        if t < tmin: 
            break 
        if tmin <= t <= tmax: 
            return elapsed + (t - tmin) 
            elapsed += (tmax - tmin) 
    return None 

# Remap KC + spindle annotations 
all_annot = annot + spindles 

new_onsets, new_durations, new_desc = [], [], [] 
for a in all_annot: 
    t_new = map_to_n2_time(float(a["onset"]), n2_intervals_merged) 
    if t_new is not None: 
        new_onsets.append(t_new) 
        new_durations.append(float(a["duration"])) 
        new_desc.append(a["description"])

annot_on_n2_timeline = mne.Annotations(new_onsets, new_durations, new_desc) 
raw_n2.set_annotations(annot_on_n2_timeline) 

# Save annotations 
annot_on_n2_timeline.save( os.path.join(out_dir, f"{session_name}_N2_annotations.csv"), overwrite=True )

## Save info on sleep durations 
total_sec = raw.times[-1] 
n2_sec = raw_n2.times[-1] 

dur_df = pd.DataFrame([{ "session": full_sess_name, 
                         "total_duration_sec": total_sec, 
                         "total_duration_hr": total_sec / 3600, 
                         "n2_duration_sec": n2_sec, 
                         "n2_duration_hr": n2_sec / 3600, }]) 

out_path = subject_model_dir / f"{full_sess_name}_recording_durations.csv" 
dur_df.to_csv(out_path, index=False)

# Save info on KC and Spindle density 
# N2 duration 
n2_duration_sec = raw_n2.times[-1] 
n2_duration_min = n2_duration_sec / 60 

# Count KCs in N2 
kc_n2 = 0 
for a in annot: 
    if map_to_n2_time(float(a["onset"]), n2_intervals_merged) is not None: 
        kc_n2 += 1 
        
# Count spindles in N2 
spindle_n2 = 0 
for a in spindles: 
    if map_to_n2_time(float(a["onset"]), n2_intervals_merged) is not None: 
        spindle_n2 += 1

# per min 
kc_density_per_min = kc_n2 / n2_duration_min if n2_duration_min > 0 else float("nan") 
spindle_density_per_min = spindle_n2 / n2_duration_min if n2_duration_min > 0 else float("nan") 

# Save 
out_dir = "/Users/elizabethkaplan/Desktop/SS2_Results" 
os.makedirs(out_dir, exist_ok=True)

# Use your subject/session variable if available 
session_name = full_sess_name # e.g., "01-02-0002" 
out_path = os.path.join(out_dir, f"{session_name}_N2_event_summary.csv") 

df = pd.DataFrame([{ "session": session_name, 
                     "n2_duration_sec": n2_duration_sec, 
                     "n2_duration_min": n2_duration_min, 
                     "kc_count_n2": kc_n2, 
                     "kc_density_per_min_n2": kc_density_per_min, 
                     "spindle_count_n2": spindle_n2, 
                     "spindle_density_per_min_n2": spindle_density_per_min, }])

df.to_csv(out_path, index=False) 
print("Saved N2 KC + spindle summary to:", out_path) 

## 2. PREPROCESSING 

#High and low pass filter 
raw_n2.filter(l_freq=0.1, h_freq=100.0) 

#notch filter 
raw_n2.notch_filter(freqs=60, picks=None, filter_length='auto', phase='zero') 

# Calculate artifact rejection threshold

# 30 sec epochs 
events_n2_30 = mne.make_fixed_length_events(raw_n2, start=0, stop=raw_n2.times[-1], duration=30.0) 
epochs_n2_30 = mne.Epochs( raw_n2, 
                           events=events_n2_30, 
                           tmin=0.0, tmax=30.0, # 30 seconds 
                           baseline=None, 
                           picks="eeg", 
                           preload=True, 
                           reject_by_annotation=True )

# calc amplitude thrshold 
reject = get_rejection_threshold(epochs_n2_30, decim=1)

# remove signal that exceeds threshold 
epochs_n2_30_clean = epochs_n2_30.copy().drop_bad(reject=reject) 

# --- ICA prep --- 
epochs_for_ica = epochs_n2_30_clean.copy().filter(l_freq=1.0, h_freq=40.0) 
ica = ICA( n_components=0.99, # adapts to channel count 
           max_iter="auto", 
           method="infomax", 
           random_state=97, 
           fit_params=dict(extended=True), ) 
ica.fit(epochs_for_ica)

# ICLabel 
ic_labels = label_components(epochs_for_ica, ica, method="iclabel") 
labels = ic_labels["labels"] # Keep "brain" (and optionally "other"); exclude the rest 
exclude_idx = [i for i, lab in enumerate(labels) if lab not in ("brain", "other")] 
ica.exclude = exclude_idx 

print("ICLabel counts:", {lab: sum(l == lab for l in labels) for lab in set(labels)}) 
print(f"Excluding {len(ica.exclude)} components: {ica.exclude}")

# Optional: inspect what you're excluding (only if there are any) 
if len(ica.exclude) > 0: 
    ica.plot_components(picks=ica.exclude[:20]) 
else: print("No ICs marked for exclusion; skipping ica.plot_components().") 

# Apply ICA 
epochs_n2_30_ica = epochs_n2_30_clean.copy() 
epochs_n2_30_ica = epochs_n2_30_ica.pick_types(eeg=True) 
epochs_n2_30_ica = epochs_n2_30_ica.set_eeg_reference("average", projection=False) # Safe to call even if ica.exclude == [] 
ica.apply(epochs_n2_30_ica) 
print("Done. Final cleaned epochs:", epochs_n2_30_ica)

### Rereference to mastoid 
epochs_n2_30_ica.set_eeg_reference(ref_channels=["A2"], 
                                   projection=False) 
epochs_n2_30_ica.drop_channels(["A2"]) 

### Autorejection 
N_JOBS = 10 
ar = AutoReject(n_jobs=N_JOBS, random_state=42, verbose=False, picks="eeg") 
epochs_ar_final, reject_log = ar.fit_transform(epochs_n2_30_ica, return_log=True) 

print(f"AutoReject removed {len(epochs_n2_30_ica) - len(epochs_ar_final)}/{len(epochs_n2_30_ica)} epochs") 
print("Number of epochs after autoreject:", len(epochs_ar_final))

Extracting EDF parameters from /Users/elizabethkaplan/Desktop/SS2_Data/01-02-0016-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 7255039  =      0.000 ... 28339.996 secs...


  raw.set_channel_types({k: v for k, v in ch_types.items() if k in raw.ch_names})


Overwriting existing file.
Saved N2 KC + spindle summary to: /Users/elizabethkaplan/Desktop/SS2_Results/01-02-0016_N2_event_summary.csv
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 1e+02 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.10
- Lower transition bandwidth: 0.10 Hz (-6 dB cutoff frequency: 0.05 Hz)
- Upper passband edge: 100.00 Hz
- Upper transition bandwidth: 25.00 Hz (-6 dB cutoff frequency: 112.50 Hz)
- Filter length: 8449 samples (33.004 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.9s


Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 59 - 61 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 59.35
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 59.10 Hz)
- Upper passband edge: 60.65 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 60.90 Hz)
- Filter length: 1691 samples (6.605 s)

Not setting metadata
348 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 348 events and 7681 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.7s


0 bad epochs dropped
Estimating rejection dictionary for eeg
0 bad epochs dropped
Setting up band-pass filter from 1 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 845 samples (3.301 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 287 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 449 tasks      | elapsed:    0.2s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    0.3s
[Parallel(n_jobs=1)]: Done 881 tasks      | elapsed:    0.4s
[Parallel(n_jobs=1)]: Done 1151 tasks      | elapsed:    0.5s
[Parallel(n_jobs=1)]: Done 1457 tasks      | elapsed:    0.7s
[Parallel(n_jobs=1)]: Done 1799 tasks      | elapsed:    0.9s
[Parallel(n_jobs=1)]: Done 2177 tasks      | elapsed:    1.1s
[Parallel(n_jobs=1)]: Done 2591 tasks      | elapsed:    1.3s
[Parallel(n_jobs=1)]: Done 3041 tasks      | elapsed:    1.6s
[Parallel(n_jobs=1)]: Done 3527 tasks      | elapsed:    1.8s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    2.1s
[Parallel(n_jobs=1)]: Done 4607 tasks      | elapsed:    2.4s
[Parallel(n_job

Fitting ICA to data using 20 channels (please be patient, this may take a while)
Selecting by explained variance: 10 components
Computing Extended Infomax ICA


In [6]:
# =========================
# SAVE CLEANED DATA + UPDATED ANNOTATIONS (post-AutoReject)
# Robust version: uses epochs_ar_final.events (no .selection)
# =========================
import numpy as np
import pandas as pd
import mne
from pathlib import Path

# ---- output folder ----
save_root = Path("/Users/elizabethkaplan/Desktop/SS2_Results") / full_sess_name / "cleaned"
save_root.mkdir(parents=True, exist_ok=True)

raw_fif_path   = save_root / f"{full_sess_name}_N2_clean_autoreject_raw.fif"
ann_mne_path   = save_root / f"{full_sess_name}_N2_clean_autoreject_annotations.csv"
ann_table_path = save_root / f"{full_sess_name}_N2_clean_autoreject_annotations_table.csv"
epoch_map_path = save_root / f"{full_sess_name}_N2_clean_autoreject_epoch_time_map.csv"

# -----------------------
# 1) Build gapless cleaned Raw from kept 30s epochs
# -----------------------
data = epochs_ar_final.get_data()  # (n_kept_epochs, n_ch, n_samp)
n_ep, n_ch, n_samp = data.shape
data_cont = data.transpose(1, 0, 2).reshape(n_ch, n_ep * n_samp)

raw_clean = mne.io.RawArray(data_cont, epochs_ar_final.info.copy(), verbose=False)

# -----------------------
# 2) Old->New time mapping using kept epochs' original start times
# -----------------------
sfreq = raw_n2.info["sfreq"]
epoch_len = epochs_ar_final.tmax - epochs_ar_final.tmin  # should be 30.0

# Old start times (raw_n2 timeline) of the KEPT epochs:
# epochs_ar_final.events[:, 0] are sample indices relative to raw_n2 start
starts_old = epochs_ar_final.events[:, 0] / sfreq

# New start times are back-to-back (gapless)
starts_new = np.arange(len(starts_old)) * epoch_len

# Save mapping table
pd.DataFrame({
    "kept_epoch_order": np.arange(len(starts_old)),
    "old_start_s": starts_old,
    "new_start_s": starts_new,
    "epoch_len_s": epoch_len,
}).to_csv(epoch_map_path, index=False)
print("Saved epoch time mapping to:", epoch_map_path)

# Helper: map old time -> new time if it falls inside a kept epoch
# Use searchsorted for speed and correctness even if old epochs are not perfectly regular.
starts_old_sorted_idx = np.argsort(starts_old)
starts_old_sorted = starts_old[starts_old_sorted_idx]
starts_new_sorted = starts_new[starts_old_sorted_idx]

def map_old_to_new_time(t_old: float):
    j = np.searchsorted(starts_old_sorted, t_old, side="right") - 1
    if j < 0:
        return None
    s_old = float(starts_old_sorted[j])
    if t_old >= s_old + epoch_len:
        return None
    s_new = float(starts_new_sorted[j])
    return s_new + (t_old - s_old)

# -----------------------
# 3) Remap annotations to the cleaned timeline
# STRICT: keep only events fully contained within a kept epoch
# -----------------------
old_ann = annot_on_n2_timeline  # annotations on raw_n2 timeline

new_onsets, new_durs, new_desc = [], [], []
dropped_outside = 0
dropped_boundary = 0

for onset, dur, desc in zip(old_ann.onset, old_ann.duration, old_ann.description):
    onset = float(onset)
    dur = float(dur)

    onset_new = map_old_to_new_time(onset)
    if onset_new is None:
        dropped_outside += 1
        continue

    # Ensure full duration stays inside the same kept epoch
    end_new = map_old_to_new_time(onset + dur)
    if end_new is None:
        dropped_boundary += 1
        continue

    new_onsets.append(onset_new)
    new_durs.append(dur)
    new_desc.append(str(desc))

print(f"Dropped {dropped_outside} annotations (fell in rejected time)")
print(f"Dropped {dropped_boundary} annotations (duration crossed a rejected/boundary region)")

ann_clean = mne.Annotations(new_onsets, new_durs, new_desc)
raw_clean.set_annotations(ann_clean)

# Save annotations (MNE) + readable table
ann_clean.save(ann_mne_path, overwrite=True)

pd.DataFrame({
    "onset_s": ann_clean.onset,
    "duration_s": ann_clean.duration,
    "description": ann_clean.description,
}).to_csv(ann_table_path, index=False)

# -----------------------
# 4) Save cleaned raw (FIF)
# -----------------------
raw_clean.save(raw_fif_path, overwrite=True)

print("Saved cleaned raw:", raw_fif_path)
print("Saved cleaned annotations:", ann_mne_path)
print("Saved cleaned annotations table:", ann_table_path)


Saved epoch time mapping to: /Users/elizabethkaplan/Desktop/SS2_Results/01-02-0016/cleaned/01-02-0016_N2_clean_autoreject_epoch_time_map.csv
Dropped 1759 annotations (fell in rejected time)
Dropped 0 annotations (duration crossed a rejected/boundary region)
Writing /Users/elizabethkaplan/Desktop/SS2_Results/01-02-0016/cleaned/01-02-0016_N2_clean_autoreject_raw.fif
Closing /Users/elizabethkaplan/Desktop/SS2_Results/01-02-0016/cleaned/01-02-0016_N2_clean_autoreject_raw.fif
[done]
Saved cleaned raw: /Users/elizabethkaplan/Desktop/SS2_Results/01-02-0016/cleaned/01-02-0016_N2_clean_autoreject_raw.fif
Saved cleaned annotations: /Users/elizabethkaplan/Desktop/SS2_Results/01-02-0016/cleaned/01-02-0016_N2_clean_autoreject_annotations.csv
Saved cleaned annotations table: /Users/elizabethkaplan/Desktop/SS2_Results/01-02-0016/cleaned/01-02-0016_N2_clean_autoreject_annotations_table.csv
