In [31]:
import os

import numpy as np
import pandas as pd
from tqdm import tqdm
import mne
from sklearn.utils import check_random_state


from braindecode.datasets import BaseConcatDataset, BaseDataset


path = '/scratch/hmc/recordings'
files = os.listdir(path)

window_size_s = 30
sfreq = 100
window_size_samples = window_size_s * sfreq


In [7]:
data, edf = [], []

for i in range(1, 155):
    current_file = f'SN{i:03d}.edf'
    if  current_file in files:
        data.append(os.path.join(path, current_file))
        edf.append(os.path.join(path, f'SN{i:03d}_sleepscoring.edf'))
        


In [8]:
raw = mne.io.read_raw_edf(data[0])
annots = mne.read_annotations(edf[0])
raw.set_annotations(annots)

Extracting EDF parameters from /scratch/hmc/recordings/SN001.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


0,1
Measurement date,"January 01, 2001 23:59:30 GMT"
Experimenter,Unknown
Digitized points,Not available
Good channels,8 EEG
Bad channels,
EOG channels,Not available
ECG channels,Not available
Sampling frequency,256.00 Hz
Highpass,0.20 Hz
Lowpass,35.00 Hz


In [103]:
class HMCSleepStaging(BaseConcatDataset):
    def __init__(
        self,
        hmc_path='/scratch/hmc/recordings',
        subject_ids=None,
        preload=True,
        load_modality='eeg',
        crop_wake_mins=0,
        crop=None,
    ):
        if subject_ids is None:
            subject_ids = range(1, 155)
        
        self.raw_files, self.edf_files = [], []       
        self._fetch_data(subject_ids, hmc_path)

        all_base_ds = list()
        for raw_fname, ann_fname in zip(self.raw_files, self.edf_files):
            raw, desc = self._load_raw(
                raw_fname,
                ann_fname,
                preload=preload,
                load_modality=load_modality,
                crop_wake_mins=crop_wake_mins,
                crop=crop
            )
            base_ds = BaseDataset(raw, desc)
            all_base_ds.append(base_ds)
        super().__init__(all_base_ds)
    
    def _fetch_data(
        self,
        subject_ids,
        hmc_path,
    ):
        hmc_files = os.listdir(hmc_path)
        for subject in subject_ids:
            current_file = f'SN{subject:03d}.edf'
            if  current_file in hmc_files:
                self.raw_files.append(os.path.join(hmc_path, current_file))
                self.edf_files.append(os.path.join(hmc_path, f'SN{subject:03d}_sleepscoring.edf'))        
    
    @staticmethod
    def _load_raw(
        raw_fname,
        ann_fname,
        preload,
        load_modality,
        crop_wake_mins,
        crop,
    ):
        ch_mapping = ['EEG F4-M1',
             'EEG C4-M1',
             'EEG O2-M1',
             'EEG C3-M2',
             'EMG chin',
             'EOG E1-M2',
             'EOG E2-M2',
             'ECG',
         ]
        if load_modality == 'eeg':
            include = ch_mapping[:4]
        elif load_modality == 'eog':
            include = ch_mapping[5:7]
        elif load_modality == 'emg':
            include = ch_mapping[4:5]
        elif load_modality == 'ecg':
            include = ch_mapping[7:]
        elif load_modality == 'emog':
            include = ch_mapping[4:7]
        else:
            raise Exception("Select a modality")

        raw = mne.io.read_raw_edf(raw_fname, preload=preload, include=include)
        annots = mne.read_annotations(ann_fname)
        raw.set_annotations(annots, emit_warning=False)
        raw.resample(100, npad="auto")

        if crop_wake_mins > 0:
            # Find first and last sleep stages
            mask = [x[-1] in ["1", "2", "3", "R"] for x in annots.description]
            sleep_event_inds = np.where(mask)[0]

            # Crop raw
            tmin = annots[int(sleep_event_inds[0])]["onset"] - crop_wake_mins * 60
            tmax = annots[int(sleep_event_inds[-1])]["onset"] + crop_wake_mins * 60
            raw.crop(tmin=max(tmin, raw.times[0]), tmax=min(tmax, raw.times[-1]))

        if crop is not None:
            raw.crop(*crop)

        basename = os.path.basename(raw_fname)
        subj_nb = int(basename[2:5])
        desc = pd.Series({"subject": subj_nb,}, name="")

        return raw, desc
