Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ Available Tasks
Readmission Prediction <tasks/pyhealth.tasks.readmission_prediction>
Sleep Staging <tasks/pyhealth.tasks.sleep_staging>
Sleep Staging (SleepEDF) <tasks/pyhealth.tasks.SleepStagingSleepEDF>
MVCL Training (SleepEDF EEG) <tasks/pyhealth.tasks.MVCLTrainingSleepEEG>
Temple University EEG Tasks <tasks/pyhealth.tasks.temple_university_EEG_tasks>
Sleep Staging v2 <tasks/pyhealth.tasks.sleep_staging_v2>
Benchmark EHRShot <tasks/pyhealth.tasks.benchmark_ehrshot>
Expand Down
7 changes: 7 additions & 0 deletions docs/api/tasks/pyhealth.tasks.MVCLTrainingSleepEEG.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.tasks.MVCLTrainingSleepEEG
===================================

.. autoclass:: pyhealth.tasks.MVCLTrainingSleepEEG
:members:
:undoc-members:
:show-inheritance:
111 changes: 111 additions & 0 deletions examples/mvcl_training_sleepedf.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MVCLTrainingSleepEEG on Sleep-EDF\n",
"\n",
"This example shows how to use `MVCLTrainingSleepEEG` from `pyhealth.tasks`.\n",
"\n",
"The workflow mirrors the task pattern used in PyHealth examples:\n",
"1. Load `SleepEDFDataset`\n",
"2. Run the task on one patient for a quick sanity check\n",
"3. Optionally run `set_task()` for the full dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pyhealth.datasets import SleepEDFDataset\n",
"from pyhealth.tasks import MVCLTrainingSleepEEG\n",
"\n",
"# Update this path to your local Sleep-EDF root.\n",
"DATA_ROOT = \"../sleepedf\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = SleepEDFDataset(root=DATA_ROOT, subset=\"cassette\")\n",
"dataset.stats()\n",
"print(f\"Number of patients: {len(dataset.unique_patient_ids)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d838e3b9",
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"This dataset contains 153 whole-night sleep electroencephalography\n",
"(EEG) recordings collected from 82 healthy subjects. Each recording is sampled at 100 Hz using a 1-lead \n",
"EEG signal. The EEG signals are segmented into non-overlapping windows of size 200, each forming\n",
"one sample. Each sample is labeled with one of five sleep stages: Wake (W), Non-rapid Eye Movement\n",
"(N1, N2, N3), and Rapid Eye Movement (REM). This segmentation results in 371,055 samples\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Quick sanity check on one patient.\n",
"patient_id = dataset.unique_patient_ids[0]\n",
"patient = dataset.get_patient(patient_id)\n",
"\n",
"task = MVCLTrainingSleepEEG(\n",
" window_size=200, ## Create None overlapping window of 200 lenth \n",
" crop_length=178, ## take first 178 data points of the window to match that of Epilepsy data \n",
" eeg_channel=\"EEG Fpz-Cz\",\n",
")\n",
"samples = task(patient)\n",
"\n",
"print(f\"patient_id: {patient_id}\")\n",
"print(f\"sample count: {len(samples)}\")\n",
"print(f\"sample keys: {list(samples[0].keys())}\")\n",
"print(f\"signal shape: {samples[0]['signal'].shape}\")\n",
"print(f\"xt shape: {samples[0]['xt'].shape}\")\n",
"print(f\"xd shape: {samples[0]['xd'].shape}\")\n",
"print(f\"xf shape: {samples[0]['xf'].shape}\")\n",
"print(f\"label: {samples[0]['label']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Full pipeline (can take a while and uses disk cache).\n",
"sample_dataset = dataset.set_task(task, num_workers=1)\n",
"print(f\"Total task samples: {len(sample_dataset)}\")\n",
"print(f\"Input schema: {sample_dataset.input_schema}\")\n",
"print(f\"Output schema: {sample_dataset.output_schema}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
17 changes: 8 additions & 9 deletions pyhealth/tasks/mvcl_training_sleepedf_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ class MVCLTrainingSleepEEG(BaseTask):

Applies MV preprocessing per event file (one PSG/Hypnogram pair at a time),
then appends samples immediately, so each returned sample includes ``xt``,
``dx``, and ``xf`` without a patient-level global buffer.
``xd``, and ``xf`` without a patient-level global buffer.

