# Arousal Detection Evaluation using High-Frequency Sleep Staging

This notebook evaluates arousal detection performance by treating short wake episodes
predicted by high-frequency sleep staging models as arousal events.

## Overview

Arousals are brief awakenings during sleep that disrupt sleep architecture. This notebook
tests whether high-frequency sleep stage predictions (with temporal resolution finer than
30 seconds) can detect these events by identifying short wake periods.

## Method

1. **Load ground truth arousals** from MASS dataset annotation files (EDF format)
   - Uses timing offsets from `mass_offsets.py` to align annotations with trimmed PSG data

2. **Load sleep stage predictions** from high-frequency model sweep
   - Multiple temporal resolutions tested (1-3840 predictions per 30s epoch)

3. **Convert wake predictions to arousal events**
   - Find continuous wake periods (sleep stage = 0)
   - Merge nearby events (within 10% of merged duration, max 15s total)
   - Filter by duration: 3-15 seconds (AASM arousal criteria)
   - Require â‰¥10s of sleep before each arousal

4. **Evaluate using IoU-based metrics**
   - Calculate Intersection over Union between predicted and ground truth arousals
   - Compute precision, recall, F1 at multiple IoU thresholds (0.0 to 1.0)

## Outputs

- `arousals_aasm_merged.json`: Predicted arousal events per model/resolution/subject
- `arousals_iou_scores_aasm_merged.json`: Precision/recall/F1 scores at each IoU threshold

## Dependencies

- High-frequency prediction files from `predict-high-freq.py` sweep

In [1]:
import gc
import glob
import json
from os.path import dirname

import mne
import numpy as np

from scripts.arousals.mass_offsets import offsets


def load_arousals(ann_file):
    # <Event channel="EEG C4-LER" groupName="MicroArousal" name="CARSM expert" scoringType="expert"/>
    data = mne.io.read_raw_edf(ann_file, verbose=False)
    s_id = ann_file.split("/")[-1].split(" ")[0]
    arousals = [
        (ann["onset"], ann["onset"] + ann["duration"])
        for ann in data.annotations
        if "MicroArousal" in ann["description"]
    ]
    channels = [
        ann["description"][ann["description"].find("channel") + 9:ann["description"].find("groupName") - 2]
        for ann in data.annotations
        if "MicroArousal" in ann["description"]
    ]
    if len(arousals) == 0:
        return arousals
    arousals = np.array([list(ar) + [ch] for ar, ch in zip(arousals, channels) if ch.startswith("EEG")])
    arousals[:, :2] = arousals[:, :2].astype(float) - offsets[f"mass-c{s_id.split('-')[1][1]}"][s_id]
    return arousals


def load_hf_sleep_stages(sweep_folder):
    labels_glob = f"{sweep_folder}/*/labels.npz"
    preds_glob = f"{sweep_folder}/*/predictions.npz"
    labels_files = sorted(glob.glob(labels_glob), key=lambda x: int(x.split("/")[-2]))
    preds_files = sorted(glob.glob(preds_glob), key=lambda x: int(x.split("/")[-2]))

    labels_per_sr_per_model = {}
    preds_per_sr_per_model = {}
    for label_file, pred_file in zip(labels_files, preds_files):
        with open(dirname(label_file) + "/predict-high-freq.log") as f:
            lines = f.readlines()
            sr_line = [line for line in lines if "sleep_stage_frequency" in line][0]
            sleep_stage_sr = int(sr_line.split("=")[1])
            model_line = [line for line in lines if "model.path" in line][0]
            model = model_line.split("=")[1]

        if model not in labels_per_sr_per_model:
            labels_per_sr_per_model[model] = {}
            preds_per_sr_per_model[model] = {}
        labels_per_sr_per_model[model][sleep_stage_sr] = np.load(label_file)
        preds_per_sr_per_model[model][sleep_stage_sr] = np.load(pred_file)

    return labels_per_sr_per_model, preds_per_sr_per_model


# TODO: replace with path to MASS data
arousal_files = glob.glob("/home/niklas/data/MASS/*Annotations.edf")

all_arousals = {}
for af in arousal_files:
    s_id = af.split('/')[-1].split(' ')[0]
    print(f"Processing {s_id}")
    arousals = load_arousals(af)

    all_arousals[s_id] = arousals
    if len(arousals) == 0:
        print(f"No arousals found in {af}")
        continue

    # check if theres overlap between arousals
    arousals = arousals[np.argsort(arousals[:, 0].astype(float))]
    overlap_mask = (arousals[1:, 0].astype(float) < arousals[:-1, 1].astype(float))
    if np.any(overlap_mask):
        print(f"Overlapping arousals found in {af}")
        print(arousals[np.r_[overlap_mask, False]])
        print(arousals[np.r_[False, overlap_mask]])

    gc.collect()

