In [31]:
import os

import numpy as np
import pandas as pd
from tqdm import tqdm
import mne
from sklearn.utils import check_random_state


from braindecode.datasets import BaseConcatDataset, BaseDataset


path = '/scratch/hmc/recordings'
files = os.listdir(path)

window_size_s = 30
sfreq = 100
window_size_samples = window_size_s * sfreq


In [7]:
data, edf = [], []

for i in range(1, 155):
    current_file = f'SN{i:03d}.edf'
    if  current_file in files:
        data.append(os.path.join(path, current_file))
        edf.append(os.path.join(path, f'SN{i:03d}_sleepscoring.edf'))
        


In [8]:
raw = mne.io.read_raw_edf(data[0])
annots = mne.read_annotations(edf[0])
raw.set_annotations(annots)

Extracting EDF parameters from /scratch/hmc/recordings/SN001.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


0,1
Measurement date,"January 01, 2001 23:59:30 GMT"
Experimenter,Unknown
Digitized points,Not available
Good channels,8 EEG
Bad channels,
EOG channels,Not available
ECG channels,Not available
Sampling frequency,256.00 Hz
Highpass,0.20 Hz
Lowpass,35.00 Hz


In [143]:
class HMCSleepStaging(BaseConcatDataset):
    def __init__(
        self,
        hmc_path='/scratch/hmc/recordings',
        subject_ids=None,
        preload=True,
        load_modality='eeg',
        crop_wake_mins=0,
        crop=None,
    ):
        if subject_ids is None:
            subject_ids = range(1, 155)
        
        self.raw_files, self.edf_files = [], []       
        self._fetch_data(subject_ids, hmc_path)

        all_base_ds = list()
        for raw_fname, ann_fname in zip(self.raw_files, self.edf_files):
            raw, desc = self._load_raw(
                raw_fname,
                ann_fname,
                preload=preload,
                load_modality=load_modality,
                crop_wake_mins=crop_wake_mins,
                crop=crop
            )
            base_ds = BaseDataset(raw, desc)
            all_base_ds.append(base_ds)
        super().__init__(all_base_ds)
    
    def _fetch_data(
        self,
        subject_ids,
        hmc_path,
    ):
        hmc_files = os.listdir(hmc_path)
        for subject in subject_ids:
            current_file = f'SN{subject:03d}.edf'
            if  current_file in hmc_files:
                self.raw_files.append(os.path.join(hmc_path, current_file))
                self.edf_files.append(os.path.join(hmc_path, f'SN{subject:03d}_sleepscoring.edf'))        
    
    @staticmethod
    def _load_raw(
        raw_fname,
        ann_fname,
        preload,
        load_modality,
        crop_wake_mins,
        crop,
    ):
        ch_mapping = ['EEG F4-M1',
             'EEG C4-M1',
             'EEG O2-M1',
             'EEG C3-M2',
             'EMG chin',
             'EOG E1-M2',
             'EOG E2-M2',
             'ECG',
         ]
        if load_modality == 'eeg':
            include = ch_mapping[:4]
        elif load_modality == 'eog':
            include = ch_mapping[5:7]
        elif load_modality == 'emg':
            include = ch_mapping[4:5]
        elif load_modality == 'ecg':
            include = ch_mapping[7:]
        elif load_modality == 'emog':
            include = ch_mapping[4:7]
        else:
            raise Exception("Select a modality")

        raw = mne.io.read_raw_edf(raw_fname, preload=preload, include=include)
        annots = mne.read_annotations(ann_fname)
        raw.set_annotations(annots, emit_warning=False)
        raw.resample(100, npad="auto")

        if crop_wake_mins > 0:
            # Find first and last sleep stages
            mask = [x[-1] in ["1", "2", "3", "R"] for x in annots.description]
            sleep_event_inds = np.where(mask)[0]

            # Crop raw
            tmin = annots[int(sleep_event_inds[0])]["onset"] - crop_wake_mins * 60
            tmax = annots[int(sleep_event_inds[-1])]["onset"] + crop_wake_mins * 60
            raw.crop(tmin=max(tmin, raw.times[0]), tmax=min(tmax, raw.times[-1]))

        if crop is not None:
            raw.crop(*crop)

        basename = os.path.basename(raw_fname)
        subj_nb = int(basename[2:5])
        desc = pd.Series({"subject": subj_nb,}, name="")

        return raw, desc
subject_ids = range(1, 10)
dataset = HMCSleepStaging(subject_ids=subject_ids)

mapping = {  
    "Sleep stage W": 0,
    "Sleep stage N1": 1,
    "Sleep stage N2": 2,
    "Sleep stage N3": 3,
    "Sleep stage R": 4,
}

from braindecode.datautil.windowers import create_windows_from_events, create_fixed_length_windows