Tensors are stored as ``numpy.float32`` arrays with shape ``(L, C_view)`` where
``C_view`` is 1 by default; with ``time_as_feature=True``, a leading time channel
in ``[0,1]`` is concatenated so ``C_view`` is 2.
"""

task_name: str = "MVCLTrainingSleepEEG"
input_schema = {"signal": "tensor"}
input_schema = {"xt": "tensor", "xd": "tensor", "xf": "tensor"}
output_schema = {"label": "multiclass"}

def __init__(
Expand Down Expand Up @@ -176,9 +176,9 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
"epoch_index": b["epoch_index"],
"window_in_epoch": b["window_in_epoch"],
"signal": vec,
"xt": xt[i].detach().cpu().numpy().astype(np.float32),
"xd": dx[i].detach().cpu().numpy().astype(np.float32),
"xf": xf[i].detach().cpu().numpy().astype(np.float32),
"xt": xt[i],
"xd": dx[i],
"xf": xf[i],
"label": b["label"],
}
)
Expand Down Expand Up @@ -376,9 +376,9 @@ def pt_dict_to_pyhealth_samples(
"patient_id": f"{patient_id_prefix}_{i}",
"record_id": f"{record_id_prefix}_{i}",
"signal": signal_array[i][np.newaxis, :],
"xt": xt[i].detach().cpu().numpy().astype(np.float32),
"xd": dx[i].detach().cpu().numpy().astype(np.float32),
"xf": xf[i].detach().cpu().numpy().astype(np.float32),
"xt": xt[i],
"xd": dx[i],
"xf": xf[i],
"label": int(label_array[i]),
}
)
Expand Down Expand Up @@ -415,7 +415,6 @@ def pt_file_to_sample_dataset(
return create_sample_dataset(
samples=samples,
input_schema={
"signal": "tensor",
"xt": "tensor",
"xd": "tensor",
"xf": "tensor",
Expand Down
166 changes: 166 additions & 0 deletions tests/core/test_mvcl_training_sleepedf_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import shutil
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
from collections import Counter

import mne
import numpy as np
import pandas as pd

from pyhealth.datasets import SleepEDFDataset
from pyhealth.tasks import MVCLTrainingSleepEEG


class TestMVCLTrainingSleepEEGTask(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.temp_dir = tempfile.mkdtemp()
cls.dummy_dataset_dir = Path(cls.temp_dir) / "dummy_dataset"
cls.cassette_dir = cls.dummy_dataset_dir / "sleep-cassette"
cls.cassette_dir.mkdir(parents=True, exist_ok=True)

cls._create_dummy_subject_spreadsheets()
cls._create_dummy_patient_files()
cls._create_dummy_metadata_csv()

@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.temp_dir, ignore_errors=True)

@classmethod
def _create_dummy_subject_spreadsheets(cls):
"""Create required SC/ST files in both requested locations."""
df = pd.DataFrame(
{
"subject": [1, 2],
"night": [1, 2],
"age": [25, 30],
"sex (F=1)": [1, 2],
"LightsOff": ["22:00", "22:30"],
}
)
spreadsheet_targets = [
cls.dummy_dataset_dir / "SC-subjects.xls",
cls.dummy_dataset_dir / "ST-subjects.xls",
cls.cassette_dir / "SC-subjects.xls",
cls.cassette_dir / "ST-subjects.xls",
]
for path in spreadsheet_targets:
# These files are only test placeholders; metadata is loaded from
# sleepedf-cassette-pyhealth.csv created below.
df.to_csv(path, index=False)

@classmethod
def _create_dummy_patient_files(cls):
"""Create two patients with 2 x 3000 dummy points each."""
# Expected cassette metadata rows in `sleepedf-cassette-pyhealth.csv` look like:
# subject,night,age,sex,lights_off,signal_file,label_file
# 1,1,25,F,22:00,<...>/SC4011E0-PSG.edf,<...>/SC4011E0-Hypnogram.edf
# 2,2,30,M,22:30,<...>/SC4022E0-PSG.edf,<...>/SC4022E0-Hypnogram.edf
#
# This helper creates those referenced PSG/Hypnogram files so SleepEDFDataset
# can load events and MVCLTrainingSleepEEG can read per-patient signal/labels.
patient_records = [
("SC4011E0", 1), # subject 01, night 1
("SC4022E0", 2), # subject 02, night 2
]
for stem, seed in patient_records:
signal = np.full(6000, fill_value=seed, dtype=np.float32)
for suffix in ("-PSG.edf", "-Hypnogram.edf"):
file_path = cls.cassette_dir / f"{stem}{suffix}"
signal.tofile(file_path)

@classmethod
def _create_dummy_metadata_csv(cls):
rows = [
{
"subject": 1,
"night": 1,
"age": 25,
"sex": "F",
"lights_off": "22:00",
"signal_file": str(cls.cassette_dir / "SC4011E0-PSG.edf"),
"label_file": str(cls.cassette_dir / "SC4011E0-Hypnogram.edf"),
},
{
"subject": 2,
"night": 2,
"age": 30,
"sex": "M",
"lights_off": "22:30",
"signal_file": str(cls.cassette_dir / "SC4022E0-PSG.edf"),
"label_file": str(cls.cassette_dir / "SC4022E0-Hypnogram.edf"),
},
]
pd.DataFrame(rows).to_csv(
cls.dummy_dataset_dir / "sleepedf-cassette-pyhealth.csv", index=False
)

@staticmethod
def _mock_read_raw_edf(signal_file, *args, **kwargs):
"""Load the binary dummy payload from .edf placeholder file."""
signal = np.fromfile(signal_file, dtype=np.float32)
if signal.size != 6000:
raise ValueError(f"Expected 6000 points in {signal_file}, got {signal.size}")
data = signal.reshape(1, -1)
info = mne.create_info(["EEG Fpz-Cz"], sfreq=100, ch_types=["eeg"])
return mne.io.RawArray(data, info, verbose="error")

@staticmethod
def _mock_read_annotations(label_file, *args, **kwargs):
"""Return two 30-second sleep-stage annotations per patient."""
name = Path(label_file).name
if "SC4011E0" in name:
descriptions = ["Sleep stage W", "Sleep stage R"]
elif "SC4022E0" in name:
descriptions = ["Sleep stage 2", "Sleep stage 4"]
else:
raise ValueError(f"Unexpected label file: {label_file}")
return mne.Annotations(
onset=[0.0, 30.0],
duration=[30.0, 30.0],
description=descriptions,
)

def test_import_from_pyhealth_tasks(self):
"""Matches notebook usage: from pyhealth.tasks import MVCLTrainingSleepEEG."""
self.assertTrue(callable(MVCLTrainingSleepEEG))

def test_sleepedf_dummy_dataset_label_mapping(self):
dataset = SleepEDFDataset(root=str(self.dummy_dataset_dir), subset="cassette")
task = MVCLTrainingSleepEEG(window_size=200, crop_length=178, eeg_channel="EEG Fpz-Cz")

with patch(
"pyhealth.tasks.mvcl_training_sleepedf_task.mne.io.read_raw_edf",
side_effect=self._mock_read_raw_edf,
), patch(
"pyhealth.tasks.mvcl_training_sleepedf_task.mne.read_annotations",
side_effect=self._mock_read_annotations,
):
sample_dataset = dataset.set_task(task, num_workers=1)

# 2 patients x 2 epochs each x (3000 / 200) windows = 60 windows
self.assertEqual(len(sample_dataset), 60)
self.assertEqual(sample_dataset.input_schema, {"xt": "tensor", "xd": "tensor", "xf": "tensor"})
self.assertEqual(sample_dataset.output_schema, {"label": "multiclass"})

sample = sample_dataset[0]
for key in ("xt", "xd", "xf"):
self.assertIn(key, sample)
# MV views are stored as [L, C] in the task; enforce equivalent 1x178 content.
for key in ("xt", "xd", "xf"):
self.assertEqual(sample[key].ndim, 2)
self.assertIn(1, sample[key].shape)
self.assertIn(178, sample[key].shape)

# set_task() encodes multiclass labels to contiguous ids, but class balance
# should still match the four injected stages (W, R, 2, 4) => 15 windows each.
label_counts = Counter(int(s["label"]) for s in sample_dataset)
self.assertEqual(len(label_counts), 4)
self.assertTrue(all(count == 15 for count in label_counts.values()))


if __name__ == "__main__":
unittest.main()