In [18]:
# import mne

from pathlib import Path

import mne
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from src.dataset_utils import build_manifest_simple, save_manifest

## 0. Create manifest

In [66]:
DATASET_PATH = Path("C:/Users/MDBI/Documents/public-datasets/sleepy-rat/")

df = build_manifest_simple(DATASET_PATH)
save_manifest(df, DATASET_PATH / "manifest.csv")

In [67]:
df

Unnamed: 0,cohort,animal_id,edf_path,scoring_path,has_scoring
0,CohortA,A1,C:\Users\MDBI\Documents\public-datasets\sleepy...,C:\Users\MDBI\Documents\public-datasets\sleepy...,True
1,CohortA,A2,C:\Users\MDBI\Documents\public-datasets\sleepy...,C:\Users\MDBI\Documents\public-datasets\sleepy...,True
2,CohortA,A3,C:\Users\MDBI\Documents\public-datasets\sleepy...,C:\Users\MDBI\Documents\public-datasets\sleepy...,True
3,CohortA,A4,C:\Users\MDBI\Documents\public-datasets\sleepy...,C:\Users\MDBI\Documents\public-datasets\sleepy...,True
4,CohortB,B1,C:\Users\MDBI\Documents\public-datasets\sleepy...,C:\Users\MDBI\Documents\public-datasets\sleepy...,True
5,CohortB,B2,C:\Users\MDBI\Documents\public-datasets\sleepy...,C:\Users\MDBI\Documents\public-datasets\sleepy...,True
6,CohortB,B3,C:\Users\MDBI\Documents\public-datasets\sleepy...,C:\Users\MDBI\Documents\public-datasets\sleepy...,True
7,CohortB,B4,C:\Users\MDBI\Documents\public-datasets\sleepy...,C:\Users\MDBI\Documents\public-datasets\sleepy...,True
8,CohortC,C1,C:\Users\MDBI\Documents\public-datasets\sleepy...,C:\Users\MDBI\Documents\public-datasets\sleepy...,True
9,CohortC,C2,C:\Users\MDBI\Documents\public-datasets\sleepy...,C:\Users\MDBI\Documents\public-datasets\sleepy...,True


## 1. Load single EEG-recording

In [79]:
manifest = pd.read_csv("C:/Users/MDBI/Documents/public-datasets/sleepy-rat/manifest.csv")      # or use your in-memory df

row = manifest.iloc[0]
raw = mne.io.read_raw_edf(str(row['edf_path']), preload=True, stim_channel=None, verbose=False)
raw

Unnamed: 0,General,General.1
,Filename(s),A1.edf
,MNE object type,RawEDF
,Measurement date,1970-01-01 at 00:00:00 UTC
,Participant,WT1
,Experimenter,Unknown
,Acquisition,Acquisition
,Duration,23:59:60 (HH:MM:SS)
,Sampling frequency,128.00 Hz
,Time points,11059200
,Channels,Channels


In [80]:
print(raw)                # basic info
print(raw.info["sfreq"])  # sampling rate (Hz)
print(raw.info["ch_names"])

<RawEDF | A1.edf, 3 x 11059200 (86400.0 s), ~253.1 MB, data loaded>
128.0
['EEG1', 'EEG2', 'EMG']


### Tell MNE what type of data is in each channel, is it EEG or EMG?

In [81]:
# Example mapping — adjust to your channel names:
mapping = {
    "EEG1": "eeg",     # or "EEG1"
    "EEG2": "eeg",
    "EMG": "emg"
}
present = {ch: tp for ch, tp in mapping.items() if ch in raw.ch_names}
if present:
    raw.set_channel_types(present)

raw

Unnamed: 0,General,General.1
,Filename(s),A1.edf
,MNE object type,RawEDF
,Measurement date,1970-01-01 at 00:00:00 UTC
,Participant,WT1
,Experimenter,Unknown
,Acquisition,Acquisition
,Duration,23:59:60 (HH:MM:SS)
,Sampling frequency,128.00 Hz
,Time points,11059200
,Channels,Channels


### Load scorings

In [82]:
scorings = pd.read_csv(row['scoring_path'], index_col=0, names = ["score_1", "score_2"])

In [83]:
scorings

Unnamed: 0,score_1,score_2
0,n,n
1,n,n
2,n,n
3,n,n
4,n,n
...,...,...
10795,w,w
10796,w,w
10797,w,w
10798,w,w