Processing 01-01-0020
Processing 01-03-0060
Processing 01-03-0010
Processing 01-01-0005
Processing 01-01-0041
Processing 01-03-0057
Processing 01-01-0006
Processing 01-03-0037
Processing 01-03-0038
Processing 01-03-0016
Processing 01-03-0042
Processing 01-03-0031
Processing 01-03-0064
Processing 01-01-0011
Processing 01-01-0040
Processing 01-01-0019
Processing 01-03-0008
Processing 01-03-0025
Processing 01-03-0061
Processing 01-03-0033
Processing 01-03-0005
Processing 01-03-0020
Processing 01-03-0017
Processing 01-01-0002
Processing 01-03-0052
Processing 01-01-0037
Processing 01-01-0023
Processing 01-01-0033
Processing 01-01-0031
Processing 01-03-0044
Processing 01-01-0013
Processing 01-03-0056
Processing 01-01-0051
Processing 01-03-0047
Processing 01-01-0046
Processing 01-03-0046
Processing 01-01-0016
Processing 01-01-0018
Processing 01-01-0042
Processing 01-03-0062
Processing 01-01-0022
Processing 01-01-0015
Processing 01-03-0063
Processing 01-01-0027
Processing 01-03-0027
Processing

In [2]:
def merge_arousals(arousals, max_after_merge_dur=15, max_merge_dist_perc=0.1):
    if len(arousals) == 0:
        return np.array(arousals)
    # Sort events by start time to ensure "consecutive" logic is correct.
    sorted_events = sorted(arousals, key=lambda x: x[0])

    merged_events = [sorted_events[0]]

    n_merges = 0
    for current_event in sorted_events[1:]:
        last_merged_event = merged_events[-1]

        merged_duration = current_event[1] - last_merged_event[0]
        distance = current_event[0] - last_merged_event[1]

        if (merged_duration <= max_after_merge_dur and
                distance <= merged_duration * max_merge_dist_perc):
            merged_events[-1][1] = current_event[1]
            n_merges += 1
        else:
            merged_events.append(current_event)

    print(f"Merged {n_merges} arousals, remaining {len(merged_events)}")

    return np.array(merged_events)

In [3]:
pred_arousals = {}
# {sleep_stage_sr: {wake_length: {s_id: arousals}}}

hf_ss_sweep_folder = "../../logs/exp002/exp002a/sweep-2025-07-28_19-56-55_mass/"
_, hf_ss = load_hf_sleep_stages(hf_ss_sweep_folder)

for model in hf_ss.keys():
    pred_arousals[model] = {}

    for ss_sr in hf_ss[model].keys():
        pred_arousals[model][ss_sr] = {}

        lower_wake_length = 3
        upper_wake_length = 15

        pred_arousals[model][ss_sr] = {}
        print(f"Processing {ss_sr} {lower_wake_length} {upper_wake_length}")

        for s_id_ss in hf_ss[model][ss_sr].keys():
            hf_ss_sr = hf_ss[model][ss_sr][s_id_ss]
            # get continuous wake sleep stages of at least lower_wake_length and at most upper_wake_length
            wake_mask = hf_ss_sr == 0
            wake_borders = np.diff(wake_mask, prepend=0, append=0)
            wake_starts = np.where(wake_borders == 1)[0]
            wake_ends = np.where(wake_borders == -1)[0]
            assert len(wake_starts) == len(wake_ends)
            if len(wake_starts) == 0:
                continue

            arousals = np.column_stack((wake_starts, wake_ends)) / ss_sr * 30
            arousals = merge_arousals(arousals)

            # remove wake events with less than 10s of sleep before
            sleep_before_lengths = (arousals[:, 0] - np.r_[0, arousals[:-1, 1]])
            arousals = arousals[sleep_before_lengths >= 10]

            wake_lengths = arousals[:, 1] - arousals[:, 0]
            arousals = arousals[(lower_wake_length <= wake_lengths) & (wake_lengths <= upper_wake_length)]
            pred_arousals[model][ss_sr][s_id_ss] = arousals

Processing 1 3 15
Merged 0 arousals, remaining 30
Merged 0 arousals, remaining 37
Merged 0 arousals, remaining 26
Merged 0 arousals, remaining 51
Merged 0 arousals, remaining 37
Merged 0 arousals, remaining 42
Merged 0 arousals, remaining 30
Merged 0 arousals, remaining 24
Merged 0 arousals, remaining 41
Merged 0 arousals, remaining 36
Merged 0 arousals, remaining 55
Merged 0 arousals, remaining 58
Merged 0 arousals, remaining 28
Merged 0 arousals, remaining 27
Merged 0 arousals, remaining 43
Merged 0 arousals, remaining 18
Merged 0 arousals, remaining 60
Merged 0 arousals, remaining 37
Merged 0 arousals, remaining 47
Merged 0 arousals, remaining 56
Merged 0 arousals, remaining 60
Merged 0 arousals, remaining 33
Merged 0 arousals, remaining 47
Merged 0 arousals, remaining 41
Merged 0 arousals, remaining 33
Merged 0 arousals, remaining 25
Merged 0 arousals, remaining 22
Merged 0 arousals, remaining 28
Merged 0 arousals, remaining 29
Merged 0 arousals, remaining 23
Merged 0 arousals, rem

