## Sleep stage classification from polysomnography (PSG) data

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import mne
from mne.datasets.sleep_physionet.age import fetch_data

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer

from utils import load_psg_samples, DATA_READERS, eegs_power_band

## Load PSG data

In [2]:
n_subjects = 10
subject_ids = list(range(n_subjects))
sleep_data_files = fetch_data(subjects=subject_ids, recording=[1])

data_format = "edf"
channel = "eeg"

raw = load_psg_samples(sleep_data_files, DATA_READERS[data_format], picks=channel)

Using default location ~/mne_data for PHYSIONET_SLEEP...
Extracting EDF parameters from /Users/maxverbiest/mne_data/physionet-sleep-data/SC4001E0-PSG.edf...
EDF file detected
Channel 'EEG Fpz-Cz' recognized as type EEG (renamed to 'Fpz-Cz').
Channel 'EEG Pz-Oz' recognized as type EEG (renamed to 'Pz-Oz').
Channel 'EOG horizontal' recognized as type EOG (renamed to 'horizontal').
Channel 'Resp oro-nasal' recognized as type RESP (renamed to 'oro-nasal').
Channel 'EMG submental' recognized as type EMG (renamed to 'submental').
Channel 'Temp rectal' recognized as type TEMP (renamed to 'rectal').
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 7949999  =      0.000 ... 79499.990 secs...
Extracting EDF parameters from /Users/maxverbiest/mne_data/physionet-sleep-data/SC4011E0-PSG.edf...
EDF file detected
Channel 'EEG Fpz-Cz' recognized as type EEG (renamed to 'Fpz-Cz').
Channel 'EEG Pz-Oz' recognized as type EEG (renamed to 'Pz-Oz').
Channel 'EOG horizontal' rec

In [3]:
annotation_desc_2_event_id = {
    "Sleep stage W": 0,
    "Sleep stage 1": 1,
    "Sleep stage 2": 2,
    "Sleep stage 3": 3,
    "Sleep stage 4": 3,
    "Sleep stage R": 4,
}
event_id = {
    "Sleep stage W": 0,
    "Sleep stage 1": 1,
    "Sleep stage 2": 2,
    "Sleep stage 3/4": 3,
    "Sleep stage R": 4,
}

epochs = []
# keep last 30-min wake events before sleep and first 30-min wake events after
# sleep and redefine annotations on raw data
for edf, annotation in raw:
    annotation.crop(annotation[1]["onset"] - (30 * 60), annotation[-2]["onset"] + (30 * 60))
    edf.set_annotations(annotation, emit_warning=False)
    events, _ = mne.events_from_annotations(
        edf, event_id=annotation_desc_2_event_id, chunk_duration=30.0
    )

    tmax = 30.0 - (1.0 / edf.info["sfreq"])  # tmax is included
    epoched = mne.Epochs(
        raw=edf,
        events=events,
        event_id=event_id,
        tmin=0.0,
        tmax=tmax,
        baseline=None,
    )
    epochs.append(epoched)

del raw
epochs

Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
841 matching events found
No baseline correction applied
0 projection items activated
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
1103 matching events found
No baseline correction applied
0 projection items activated
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
1025 matching events found
No baseline correction applied
0 projection items activated
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
952 matching events found
No baseline correction applied
0 projection items activated
Used Annotations descriptions: ['S

[<Epochs |  841 events (good & bad), 0 – 29.99 s, baseline off, ~8 kB, data not loaded,
  'Sleep stage W': 188
  'Sleep stage 1': 58
  'Sleep stage 2': 250
  'Sleep stage 3/4': 220
  'Sleep stage R': 125>,
 <Epochs |  1103 events (good & bad), 0 – 29.99 s, baseline off, ~8 kB, data not loaded,
  'Sleep stage W': 157
  'Sleep stage 1': 109
  'Sleep stage 2': 562
  'Sleep stage 3/4': 105
  'Sleep stage R': 170>,
 <Epochs |  1025 events (good & bad), 0 – 29.99 s, baseline off, ~8 kB, data not loaded,
  'Sleep stage W': 128
  'Sleep stage 1': 94
  'Sleep stage 2': 545
  'Sleep stage 3/4': 95
  'Sleep stage R': 163>,
 <Epochs |  952 events (good & bad), 0 – 29.99 s, baseline off, ~8 kB, data not loaded,
  'Sleep stage W': 140
  'Sleep stage 1': 61
  'Sleep stage 2': 485
  'Sleep stage 3/4': 57
  'Sleep stage R': 209>,
 <Epochs |  1235 events (good & bad), 0 – 29.99 s, baseline off, ~8 kB, data not loaded,
  'Sleep stage W': 200
  'Sleep stage 1': 166
  'Sleep stage 2': 620
  'Sleep stage 3/

In [9]:
pipe = make_pipeline(
    FunctionTransformer(eegs_power_band, validate=False),
    RandomForestClassifier(n_estimators=100, random_state=42),
)

# Train
y_train = np.concatenate([e.events[:, 2] for e in epochs[1:]])
pipe.fit(epochs[1:], y_train)

# Test
y_pred = pipe.predict([epochs[0]])

# Assess the results
y_test = epochs[0].events[:, 2]
acc = accuracy_score(y_test, y_pred)

print("Accuracy score: {}".format(acc))

Using data from preloaded Raw for 1103 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Using data from preloaded Raw for 1025 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Using data from preloaded Raw for 952 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Using data from preloaded Raw for 1235 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Using data from preloaded Raw for 672 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Using data from preloaded Raw for 843 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Using data from preloaded Raw f

In [10]:
confusion_matrix(y_test, y_pred)

array([[131,  26,   5,   1,  25],
       [  1,   8,   9,   0,  40],
       [  1,   1, 209,  36,   3],
       [  0,   0,  19, 201,   0],
       [  1,  14,   7,   0, 103]])

In [11]:
print(classification_report(y_test, y_pred, target_names=event_id.keys()))

                 precision    recall  f1-score   support

  Sleep stage W       0.98      0.70      0.81       188
  Sleep stage 1       0.16      0.14      0.15        58
  Sleep stage 2       0.84      0.84      0.84       250
Sleep stage 3/4       0.84      0.91      0.88       220
  Sleep stage R       0.60      0.82      0.70       125

       accuracy                           0.78       841
      macro avg       0.69      0.68      0.67       841
   weighted avg       0.79      0.78      0.77       841



In [12]:
np.unique(y_test, return_counts=True)

(array([0, 1, 2, 3, 4]), array([188,  58, 250, 220, 125]))