## 2. Pre-processing steps for one EEG-recording

In [84]:
from pathlib import Path
import numpy as np
import pandas as pd
import mne

# ---------- load from your manifest ----------
row = manifest.iloc[0]                                
EDF_PATH = Path(row["edf_path"])
SCORING_CSV = Path(row["scoring_path"]) if pd.notna(row["scoring_path"]) else None


# ---------- parameters you can tweak ----------
EPOCH_LEN = 4.0         # mouse sleep scoring commonly 10 or 20 s; you have 20 s in many datasets
LINE_FREQ = 50.0         # set 60.0 if you're in a 60 Hz mains region
HP = 0.5                 # high-pass (Hz)
LP = 45.0                # low-pass (Hz); below hardware 64 Hz is fine
TARGET_SFREQ = 128.0     # you are already at 128 Hz; keep it consistent


# ---- 1) Load EDF header (lazy) ----
raw = mne.io.read_raw_edf(str(EDF_PATH), preload=False, stim_channel=None, verbose=False)

# Mark channel types (adjust names if yours differ)
eeg_like = [ch for ch in raw.ch_names if "EEG" in ch.upper()]
emg_like = [ch for ch in raw.ch_names if "EMG" in ch.upper()]
if eeg_like: raw.set_channel_types({ch: "eeg" for ch in eeg_like})
if emg_like: raw.set_channel_types({ch: "emg" for ch in emg_like})

# ---- 2) Read your scoring file (3 columns, no header) ----
# Accept comma OR semicolon just in case; strip spaces
sc = pd.read_csv(SCORING_CSV, header=None, names=["epoch", "score_1", "score_2"],
                 sep=r"[;,]", engine="python")
sc["score_1"] = sc["score_1"].astype(str).str.strip()
sc["score_2"] = sc["score_2"].astype(str).str.strip()

# Map label codes to canonical names
LABEL_MAP = {
    "w":"Wake", "wake":"Wake", "W":"Wake",
    "n":"NREM", "nr":"NREM", "N":"NREM", "NREM":"NREM",
    "r":"REM",  "rem":"REM", "R":"REM"
}

def norm_label(x: str) -> str:
    s = str(x).strip()
    return LABEL_MAP.get(s, LABEL_MAP.get(s.lower(), "Unknown"))

# Simple consensus between the two raters
PRIORITY = {"REM": 3, "NREM": 2, "Wake": 1, "Unknown": 0}
def consensus(a, b) -> str:
    a, b = norm_label(a), norm_label(b)
    if a == b: return a
    # tie-breaker: REM > NREM > Wake
    return a if PRIORITY[a] >= PRIORITY[b] else b

sc["stage"] = [consensus(a, b) for a, b in zip(sc["score_1"], sc["score_2"])]

# ---- 3) Turn rows into onsets/durations ----
rec_dur = float(raw.times[-1])                 # seconds
n_rows  = len(sc)
epoch_len = rec_dur / n_rows                   # ~4.0 s for your file
onsets   = np.arange(n_rows, dtype=float) * epoch_len
durations= np.full(n_rows, epoch_len, dtype=float)

# ---- 4) Attach annotations, crop to scored span, then load data ----
ann = mne.Annotations(onset=onsets, duration=durations,
                      description=sc["stage"].astype(str).to_numpy())
raw.set_annotations(ann)

raw.crop(tmin=float(onsets.min()), tmax=float(onsets.max() + epoch_len))
raw.load_data()

# ---- 5) Denoise (line + bands) ----
LINE = 50.0   # set to 60.0 if you're in a 60 Hz mains region
raw.notch_filter([LINE])

picks_eeg = mne.pick_types(raw.info, eeg=True,  emg=False)
picks_emg = mne.pick_types(raw.info, eeg=False, emg=True)

if len(picks_eeg):
    raw.filter(0.5, 45.0, picks=picks_eeg)   # EEG: sleep band
if len(picks_emg):
    raw.filter(10.0, 45.0, picks=picks_emg)  # EMG: keep activity band

# Optional average ref if ≥2 EEG
if len(picks_eeg) >= 2:
    raw.set_eeg_reference("average")

# ---- 6) Make per-epoch data (4 s) ----
event_id = {"Wake": 1, "NREM": 2, "REM": 3}
events = []
for onset, label in zip(onsets, sc["stage"]):
    if label not in event_id:    # skip Unknown
        continue
    events.append([int(round(onset * raw.info["sfreq"])), 0, event_id[label]])