windows_dataset = create_windows_from_events(
    dataset,
    window_size_samples=window_size_samples,
    window_stride_samples=window_size_samples,
    preload= True,
    mapping=mapping,
    n_jobs=-1
)

Extracting EDF parameters from /scratch/hmc/recordings/SN001.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6566399  =      0.000 ... 25649.996 secs...
Extracting EDF parameters from /scratch/hmc/recordings/SN002.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6577407  =      0.000 ... 25692.996 secs...
Extracting EDF parameters from /scratch/hmc/recordings/SN003.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 7330815  =      0.000 ... 28635.996 secs...
Extracting EDF parameters from /scratch/hmc/recordings/SN004.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 7804927  =      0.000 ... 30487.996 secs...
Extracting EDF parameters from /scratch/hmc/recordings/SN005.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ...

In [146]:
for windows_subject in windows_dataset.datasets:
    print(windows_subject)
    break

(array([[ 1.7087381e-05,  1.4344381e-05,  1.1525641e-06, ...,
         1.2500958e-05,  2.5288666e-05,  2.6783413e-05],
       [ 1.4664362e-05,  7.3353440e-06,  2.4618416e-06, ...,
         4.4581006e-07,  1.2948366e-05,  1.7326725e-05],
       [ 3.8281894e-05,  2.0248757e-05,  2.4458244e-05, ...,
         7.9606389e-06,  1.8057051e-05,  2.1296686e-05],
       [ 5.8028559e-06,  6.9939942e-06, -2.9917583e-06, ...,
        -1.0429584e-05, -8.3001996e-06, -1.2006130e-06]], dtype=float32), 0, [0, 0, 3000])


In [174]:
a = windows_dataset.datasets[0]


for i in a.windows:
    print(i.shape)
    break

(4, 3000)


In [197]:
df = pd.DataFrame(columns = ['a', 'b'])

for i in range(10):
    df = df.append({'a':10*i, 'b': 8*i}, ignore_index=True)
df

  df = df.append({'a':10*i, 'b': 8*i}, ignore_index=True)
  df = df.append({'a':10*i, 'b': 8*i}, ignore_index=True)
  df = df.append({'a':10*i, 'b': 8*i}, ignore_index=True)
  df = df.append({'a':10*i, 'b': 8*i}, ignore_index=True)
  df = df.append({'a':10*i, 'b': 8*i}, ignore_index=True)
  df = df.append({'a':10*i, 'b': 8*i}, ignore_index=True)
  df = df.append({'a':10*i, 'b': 8*i}, ignore_index=True)
  df = df.append({'a':10*i, 'b': 8*i}, ignore_index=True)
  df = df.append({'a':10*i, 'b': 8*i}, ignore_index=True)
  df = df.append({'a':10*i, 'b': 8*i}, ignore_index=True)


Unnamed: 0,a,b
0,0,0
1,10,8
2,20,16
3,30,24
4,40,32
5,50,40
6,60,48
7,70,56
8,80,64
9,90,72


In [116]:
from torch.utils.data import DataLoader, Dataset
from braindecode.datautil.windowers import create_windows_from_events, create_fixed_length_windows

windows_dataset = create_fixed_length_windows(
    dataset,
    window_size_samples=window_size_samples * 7,
    window_stride_samples=window_size_samples,
    drop_last_window=True,
    # mapping=mapping,
)


dl = DataLoader(
    windows_dataset,
    batch_size=4,
    shuffle=False,
)

Adding metadata with 4 columns
849 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 849 events and 21000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
850 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 850 events and 21000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
948 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 948 events and 21000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
1010 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 1010 events and 21000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
953 matching events found
No baseline correction applied
0 projection items activated
Using dat

In [118]:
windows_dataset

Using data from preloaded Raw for 1 events and 21000 original time points ...


(4, 21000)

In [141]:
class test(Dataset):
    def __init__(self):
        self.x =0
    def __getitem__(self, index):
        return index
    def __len__(self):
        return 16*10

dll = DataLoader(test(), batch_size=16, shuffle=True)

a = []
for i in dll:
    a.extend(list(i.tolist()))
    
a.sort()
a

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159]

In [102]:
annots.description

array(['Sleep stage W', 'Sleep stage W', 'Lights off@@EEG F4-A1',
       'Sleep stage W', 'Sleep stage W', 'Sleep stage W', 'Sleep stage W',
       'Sleep stage W', 'Sleep stage W', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N2', 'Sleep stage N1',
       'Sleep stage N2', 'Sleep stage N2', 'Sleep stage N2',
       'Sleep stage N2', 'Sleep stage N2', 'Sleep stage N2',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage W',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N2', 'Sleep stage N1', 'Sleep stage W',
       'Sleep stage W', 'Sleep stage W', 'Sleep stage W',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Slee