In [4]:
# JSON serializer for numpy arrays
class NumpyArrayEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NumpyArrayEncoder, self).default(obj)


with open('arousals_aasm_merged.json', 'w') as f:
    json.dump(pred_arousals, f, cls=NumpyArrayEncoder)

In [5]:
def calc_tp_fp_fn(gt_arousals, pred_arousals, overlap_thresholds):
    tp, fp = [0] * len(overlap_thresholds), [0] * len(overlap_thresholds)
    if len(gt_arousals) > 0:
        gt_arousals = gt_arousals[:, :2].astype(float)

    used_gt_ar = [np.zeros(len(gt_arousals))] * len(overlap_thresholds)
    gt_start = 0
    for pred_s in pred_arousals:
        intersects = [0] * len(gt_arousals)
        unions = [0] * len(gt_arousals)
        for gt_i in range(gt_start, len(gt_arousals)):
            gt_s = gt_arousals[gt_i]
            if pred_s[1] < gt_s[0]:
                break
            if pred_s[0] > gt_s[1]:
                gt_start = gt_i + 1
                continue
            intersects[gt_i] = min(gt_s[1], pred_s[1]) - max(gt_s[0], pred_s[0])
            unions[gt_i] = max(gt_s[1], pred_s[1]) - min(gt_s[0], pred_s[0])
        ious = [intersects[i] / unions[i] if unions[i] > 0 else 0
                for i in range(len(intersects))]
        for i, ovt in enumerate(overlap_thresholds):
            if len(ious) > 0 and max(ious) > ovt:
                tp[i] += 1
                used_gt_ar[i][np.argmax(ious)] = 1
            else:
                fp[i] += 1

    fn = [len(gt_arousals) - np.sum(used_gt_ar[i]) for i in range(len(overlap_thresholds))]

    return np.array(tp), np.array(fp), np.array(fn)

In [8]:
overlap_thresholds = np.linspace(0, 1, 11)

scores = {}

for model in pred_arousals.keys():
    scores[model] = {}

    for ss_sr in pred_arousals[model].keys():
        scores[model][ss_sr] = {}

        lower_wake_length = 3
        upper_wake_length = 15

        print(f"Processing {ss_sr} {lower_wake_length} {upper_wake_length}")

        for s_id_ss in pred_arousals[model][ss_sr].keys():
            pred_ar = pred_arousals[model][ss_sr][s_id_ss]
            gt_ar = all_arousals[s_id_ss.split('_')[1]]
            tp, fp, fn = calc_tp_fp_fn(gt_ar, pred_ar, overlap_thresholds)
            prec = [tp[i] / (tp[i] + fp[i]) if tp[i] + fp[i] > 0 else 0 for i in range(len(overlap_thresholds))]
            rec = [tp[i] / (tp[i] + fn[i]) if tp[i] + fn[i] > 0 else 0 for i in range(len(overlap_thresholds))]
            f1 = [2 * tp[i] / (2 * tp[i] + fp[i] + fn[i]) if tp[i] + fp[i] + fn[i] > 0 else 0
                  for i in range(len(overlap_thresholds))]
            scores[model][ss_sr][s_id_ss] = {"prec": prec, "rec": rec, "f1": f1}

Processing 1 3 15
Processing 2 3 15
Processing 4 3 15
Processing 8 3 15
Processing 16 3 15
Processing 32 3 15
Processing 64 3 15
Processing 128 3 15
Processing 256 3 15
Processing 384 3 15
Processing 640 3 15
Processing 960 3 15
Processing 1920 3 15
Processing 3840 3 15
Processing 1 3 15
Processing 2 3 15
Processing 4 3 15
Processing 8 3 15
Processing 16 3 15
Processing 32 3 15
Processing 64 3 15
Processing 128 3 15
Processing 256 3 15
Processing 384 3 15
Processing 640 3 15
Processing 960 3 15
Processing 1920 3 15
Processing 3840 3 15
Processing 1 3 15
Processing 2 3 15
Processing 4 3 15
Processing 8 3 15
Processing 16 3 15
Processing 32 3 15
Processing 64 3 15
Processing 128 3 15
Processing 256 3 15
Processing 384 3 15
Processing 640 3 15
Processing 960 3 15
Processing 1920 3 15
Processing 3840 3 15


In [9]:
# save f1s to json
with open('arousals_iou_scores_aasm_merged.json', 'w') as f:
    json.dump(scores, f)