events = np.asarray(events, dtype=int)

epochs4 = mne.Epochs(raw, events=events, event_id=event_id,
                     tmin=0.0, tmax=epoch_len, baseline=None,
                     preload=True, reject_by_annotation=True)

print(f"Epochs shape: {epochs4.get_data().shape}  (n_epochs, n_channels, n_times)")
# Optional: rebin to 20 s (group 5×4 s)
# ...

# ---- 7) Save outputs ----
raw.save("C:/Users/MDBI/Documents/public-datasets/sleepy-rat/processed-data/raw_preproc.fif", overwrite=True)
epochs4.save("C:/Users/MDBI/Documents/public-datasets/sleepy-rat/processed-data/epochs4s-epo.fif", overwrite=True)


Reading 0 ... 11059199  =      0.000 ... 86399.992 secs...
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 845 samples (6.602 s)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.5

  raw.save("C:/Users/MDBI/Documents/public-datasets/sleepy-rat/processed-data/raw_preproc.fif", overwrite=True)


Closing C:\Users\MDBI\Documents\public-datasets\sleepy-rat\processed-data\raw_preproc.fif
[done]
Overwriting existing file.
Overwriting existing file.


In [85]:
raw

Unnamed: 0,General,General.1
,Filename(s),A1.edf
,MNE object type,RawEDF
,Measurement date,1970-01-01 at 00:00:00 UTC
,Participant,WT1
,Experimenter,Unknown
,Acquisition,Acquisition
,Duration,23:59:60 (HH:MM:SS)
,Sampling frequency,128.00 Hz
,Time points,11059200
,Channels,Channels


In [75]:
epochs4

Unnamed: 0,General,General.1
,MNE object type,Epochs
,Measurement date,1970-01-01 at 00:00:00 UTC
,Participant,WT1
,Experimenter,Unknown
,Acquisition,Acquisition
,Total number of events,21597
,Events counts,NREM: 8539  REM: 2494  Wake: 10564
,Time range,0.000 – 4.000 s
,Baseline,off
,Sampling frequency,128.00 Hz


In [76]:
import numpy as np
import pandas as pd
import mne

# Frequency bands (tweak if you like)
BANDS = {
    "delta": (0.5, 4.0),
    "theta": (6.0, 9.0),
    "sigma": (10.0, 15.0),
    "beta":  (15.0, 30.0),
}

def _integrate_band(psd, freqs, fmin, fmax):
    """Integrate power between fmin..fmax using the trapezoid rule."""
    idx = (freqs >= fmin) & (freqs < fmax)
    if not np.any(idx):
        return np.zeros(psd.shape[:-1])
    return np.trapz(psd[..., idx], freqs[idx], axis=-1)

def extract_epoch_features(epochs: mne.Epochs,
                           eeg_picks=None,
                           emg_pick=None,
                           bands=BANDS) -> tuple[pd.DataFrame, np.ndarray]:
    """
    Returns:
      - X: DataFrame (one row per epoch)
      - y: array of stage labels ('Wake','NREM','REM') from epochs.event_id
    """
    sfreq = epochs.info["sfreq"]

    # --- labels y from events ---
    inv_id = {v: k for k, v in epochs.event_id.items()}
    y = np.array([inv_id.get(code, "Unknown") for code in epochs.events[:, 2]])

    # --- choose channels ---
    if eeg_picks is None:
        eeg_picks = mne.pick_types(epochs.info, eeg=True, emg=False)
    if emg_pick is None:
        emg_candidates = mne.pick_types(epochs.info, eeg=False, emg=True)
        emg_pick = emg_candidates[0] if len(emg_candidates) else None

    # --- EEG PSD (Welch) ---
    # Using MNE's built-in welch on epochs; returns V^2/Hz
    psd = epochs.compute_psd(method="welch", fmin=0.5, fmax=45.0)
    freqs = psd.freqs
    psd_eeg = psd.get_data(picks=eeg_picks)  # (n_epochs, n_eeg, n_freqs)
    if psd_eeg.size == 0:
        raise RuntimeError("No EEG channels found. Check channel types.")
    psd_eeg_mean = psd_eeg.mean(axis=1)      # average across EEG channels → (n_epochs, n_freqs)

    # Total power in analysis band
    total = _integrate_band(psd_eeg_mean, freqs, 0.5, 45.0) + 1e-20  # +eps to avoid div/0

    # Bandpowers (absolute and relative)
    feats = {}
    for name, (f1, f2) in bands.items():
        bp = _integrate_band(psd_eeg_mean, freqs, f1, f2)           # V^2
        feats[f"{name}_abs"] = bp
        feats[f"{name}_rel"] = bp / total

    # Simple EEG ratios
    feats["theta_delta"] = feats["theta_abs"] / (feats["delta_abs"] + 1e-20)
    feats["sigma_delta"] = feats["sigma_abs"] / (feats["delta_abs"] + 1e-20)

    # --- EMG features (RMS) ---
    if emg_pick is not None:
        emg = epochs.get_data(picks=[emg_pick])[:, 0, :]  # (n_epochs, n_times)
        emg_rms_v = np.sqrt((emg ** 2).mean(axis=1))      # in Volts
        feats["emg_rms_uv"] = emg_rms_v * 1e6             # report in µV
        feats["emg_rms_log"] = np.log10(feats["emg_rms_uv"] + 1e-6)
    else:
        feats["emg_rms_uv"] = np.zeros(len(y))
        feats["emg_rms_log"] = np.zeros(len(y))

    # Assemble DataFrame
    X = pd.DataFrame(feats)
    return X, y