subject_ids = range(1, 10)
dataset = HMCSleepStaging(subject_ids=subject_ids)

mapping = {  
    "Sleep stage W": 0,
    "Sleep stage N1": 1,
    "Sleep stage N2": 2,
    "Sleep stage N3": 3,
    "Sleep stage R": 4,
}

from braindecode.datautil.windowers import create_windows_from_events, create_fixed_length_windows

windows_dataset = create_windows_from_events(
    dataset,
    window_size_samples=window_size_samples * 7,
    window_stride_samples=window_size_samples,
    drop_last_window=True,
    preload= True,
    mapping=mapping,
    n_jobs=-1
)

Extracting EDF parameters from /scratch/hmc/recordings/SN001.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6566399  =      0.000 ... 25649.996 secs...
Extracting EDF parameters from /scratch/hmc/recordings/SN002.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6577407  =      0.000 ... 25692.996 secs...
Extracting EDF parameters from /scratch/hmc/recordings/SN003.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 7330815  =      0.000 ... 28635.996 secs...
Extracting EDF parameters from /scratch/hmc/recordings/SN004.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 7804927  =      0.000 ... 30487.996 secs...
Extracting EDF parameters from /scratch/hmc/recordings/SN005.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ...

ValueError: Window size 21000 exceeds trial duration (3000) for too many trials (100.0%). Set accepted_bads_ratio to at least 1.0and restart training to be able to continue.

In [109]:
from torch.utils.data import DataLoader, Dataset
from braindecode.datautil.windowers import create_windows_from_events, create_fixed_length_windows

windows_dataset = create_fixed_length_windows(
    dataset,
    window_size_samples=window_size_samples * 7,
    window_stride_samples=window_size_samples,
    drop_last_window=True,
    preload= True,
    mapping=mapping,
)


dl = DataLoader(
    windows_dataset,
    batch_size=4,
    shuffle=False,
)

KeyError: -1

In [89]:
windows_dataset[5]

(array([[-5.8949404e-06,  4.3617911e-06,  2.4263863e-06, ...,
          3.2385326e-06,  7.2328812e-06,  7.9230222e-06],
        [-6.9793809e-06, -4.4613449e-07,  4.7712369e-06, ...,
          1.6334093e-07,  3.6747072e-06,  5.0347580e-06],
        [ 1.5452737e-06,  1.1963056e-05,  1.4000091e-05, ...,
          4.1262833e-06,  6.9388793e-06,  9.6808590e-06],
        [-2.4760914e-06,  1.5252542e-06,  2.5595805e-06, ...,
          8.8233528e-06,  1.5575761e-05,  1.4216160e-05]], dtype=float32),
 0,
 [0, 15000, 18000])

In [97]:
windows_dataset.datasets[0].y[:25]

[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 1]

In [102]:
annots.description

array(['Sleep stage W', 'Sleep stage W', 'Lights off@@EEG F4-A1',
       'Sleep stage W', 'Sleep stage W', 'Sleep stage W', 'Sleep stage W',
       'Sleep stage W', 'Sleep stage W', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N2', 'Sleep stage N1',
       'Sleep stage N2', 'Sleep stage N2', 'Sleep stage N2',
       'Sleep stage N2', 'Sleep stage N2', 'Sleep stage N2',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage W',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N2', 'Sleep stage N1', 'Sleep stage W',
       'Sleep stage W', 'Sleep stage W', 'Sleep stage W',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Slee