In [77]:
X_one, y_one = extract_epoch_features(epochs4)
X_one.head(), np.unique(y_one, return_counts=True)

Effective window size : 4.008 (s)


  return np.trapz(psd[..., idx], freqs[idx], axis=-1)


(      delta_abs  delta_rel     theta_abs  theta_rel     sigma_abs  sigma_rel  \
 0  1.424768e-10   0.114969  2.160520e-10   0.174340  3.300775e-10   0.266351   
 1  2.720282e-10   0.180978  1.633843e-10   0.108698  3.307798e-10   0.220065   
 2  3.060551e-10   0.212357  2.520312e-10   0.174872  2.243878e-10   0.155692   
 3  3.448240e-10   0.139380  5.506970e-10   0.222595  5.821324e-10   0.235302   
 4  3.805450e-10   0.151019  5.054402e-10   0.200583  6.454865e-10   0.256161   
 
        beta_abs  beta_rel  theta_delta  sigma_delta  emg_rms_uv  emg_rms_log  
 0  2.114217e-10  0.170603     1.516402     2.316711   27.976080     1.446787  
 1  3.267474e-10  0.217383     0.600615     1.215976   22.469516     1.351594  
 2  2.326511e-10  0.161425     0.823483     0.733161   21.324376     1.328876  
 3  4.648554e-10  0.187898     1.597038     1.688202   24.588754     1.390737  
 4  2.697594e-10  0.107054     1.328201     1.696216   27.791812     1.443917  ,
 (array(['NREM', 'REM', 'Wake']

In [78]:
X_one

Unnamed: 0,delta_abs,delta_rel,theta_abs,theta_rel,sigma_abs,sigma_rel,beta_abs,beta_rel,theta_delta,sigma_delta,emg_rms_uv,emg_rms_log
0,1.424768e-10,0.114969,2.160520e-10,0.174340,3.300775e-10,0.266351,2.114217e-10,0.170603,1.516402,2.316711,27.976080,1.446787
1,2.720282e-10,0.180978,1.633843e-10,0.108698,3.307798e-10,0.220065,3.267474e-10,0.217383,0.600615,1.215976,22.469516,1.351594
2,3.060551e-10,0.212357,2.520312e-10,0.174872,2.243878e-10,0.155692,2.326511e-10,0.161425,0.823483,0.733161,21.324376,1.328876
3,3.448240e-10,0.139380,5.506970e-10,0.222595,5.821324e-10,0.235302,4.648554e-10,0.187898,1.597038,1.688202,24.588754,1.390737
4,3.805450e-10,0.151019,5.054402e-10,0.200583,6.454865e-10,0.256161,2.697594e-10,0.107054,1.328201,1.696216,27.791812,1.443917
...,...,...,...,...,...,...,...,...,...,...,...,...
21592,8.152086e-11,0.062781,4.328103e-10,0.333317,2.703968e-10,0.208239,2.150307e-10,0.165600,5.309197,3.316903,35.963510,1.555862
21593,1.207700e-10,0.113713,3.468966e-10,0.326628,2.133983e-10,0.200930,1.665801e-10,0.156847,2.872374,1.766981,44.334133,1.646738
21594,4.051657e-11,0.033221,1.032704e-10,0.084674,2.960185e-10,0.242713,2.103430e-10,0.172466,2.548844,7.306108,31.980311,1.504883
21595,2.919143e-10,0.175854,2.069523e-10,0.124672,3.903967e-10,0.235182,3.329385e-10,0.200568,0.708949,1.337368,34.064550,1.532303


## 3. Pre-process all data

In [88]:
from pathlib import Path
import numpy as np
import pandas as pd
import mne

def preprocess_recording_to_epochs(
    edf_path: Path,
    scoring_csv: Path,
    line_freq: float = 50.0,            # set 60.0 if you're in a 60 Hz region
    eeg_band: tuple = (0.5, 45.0),
    emg_band: tuple = (10.0, 45.0),
    consensus_rule: str = "priority"    # "priority" or "agree"
):
    """
    Load one EDF + its 3-column scoring CSV (epoch_idx, score_1, score_2) and return clean 4-s epochs.
    Stage codes: w/n/r. Artifact codes: 1/2/3 (dropped).
    Returns:
        epochs4 : mne.Epochs
        y       : np.ndarray[str] of labels ('Wake','NREM','REM')
        meta    : dict
    """
    edf_path, scoring_csv = Path(edf_path), Path(scoring_csv)

    # 1) Raw (lazy); set channel types (EEG1/EEG2/EMG)
    raw = mne.io.read_raw_edf(str(edf_path), preload=False, stim_channel=None, verbose=False)
    eeg_like = [ch for ch in raw.ch_names if "EEG" in ch.upper()]
    emg_like = [ch for ch in raw.ch_names if "EMG" in ch.upper()]
    if eeg_like: raw.set_channel_types({ch: "eeg" for ch in eeg_like})
    if emg_like: raw.set_channel_types({ch: "emg" for ch in emg_like})

    # 2) Read scoring (3 columns, no header)
    sc = pd.read_csv(scoring_csv, header=None, names=["epoch", "s1", "s2"],
                     sep=r"[;,]", engine="python")
    sc["s1"] = sc["s1"].astype(str).str.strip()
    sc["s2"] = sc["s2"].astype(str).str.strip()

    # Stage / artifact handling
    STAGE_MAP = {"w":"Wake", "n":"NREM", "r":"REM", "W":"Wake", "N":"NREM", "R":"REM"}
    ART = {"1","2","3"}  # 1: wake artifact, 2: NREM artifact, 3: REM artifact
    PRIORITY = {"REM": 3, "NREM": 2, "Wake": 1}

    def norm_stage(x): return STAGE_MAP.get(str(x).strip(), "Unknown")
    def is_artifact(a, b): return (str(a).strip() in ART) or (str(b).strip() in ART)
    def consensus(a, b):
        if is_artifact(a, b): return "Artifact"
        a, b = norm_stage(a), norm_stage(b)
        if consensus_rule == "agree": return a if a == b else "Unknown"
        if a == b: return a
        if a in PRIORITY and b in PRIORITY:
            return a if PRIORITY[a] >= PRIORITY[b] else b
        return "Unknown"

    sc["stage"] = [consensus(a, b) for a, b in zip(sc["s1"], sc["s2"])]
    sc["artifact"] = [is_artifact(a, b) for a, b in zip(sc["s1"], sc["s2"])]

    # 3) Build onsets/durations from row count (≈ 4 s)
    rec_dur = float(raw.times[-1])       # seconds
    n = len(sc)
    epoch_len = rec_dur / n
    onsets = np.arange(n, dtype=float) * epoch_len
    durations = np.full(n, epoch_len, float)

    # --- IMPORTANT: use the same orig_time for all new annotations ---
    # Prefer the raw meas_date (in your files it's 1970-01-01) to keep everything consistent.
    orig = raw.info.get("meas_date", None)

    # Stage annotations (all epochs)
    ann_stage = mne.Annotations(onset=onsets,
                                duration=durations,
                                description=sc["stage"].astype(str).to_numpy(),
                                orig_time=orig)

    # BAD_artifact annotations (only artifact epochs)
    new_ann = ann_stage
    art_mask = sc["artifact"].to_numpy()
    if art_mask.any():
        bad_ann = mne.Annotations(onset=onsets[art_mask],
                                  duration=durations[art_mask],
                                  description=["BAD_artifact"] * int(art_mask.sum()),
                                  orig_time=orig)
        new_ann = ann_stage + bad_ann   # same orig_time -> safe to add

    # Set annotations ONCE (don’t add to existing to avoid orig_time clashes)
    raw.set_annotations(new_ann)

    # Crop to scored span and load data
    raw.crop(onsets.min(), onsets.max() + epoch_len)
    raw.load_data()

    # 4) Denoise
    if line_freq is not None:
        raw.notch_filter([float(line_freq)])
    picks_eeg = mne.pick_types(raw.info, eeg=True, emg=False)
    picks_emg = mne.pick_types(raw.info, eeg=False, emg=True)
    if len(picks_eeg): raw.filter(eeg_band[0], eeg_band[1], picks=picks_eeg)
    if len(picks_emg): raw.filter(emg_band[0], emg_band[1], picks=picks_emg)
    if len(picks_eeg) >= 2:
        raw.set_eeg_reference("average")

    # 5) Events for clean (non-artifact) epochs
    event_id = {"Wake": 1, "NREM": 2, "REM": 3}
    keep = sc["stage"].isin(event_id.keys()) & (~sc["artifact"])
    evt_onsets = onsets[keep.to_numpy()]
    evt_labels = sc.loc[keep, "stage"].to_numpy()
    events = np.c_[
        np.round(evt_onsets * raw.info["sfreq"]).astype(int),
        np.zeros(len(evt_onsets), dtype=int),
        np.array([event_id[s] for s in evt_labels], dtype=int)
    ]

    epochs4 = mne.Epochs(raw, events, event_id, tmin=0.0, tmax=epoch_len,
                         baseline=None, preload=True, reject_by_annotation=True)

    inv = {v: k for k, v in event_id.items()}
    y = np.array([inv[c] for c in epochs4.events[:, 2]])

    meta = {
        "epoch_len": epoch_len,
        "n_epochs_total": n,
        "n_epochs_kept": len(epochs4),
        "kept_fraction": len(epochs4) / n,
        "sfreq": float(raw.info["sfreq"]),
        "channels": raw.ch_names,
        "orig_time_used": str(orig),
    }
    return epochs4, y, meta


In [89]:
# manifest must have columns: edf_path, scoring_path, and (optionally) cohort/animal_id
results = []
for _, row in manifest.iterrows():
    try:
        epochs4, y, meta = preprocess_recording_to_epochs(
            edf_path=row["edf_path"],
            scoring_csv=row["scoring_path"],
            line_freq=50.0,              # change to 60.0 if needed
            consensus_rule="priority"    # or "agree"
        )
        results.append({
            "cohort": row.get("cohort", None),
            "animal_id": row.get("animal_id", None),
            "epochs": epochs4,
            "labels": y,
            "meta": meta
        })
    except Exception as e:
        print(f"Failed on {row.get('animal_id', row['edf_path'])}: {e}")


Reading 0 ... 11059199  =      0.000 ... 86399.992 secs...
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 845 samples (6.602 s)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.5

In [90]:
results

[{'cohort': 'CohortA',
  'animal_id': 'A1',
  'epochs': <Epochs | 18221 events (all good), 0 – 4 s (baseline off), ~214.0 MB, data loaded,
   'Wake': 9785
   'NREM': 6908
   'REM': 1528>,
  'labels': array(['NREM', 'NREM', 'NREM', ..., 'Wake', 'Wake', 'Wake'], dtype='<U4'),
  'meta': {'epoch_len': 3.999999638310185,
   'n_epochs_total': 21600,
   'n_epochs_kept': 18221,
   'kept_fraction': 0.8435648148148148,
   'sfreq': 128.0,
   'channels': ['EEG1', 'EEG2', 'EMG'],
   'orig_time_used': '1970-01-01 00:00:00+00:00'}},
 {'cohort': 'CohortA',
  'animal_id': 'A2',
  'epochs': <Epochs | 14477 events (all good), 0 – 4 s (baseline off), ~170.0 MB, data loaded,
   'Wake': 5835
   'NREM': 7367
   'REM': 1275>,
  'labels': array(['Wake', 'Wake', 'Wake', ..., 'Wake', 'Wake', 'Wake'], dtype='<U4'),
  'meta': {'epoch_len': 3.999999638310185,
   'n_epochs_total': 21600,
   'n_epochs_kept': 14477,
   'kept_fraction': 0.6702314814814815,
   'sfreq': 128.0,
   'channels': ['EEG1', 'EEG2', 'EMG'],
   '

## 4. Create features

In [93]:
import numpy as np
import pandas as pd
import mne

# Frequency bands
BANDS = {
    "delta": (0.5, 4.0),
    "theta": (6.0, 9.0),
    "sigma": (10.0, 15.0),
    "beta":  (15.0, 30.0),
}

def _integrate_band(psd, freqs, fmin, fmax):
    idx = (freqs >= fmin) & (freqs < fmax)
    if not np.any(idx):
        return np.zeros(psd.shape[:-1])
    return np.trapezoid(psd[..., idx], freqs[idx], axis=-1)

def extract_basic_features(epochs: mne.Epochs, bands=BANDS):
    """
    From an MNE Epochs (e.g., 4-s epochs), return:
      X: DataFrame of features (n_epochs x n_features)
      y: array of labels ('Wake','NREM','REM') from epochs.event_id
    """
    # Labels y from events
    inv_id = {v: k for k, v in epochs.event_id.items()}
    y = np.array([inv_id.get(code, "Unknown") for code in epochs.events[:, 2]])

    # Picks
    eeg_picks = mne.pick_types(epochs.info, eeg=True, emg=False)
    emg_picks = mne.pick_types(epochs.info, eeg=False, emg=True)
    emg_pick = emg_picks[0] if len(emg_picks) else None

    # PSD (Welch)
    psd = epochs.compute_psd(method="welch", fmin=0.5, fmax=45.0)
    freqs = psd.freqs
    psd_data = psd.get_data()  # (n_epochs, n_channels, n_freqs)

    if len(eeg_picks) == 0:
        raise RuntimeError("No EEG channels found — check channel types.")
    psd_eeg = psd_data[:, eeg_picks, :]              # (n_epochs, n_eeg, n_freqs)
    psd_eeg_mean = psd_eeg.mean(axis=1)              # avg across EEGs → (n_epochs, n_freqs)

    # Total power for relative features
    total = _integrate_band(psd_eeg_mean, freqs, 0.5, 45.0) + 1e-20

    feats = {}
    for name, (f1, f2) in bands.items():
        bp = _integrate_band(psd_eeg_mean, freqs, f1, f2)      # V^2
        feats[f"{name}_abs"] = bp
        feats[f"{name}_rel"] = bp / total

    feats["theta_delta"] = feats["theta_abs"] / (feats["delta_abs"] + 1e-20)
    feats["sigma_delta"] = feats["sigma_abs"] / (feats["delta_abs"] + 1e-20)

    # EMG RMS (µV)
    if emg_pick is not None:
        emg = epochs.get_data(picks=[emg_pick])[:, 0, :]       # (n_epochs, n_times)
        emg_rms_v = np.sqrt((emg ** 2).mean(axis=1))
        feats["emg_rms_uv"] = emg_rms_v * 1e6
        feats["emg_rms_log"] = np.log10(feats["emg_rms_uv"] + 1e-6)
    else:
        feats["emg_rms_uv"] = np.zeros(len(y))
        feats["emg_rms_log"] = np.zeros(len(y))

    X = pd.DataFrame(feats)
    return X, y


In [105]:
def dataset_from_results(results):
    """
    results: list of dicts with keys 'epochs', 'labels', 'animal_id', 'cohort' (as in your loop)
    Returns:
      X (np.ndarray), y (np.ndarray), groups (np.ndarray), df (features+meta DataFrame)
    """
    frames = []
    for r in results:
        X_i, y_i = extract_basic_features(r["epochs"])
        X_i["label"] = y_i
        X_i["animal_id"] = r.get("animal_id")
        X_i["cohort"] = r.get("cohort")
        frames.append(X_i)

    df = pd.concat(frames, ignore_index=True)
    feature_cols = [c for c in df.columns if c not in ["label", "animal_id", "cohort"]]
    X = df[feature_cols].to_numpy()
    y = df["label"].to_numpy()
    groups = df["animal_id"].to_numpy()
    return X, y, groups, df

X, y, groups, df = dataset_from_results(results)

Effective window size : 4.008 (s)
Effective window size : 4.008 (s)
Effective window size : 4.008 (s)
Effective window size : 4.008 (s)
Effective window size : 4.008 (s)
Effective window size : 4.008 (s)
Effective window size : 4.008 (s)
Effective window size : 4.008 (s)
Effective window size : 4.008 (s)
Effective window size : 4.008 (s)
Effective window size : 4.008 (s)
Effective window size : 4.008 (s)
Effective window size : 4.008 (s)
Effective window size : 4.008 (s)


### Class balance

In [117]:
import numpy as np

labels, counts = np.unique(y, return_counts=True)
proportions = counts / counts.sum()   # or counts / y.size

print("Proportions:")
for lab, p in zip(labels, proportions):
    print(f"  {lab}: {p:.3f}")


Proportions:
  NREM: 0.484
  REM: 0.075
  Wake: 0.440


## 5. Train, val, test

In [None]:
import numpy as np
from sklearn.model_selection import GroupShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report

def grouped_train_val_test_split(X, y, groups, *, test_size=0.2, val_size=0.2, random_state=42):
    """
    Returns indices (train_idx, val_idx, test_idx) with group-wise disjointness.
    val_size is the *overall* fraction. Internally we split test first, then val from the remainder.
    """
    X = np.asarray(X); y = np.asarray(y); groups = np.asarray(groups)

    # 1) Hold out TEST by groups
    gss_test = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
    trainval_idx, test_idx = next(gss_test.split(X, y, groups=groups))

    # 2) From the remaining groups, hold out VAL by groups
    rel_val_size = val_size / (1.0 - test_size)  # fraction of remaining pool
    gss_val = GroupShuffleSplit(n_splits=1, test_size=rel_val_size, random_state=random_state + 1)
    tr_local, val_local = next(gss_val.split(X[trainval_idx], y[trainval_idx], groups=groups[trainval_idx]))

    train_idx = trainval_idx[tr_local]
    val_idx   = trainval_idx[val_local]
    return train_idx, val_idx, test_idx

# --- use it ---
train_idx, val_idx, test_idx = grouped_train_val_test_split(
    X, y, groups, test_size=0.15, val_size=0.2, random_state=42
)

X_train, y_train = X[train_idx], y[train_idx]
X_val,   y_val   = X[val_idx],   y[val_idx]
X_test,  y_test  = X[test_idx],  y[test_idx]

# sanity: no group leakage
assert set(groups[train_idx]).isdisjoint(groups[val_idx])
assert set(groups[train_idx]).isdisjoint(groups[test_idx])
assert set(groups[val_idx]).isdisjoint(groups[test_idx])

## 5. Predict sleep-wake stage

### 5.1 Logistic regression

In [None]:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report


pipe = Pipeline([
    ("scaler", StandardScaler()),
    ("clf", LogisticRegression(max_iter=200, class_weight="balanced"))
])

pipe.fit(X_train, y_train)
print(classification_report(y_val, pipe.predict(X_val), digits=3))

              precision    recall  f1-score   support

        NREM      0.920     0.790     0.850     24012
         REM      0.448     0.620     0.520      3794
        Wake      0.906     0.975     0.939     25718

    accuracy                          0.867     53524
   macro avg      0.758     0.795     0.770     53524
weighted avg      0.880     0.867     0.869     53524



### 5.2 XGBoost

In [139]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from xgboost import XGBClassifier

# ---- 3) Fit XGBoost (basic settings, no early stopping)
le = LabelEncoder()
y_tr_enc = le.fit_transform(y_train)
y_val_enc = le.transform(y_val)

# Class-balanced sample weights (boosts REM)
classes = np.unique(y_train)
cls_weights = compute_class_weight(class_weight="balanced", classes=classes, y=y_train)
w_map = dict(zip(classes, cls_weights))
sw_train = np.array([w_map[c] for c in y_train])

# XGBoost (basic)
clf = XGBClassifier(
    objective="multi:softprob",
    num_class=len(le.classes_),
    n_estimators=100,
    learning_rate=0.001,
    max_depth=4,
    subsample=0.9,
    colsample_bytree=0.9,
    reg_lambda=1.0,
    tree_method="hist",   # "gpu_hist" if you have a GPU
    random_state=2025,
)

clf.fit(X_train, y_tr_enc, sample_weight=sw_train)

# ---- 4) Evaluate on held-out animals
y_pred_enc = clf.predict(X_val)
y_pred = le.inverse_transform(y_pred_enc)
print(classification_report(y_val, y_pred, digits=3))

              precision    recall  f1-score   support

        NREM      0.906     0.817     0.859     24012
         REM      0.442     0.469     0.455      3794
        Wake      0.898     0.973     0.934     25718

    accuracy                          0.867     53524
   macro avg      0.749     0.753     0.749     53524
weighted avg      0.869     0.867     0.866     53524

