# Dataset and DataLoader

This notebook loads the `CAUEEG` dataset, tests some useful preprocessing, and makes up the PyTorch DataLoader instances for the training.

-----

## Configurations

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
# Load some packages
import os
import glob
import json
import pprint

import numpy as np
import random
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

# custom package
from datasets.caueeg_dataset import *
from datasets.caueeg_script import *
from datasets.pipeline import *

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.11.0+cu113
cuda is available.


In [4]:
# Data file path
data_path = r'local/dataset/02_Curated_Data_220412/'

-----

## Load the CAUEEG dataset

### Load the whole CAUEEG data as a PyTorch dataset instance without considering the target task (no train/val/test sets and no class label).

In [5]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='edf',
                                                         transform=None)

pprint.pprint(config_data, width=250)
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[1])
print('\n', '-' * 100, '\n')

{'dataset_name': 'CAUEEG dataset',
 'signal_header': ['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 78,
 'serial': '00001',
 'signal': array([[  0., -11., -13., ...,   0.,   0.,   0.],
       [ 29.,  33.,  34., ...,   0.,   0.,   0.],
       [ -3.,  -6.,  -3., ...,   0.,   0.,   0.],
       ...,
       [ -4.,  -2.,   1., ...,   0.,   0.,   0.],
       [112.,  67.,  76., ...,   0.,   0.,   0.],
       [ -1.,  -1.,  -1., ...,   0.,   0.,   0.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 56,
 'serial': '00002',
 'signal': array([[  39.,   58.,   72., ...,    0.,    0.,    0.],
       [   4.,

### Load the CAUEEG task1 datasets as the PyTorch dataset instances.

In [6]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')

pprint.pprint(train_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(val_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal], [MCI], and [Dementia] '
                     'symptoms.',
 'task_name': 'CAUEEG-task1 benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 81,
 'class_label': 2,
 'class_name': 'Dementia',
 'serial': '01222',
 'signal': array([[  1.,   1.,  -1., ...,   0.,   0.,   0.],
       [-24., -14., -14., ...,   0.,   0.,   0.],
       [ -6.,  -5.,  -6., ...,   0.,   0.,   0.],
       ...,
       [ -1.,   6.,   6., ...,   0.,   0.,   0.],
       [-42., -62., -59., ...,   0.,   0.,   0.],
       [  0.,   0.,  -1., ...,   0.,   0.,   0.]]),
 'symptom': ['dementia', 'ad', 'load']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 74,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00857',
 'si

### Load the CAUEEG task2 datasets as the PyTorch dataset instances.

In [7]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')

pprint.pprint(train_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(val_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'Abnormal'],
 'class_name_to_label': {'Abnormal': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal] and [Abnormal] symptoms',
 'task_name': 'CAUEEG-task2 benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 66,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00085',
 'signal': array([[ 39.,  87.,  85., ...,   0.,   0.,   0.],
       [ 17.,  64.,  67., ...,   0.,   0.,   0.],
       [ -1.,   9.,  12., ...,   0.,   0.,   0.],
       ...,
       [ 10.,   3.,   3., ...,   0.,   0.,   0.],
       [ 30.,  46., -30., ...,   0.,   0.,   0.],
       [  2.,   2.,   1., ...,   0.,   0.,   0.]]),
 'symptom': ['normal', 'cb_normal']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 79,
 'class_label': 1,
 'class_name': 'Abnormal',
 'serial': '01203',
 'signal': array([[117., 165., 163., ...,   0.,   0.,   0

### Event information

In [8]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=True, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(train_dataset[0])

{'age': 81,
 'class_label': 2,
 'class_name': 'Dementia',
 'event': [[0, 'Start Recording'],
           [0, 'New Montage - Montage 002'],
           [13926, 'Eyes Open'],
           [14562, 'Eyes Closed'],
           [16000, 'Paused'],
           [18000, 'Recording Resumed'],
           [20984, 'Eyes Open'],
           [21898, 'Eyes Closed'],
           [23258, 'Eyes Open'],
           [23866, 'Eyes Closed'],
           [38064, 'Eyes Open'],
           [53800, 'Paused'],
           [55600, 'Recording Resumed'],
           [62090, 'artifact'],
           [62800, 'Paused'],
           [64600, 'Recording Resumed'],
           [66332, 'Move'],
           [70290, 'Move'],
           [71122, 'Move'],
           [73030, 'Move'],
           [73762, 'Move'],
           [74600, 'Paused'],
           [86600, 'Recording Resumed'],
           [88970, 'swallowing'],
           [90020, 'Eyes Closed'],
           [91490, 'Eyes Open'],
           [92708, 'Eyes Closed'],
           [95124, 'Eyes Open'],

### Data Format: `EDF`

In [9]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'signal': array([[ 39.,  87.,  85., ...,   0.,   0.,   0.],
       [ 17.,  64.,  67., ...,   0.,   0.,   0.],
       [ -1.,   9.,  12., ...,   0.,   0.,   0.],
       ...,
       [ 10.,   3.,   3., ...,   0.,   0.,   0.],
       [ 30.,  46., -30., ...,   0.,   0.,   0.],
       [  2.,   2.,   1., ...,   0.,   0.,   0.]]), 'serial': '00085', 'age': 66, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0}
{'signal': array([[  45.,   -7.,   -9., ...,    0.,    0.,    0.],
       [   1.,   -2.,  -11., ...,    0.,    0.,    0.],
       [   2.,   -6.,   -5., ...,    0.,    0.,    0.],
       ...,
       [  -4.,    2.,    4., ...,    0.,    0.,    0.],
       [ -71., -115.,  -98., ...,    0.,    0.,    0.],
       [   0.,    0.,   -1., ...,    0.,    0.,    0.]]), 'serial': '00790', 'age': 59, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0}
CPU times: total: 188 ms
Wall time: 188 ms


### Data Format: `PyArrow Feather`

In [10]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'signal': array([[ 39,  87,  85, ..., -52, -49, -45],
       [ 17,  64,  67, ..., -62, -58, -56],
       [ -1,   9,  12, ...,  -8,  -7,  -6],
       ...,
       [ 10,   3,   3, ...,  -1,   0,   2],
       [ 30,  46, -30, ...,  -6, -17, -15],
       [  2,   2,   1, ...,   1,   0,   0]]), 'serial': '00085', 'age': 66, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0}
{'signal': array([[  45,   -7,   -9, ...,   15,   18,   16],
       [   1,   -2,  -11, ...,  -15,  -14,  -14],
       [   2,   -6,   -5, ...,   12,   10,   11],
       ...,
       [  -4,    2,    4, ...,   -2,   -1,   -1],
       [ -71, -115,  -98, ...,    3,    5,   -2],
       [   0,    0,   -1, ...,    1,   -1,    0]]), 'serial': '00790', 'age': 59, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0}
CPU times: total: 109 ms
Wall time: 17 ms


### Data Format: `NumPy Memmap`

In [11]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'signal': memmap([[ 39,  87,  85, ..., -52, -49, -45],
        [ 17,  64,  67, ..., -62, -58, -56],
        [ -1,   9,  12, ...,  -8,  -7,  -6],
        ...,
        [ 10,   3,   3, ...,  -1,   0,   2],
        [ 30,  46, -30, ...,  -6, -17, -15],
        [  2,   2,   1, ...,   1,   0,   0]]), 'serial': '00085', 'age': 66, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0}
{'signal': memmap([[  45,   -7,   -9, ...,   15,   18,   16],
        [   1,   -2,  -11, ...,  -15,  -14,  -14],
        [   2,   -6,   -5, ...,   12,   10,   11],
        ...,
        [  -4,    2,    4, ...,   -2,   -1,   -1],
        [ -71, -115,  -98, ...,    3,    5,   -2],
        [   0,    0,   -1, ...,    1,   -1,    0]]), 'serial': '00790', 'age': 59, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0}
CPU times: total: 0 ns
Wall time: 4 ms


---

## PyTorch Transforms

### Drop channel(s)

In [12]:
anno_path = os.path.join(data_path, 'annotation.json')
with open(anno_path, 'r') as json_file:
    annotation = json.load(json_file)
signal_headers = annotation['signal_header']
del annotation
print(signal_headers)

channel_ekg = signal_headers.index('EKG')
print('channel_ekg: ', channel_ekg)

channel_photic = signal_headers.index('Photic')
print('channel_photic: ', channel_photic)

['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']
channel_ekg:  19
channel_photic:  20


In [13]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=EegDropChannels(channel_ekg))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 156600)
[[  1   1  -1 ...   2   3   3]
 [-24 -14 -14 ... -12  -9  -8]
 [ -6  -5  -6 ...  -3  -1  -1]
 ...
 [ -1   6   6 ...  -1  -2  -2]
 [-42 -62 -59 ... -12 -14 -16]
 [  0   0  -1 ...  -1   0   0]]

----------------------------------------------------------------------------------------------------

after: (20, 156600)
[[  1   1  -1 ...   2   3   3]
 [-24 -14 -14 ... -12  -9  -8]
 [ -6  -5  -6 ...  -3  -1  -1]
 ...
 [ 14  22  23 ...  -5  -4  -3]
 [ -1   6   6 ...  -1  -2  -2]
 [  0   0  -1 ...  -1   0   0]]


In [14]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=EegDropChannels(channel_photic))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 156600)
[[  1   1  -1 ...   2   3   3]
 [-24 -14 -14 ... -12  -9  -8]
 [ -6  -5  -6 ...  -3  -1  -1]
 ...
 [ -1   6   6 ...  -1  -2  -2]
 [-42 -62 -59 ... -12 -14 -16]
 [  0   0  -1 ...  -1   0   0]]

----------------------------------------------------------------------------------------------------

after: (20, 156600)
[[  1   1  -1 ...   2   3   3]
 [-24 -14 -14 ... -12  -9  -8]
 [ -6  -5  -6 ...  -3  -1  -1]
 ...
 [ 14  22  23 ...  -5  -4  -3]
 [ -1   6   6 ...  -1  -2  -2]
 [-42 -62 -59 ... -12 -14 -16]]


In [15]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=EegDropChannels([channel_ekg, channel_photic]))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 156600)
[[  1   1  -1 ...   2   3   3]
 [-24 -14 -14 ... -12  -9  -8]
 [ -6  -5  -6 ...  -3  -1  -1]
 ...
 [ -1   6   6 ...  -1  -2  -2]
 [-42 -62 -59 ... -12 -14 -16]
 [  0   0  -1 ...  -1   0   0]]

----------------------------------------------------------------------------------------------------

after: (19, 156600)
[[  1   1  -1 ...   2   3   3]
 [-24 -14 -14 ... -12  -9  -8]
 [ -6  -5  -6 ...  -3  -1  -1]
 ...
 [  9  18  19 ...  -3  -2  -1]
 [ 14  22  23 ...  -5  -4  -3]
 [ -1   6   6 ...  -1  -2  -2]]


### Random crop

In [16]:
transform = EegRandomCrop(crop_length=100)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path,
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', d['signal'].shape)
    print('\n', '-' * 100, '\n')

{'age': 66,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00085',
 'signal': array([[-61, -60, -60, ..., -45, -46, -45],
       [ 17,  17,  18, ...,  -4,  -6,  -6],
       [  0,   1,   2, ...,   6,   6,   7],
       ...,
       [ -2,   0,   0, ...,   8,   9,   7],
       [ -6,  -2,   7, ...,  -8,   1,   5],
       [  0,   0,   0, ...,   2,   0,  -1]]),
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: (21, 100)

 ---------------------------------------------------------------------------------------------------- 

{'age': 66,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00085',
 'signal': array([[-46, -48, -54, ...,  -2,  -3,  -2],
       [ -1,  -1,  -4, ..., -37, -38, -39],
       [-11, -10, -11, ...,   0,  -1,  -2],
       ...,
       [  8,   8,  11, ...,   8,   7,   8],
       [  8,   4,   5, ...,   4,  -7,  -3],
       [  0,   1,   2, ...,   0,   0,  -1]]),
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: (21, 100)

 -------------------------------

### Random crop with multiple cropping

In [17]:
transform = EegRandomCrop(crop_length=200, multiple=2)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 66,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00085',
 'signal': [array([[ 47,  48,  48, ...,  12,  12,  12],
       [ 54,  55,  56, ...,  21,  19,  19],
       [  6,   6,   7, ...,   8,   6,   4],
       ...,
       [-19, -18, -17, ...,  -1,  -2,  -4],
       [ -9,   0,   5, ...,  -5,   6,   5],
       [ -1,  -1,  -1, ...,  -1,   0,  -1]]),
            array([[ 24,  24,  26, ...,  17,  15,  15],
       [ 22,  22,  24, ..., -27, -27, -26],
       [  1,   0,   1, ...,   3,   5,   7],
       ...,
       [  3,   1,   0, ...,   7,   9,   9],
       [  4,   6,  -5, ...,   8,   4,  -5],
       [  0,   0,   1, ...,   0,   0,   0]])],
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: [(21, 200), (21, 200)]

 ---------------------------------------------------------------------------------------------------- 

{'age': 66,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00085',
 'signal': [array([[  6,   7,   7, ..., -15, -14, -14],
       [ 78,  79,  81, ...,

### Random crop with multiple cropping and latency

In [18]:
transform = EegRandomCrop(crop_length=300, multiple=3, latency=50000, return_timing=True)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
for i in range(2): 
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 66,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00085',
 'signal': [array([[-14, -15, -17, ...,  23,  22,  20],
       [ 17,  17,  17, ...,  16,  15,  13],
       [-13, -14, -14, ...,  -8,  -7,  -8],
       ...,
       [  5,   4,   5, ...,   4,   3,   4],
       [  9,  -5,  -8, ...,   9,  -5,  -1],
       [  1,   0,   1, ...,   0,   1,   1]]),
            array([[ -6,  -6, -10, ..., -21, -25, -26],
       [ -2,  -1,  -4, ..., -14, -13, -11],
       [ 14,  16,  17, ...,   5,   5,   3],
       ...,
       [  1,   0,   0, ...,   0,   1,   1],
       [-10,  -2,  11, ...,  -9,   2,   8],
       [  0,   1,   3, ...,   0,   1,   3]]),
            array([[ 72,  70,  69, ..., -13, -16, -20],
       [-19, -21, -23, ..., -49, -50, -51],
       [  9,   6,   5, ..., -26, -26, -26],
       ...,
       [ -1,  -3,  -6, ...,   2,   1,   0],
       [ -3,   6,   0, ...,  -1,   4,  -6],
       [  1,   1,   0, ...,   0,   0,  -1]])],
 'start_point': [60409, 76085, 115349],
 'symptom': [

### Max Length

In [19]:
transform = transforms.Compose([
    EegLimitMaxLength(50300), 
    EegRandomCrop(crop_length=200, multiple=3, latency=50000, return_timing=True)
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 81,
 'class_label': 2,
 'class_name': 'Dementia',
 'serial': '01222',
 'signal': [array([[114, 111, 106, ...,  85,  75,  73],
       [ 19,  18,  14, ...,  13,  10,   8],
       [  3,   5,   4, ...,  -8,  -7,  -6],
       ...,
       [-19, -18, -20, ..., -22, -21, -20],
       [  2,   0,  -2, ..., -14, -10,  -4],
       [  0,  -2,  -2, ...,   0,   0,  -3]]),
            array([[  8,  21,  36, ...,  13,  36,  54],
       [-14, -13, -11, ...,   5,   8,  11],
       [  7,   5,   5, ...,   7,   5,   5],
       ...,
       [ -7,  -8, -10, ..., -17, -18, -19],
       [ 32,  31,  27, ...,   3,   1,  -1],
       [ -2,   0,   0, ...,  -3,  -3,   0]]),
            array([[ 21,  36,  57, ...,  36,  54,  67],
       [-13, -11,  -5, ...,   8,  11,  14],
       [  5,   5,   6, ...,   5,   5,   4],
       ...,
       [ -8, -10, -12, ..., -18, -19, -19],
       [ 31,  27,  27, ...,   1,  -1,  -5],
       [  0,   0,   2, ...,  -3,   0,   0]])],
 'start_point': [50021, 50005, 50006],
 'symptom': 

### To Tensor

In [20]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='feather',
                                                         transform=None)
print('Before:')
pprint.pprint(full_eeg_dataset[0])

print()
print('-' * 100)
print()

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='feather',
                                                         transform=EegToTensor())
print('After:')
pprint.pprint(full_eeg_dataset[0])

Before:
{'age': 78,
 'serial': '00001',
 'signal': array([[  0, -11, -13, ...,  18,  21,  22],
       [ 29,  33,  34, ...,  -7,  -4,  -4],
       [ -3,  -6,  -3, ...,  -1,   1,   2],
       ...,
       [ -4,  -2,   1, ...,   0,  -1,  -1],
       [112,  67,  76, ..., -13, -15, -11],
       [ -1,  -1,  -1, ...,  -1,  -1,  -1]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

----------------------------------------------------------------------------------------------------

After:
{'age': tensor(78.),
 'serial': '00001',
 'signal': tensor([[  0., -11., -13.,  ...,  18.,  21.,  22.],
        [ 29.,  33.,  34.,  ...,  -7.,  -4.,  -4.],
        [ -3.,  -6.,  -3.,  ...,  -1.,   1.,   2.],
        ...,
        [ -4.,  -2.,   1.,  ...,   0.,  -1.,  -1.],
        [112.,  67.,  76.,  ..., -13., -15., -11.],
        [ -1.,  -1.,  -1.,  ...,  -1.,  -1.,  -1.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}


### Compose the above all in one

In [21]:
transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegLimitMaxLength(max_length=200*60*10),  # 10m
    EegRandomCrop(crop_length=200*10, multiple=4, latency=200*10),  # crop: 10s, latency: 10s
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

pprint.pprint(train_dataset[0])

{'age': tensor(81.),
 'class_label': tensor(2),
 'class_name': 'Dementia',
 'serial': '01222',
 'signal': [tensor([[ 14.,  15.,  15.,  ...,  22.,  21.,  19.],
        [ -8.,  -8.,  -9.,  ...,  -1.,  -2.,  -3.],
        [ -6.,  -8., -10.,  ...,  -3.,  -3.,  -4.],
        ...,
        [ 11.,  11.,  10.,  ...,  10.,  10.,   8.],
        [ 13.,  13.,  13.,  ...,  -3.,  -2.,  -2.],
        [-12., -18., -20.,  ..., -51., -45., -40.]]),
            tensor([[ 67.,  67.,  64.,  ..., -13., -14., -15.],
        [  9.,  10.,   9.,  ...,  -7.,  -4.,  -4.],
        [-21., -18., -18.,  ...,  -3.,  -2.,  -2.],
        ...,
        [ -5.,  -3.,  -2.,  ...,   0.,   2.,   3.],
        [  2.,   1.,   1.,  ...,  -3.,  -2.,   0.],
        [  9.,  25.,  45.,  ...,  -9.,  -9.,  -5.]]),
            tensor([[-17., -20., -24.,  ...,   4.,   0.,  -1.],
        [ -8.,  -9.,  -9.,  ...,  11.,   9.,  10.],
        [  9.,  10.,  11.,  ...,  13.,  11.,  13.],
        ...,
        [  4.,   4.,   5.,  ...,  -6.,  -6.,  

---

## PyTorch DataLoader

In [22]:
if device.type == 'cuda':
    num_workers = 0  # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

In [23]:
transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegLimitMaxLength(max_length=200*60*10),  # 10m
    EegRandomCrop(crop_length=200*10, multiple=2, latency=200*10),  # crop: 10s, latency: 10s
    EegToTensor()
])

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='feather',
                                                         transform=transform)

full_loader = DataLoader(full_eeg_dataset,
                         batch_size=4,
                         shuffle=True,
                         drop_last=True,
                         num_workers=num_workers,
                         pin_memory=pin_memory,
                         collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(full_loader):
    pprint.pprint(sample_batched)
    break

{'age': tensor([55., 55., 60., 60., 76., 76., 90., 90.]),
 'serial': ['01289',
            '01289',
            '01141',
            '01141',
            '00724',
            '00724',
            '00581',
            '00581'],
 'signal': tensor([[[ -7.,  -8., -13.,  ...,  -3.,  -3.,  -3.],
         [  2.,  -1.,  -6.,  ...,   1.,   2.,   2.],
         [  3.,   2.,   2.,  ...,  -1.,   0.,   0.],
         ...,
         [ -4.,  -6.,  -7.,  ...,   3.,   3.,   4.],
         [  0.,  -1.,  -2.,  ...,   5.,   7.,   7.],
         [-11., -18., -24.,  ...,   4.,   7.,   9.]],

        [[ 62.,  61.,  56.,  ...,  -7.,  -5.,  -2.],
         [ 21.,  20.,  23.,  ..., -12.,  -9.,  -4.],
         [  2.,  -1.,  -5.,  ...,   2.,   3.,   3.],
         ...,
         [ 23.,  22.,  22.,  ...,   0.,  -1.,  -3.],
         [ -6.,  -6.,  -6.,  ...,  -3.,  -2.,  -2.],
         [ 36.,  31.,  26.,  ...,  23.,  40.,  20.]],

        [[ 18.,  24.,  28.,  ..., -47., -42., -38.],
         [ -4.,   0.,   1.,  ..., -11.,  

In [24]:
transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegLimitMaxLength(max_length=200*60*10),  # 10m
    EegRandomCrop(crop_length=200*10, multiple=2, latency=200*10),  # crop: 10s, latency: 10s
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=8,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    pprint.pprint(sample_batched, width=250)
    break

{'age': tensor([39., 39., 77., 77., 86., 86., 73., 73., 82., 82., 84., 84., 85., 85.,
        85., 85.]),
 'class_label': tensor([0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 2, 2, 2, 2]),
 'class_name': ['Normal', 'Normal', 'MCI', 'MCI', 'MCI', 'MCI', 'MCI', 'MCI', 'Dementia', 'Dementia', 'MCI', 'MCI', 'Dementia', 'Dementia', 'Dementia', 'Dementia'],
 'serial': ['01031', '01031', '00303', '00303', '01242', '01242', '01347', '01347', '01181', '01181', '01243', '01243', '00292', '00292', '00252', '00252'],
 'signal': tensor([[[ 44.,  40.,  47.,  ..., -38., -36., -34.],
         [ 25.,  26.,  33.,  ..., -26., -29., -28.],
         [ 19.,  23.,  24.,  ..., -13., -14., -12.],
         ...,
         [ -4.,  -2.,  -1.,  ...,  -1.,   0.,   1.],
         [  1.,   2.,   2.,  ...,   3.,   5.,   5.],
         [  1.,  -1., -10.,  ...,  -9.,   3.,  -4.]],

        [[-28., -30., -23.,  ...,  -1.,  -5., -14.],
         [-12., -16., -13.,  ..., -10., -11., -11.],
         [  1.,   2.,   1.,  ...,   4.,   4.,  

---

## Preprocessing steps run by the PyTorch Modules

In [25]:
transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegLimitMaxLength(max_length=200*60*10),  # 10m
    EegRandomCrop(crop_length=200*10, multiple=2, latency=200*10),  # crop: 10s, latency: 10s
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=2,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

### To GPU device if it is possible

In [26]:
print('device:', device)
print()

preprocess_train = transforms.Compose([EegToDevice(device=device)])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched)
    break

device: cuda

Sequential(
  (0): EegToDevice(device=cuda)
)
- Before -
{'age': tensor([72., 72., 85., 85.]),
 'class_label': tensor([0, 0, 1, 1]),
 'class_name': ['Normal', 'Normal', 'Abnormal', 'Abnormal'],
 'serial': ['00696', '00696', '00657', '00657'],
 'signal': tensor([[[ -56.,  -55.,  -60.,  ...,   16.,   16.,   15.],
         [   1.,    3.,   13.,  ...,   28.,   30.,   31.],
         [  25.,   30.,   20.,  ...,   -4.,    0.,    4.],
         ...,
         [  18.,   17.,   18.,  ...,   -8.,   -7.,   -6.],
         [  11.,   15.,   15.,  ...,  -13.,  -11.,   -7.],
         [  18.,   20.,   17.,  ...,  -76.,  -83.,  -87.]],

        [[  77.,   78.,   76.,  ...,   10.,   13.,   12.],
         [  -1.,    2.,    9.,  ...,   12.,   16.,   20.],
         [ -11.,   -9.,   -5.,  ...,   -5.,   -4.,   -4.],
         ...,
         [  -4.,   -3.,   -3.,  ...,   10.,    9.,   10.],
         [   4.,   -1.,   -4.,  ...,    8.,    6.,    3.],
         [-122., -126., -137.,  ...,   78.,   34.,  -

### Normalization per signal

In [27]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizePerSignal()
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizePerSignal(eps=1e-08)
)
- Before -
Mean: tensor([[ 3.6975,  1.7315,  0.2125, -0.9015,  1.4420,  0.8715,  0.4200, -0.7245,
         -0.5970, -0.3370, -0.6790,  0.8970, -0.9285, -0.8055, -0.9575, -1.5605,
          3.5805, -0.2895, -0.5280, -1.5705],
        [12.0375,  4.4140,  1.2955, -0.2610, -1.2930,  5.5185, -0.1050, -2.3265,
         -1.5660, -0.9565,  4.7895,  2.9220, -0.6145, -4.6300, -1.4195, -1.9960,
          3.2125, -0.5285, -0.9555,  0.0365],
        [ 0.3860,  0.4420,  0.2560,  0.0980, -0.3285,  0.4215,  0.9185,  0.1725,
          0.1980, -0.1870, -0.4770,  0.2625,  0.1625, -1.2885, -0.2020, -0.4685,
          0.2830, -0.1530,  0.5610,  0.5525],
        [-1.4555, -1.2505,  0.7580,  0.5830,  1.2935, -0.1330, -3.2185,  0.0795,
          0.2380,  0.5530, -0.7655,  0.3045,  0.7995,  1.1770,  0.4395,  0.1050,
          0.6425,  0.6660, -0.7175,  0.2405]])

Std: tensor([[ 40.0912,  17.8688,   6.4726,   6.1523,   9.1699,

### Signal normalization using the specified mean and std values

In [28]:
signal_mean, signal_std = calculate_signal_statistics(train_loader, repeats=1, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Mean and standard deviation for signal:
tensor([[[-0.1357],
         [-0.2511],
         [ 0.0018],
         [ 0.0711],
         [-0.0479],
         [-0.2520],
         [-0.2120],
         [ 0.0005],
         [-0.0007],
         [-0.0591],
         [ 0.1582],
         [ 0.3231],
         [ 0.0052],
         [-0.0274],
         [ 0.1218],
         [ 0.0294],
         [-0.0802],
         [-0.0193],
         [-0.0197],
         [-0.0166]]])
-
tensor([[[46.2526],
         [20.3984],
         [11.6559],
         [11.7465],
         [15.2970],
         [50.3807],
         [20.2444],
         [10.7235],
         [11.5641],
         [15.6016],
         [20.7594],
         [15.2747],
         [13.4793],
         [21.4928],
         [16.8901],
         [15.2875],
         [19.7100],
         [11.2538],
         [11.2400],
         [98.3843]]])

----------------------------------------------------------------------------------------------------

Sequential(
  (0): EegToDevice(device=cuda)
  (1): 

### Age normalization

In [29]:
age_mean, age_std = calculate_age_statistics(train_loader, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['age'])

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['age'])
    break

Age mean and standard deviation:
tensor([70.6025]) tensor([6.3509])

----------------------------------------------------------------------------------------------------

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.3509]),eps=1e-08)
)
- Before -
tensor([75., 75., 86., 86.])

----------------------------------------------------------------------------------------------------

- After -
tensor([0.6924, 0.6924, 2.4245, 2.4245], device='cuda:0')


### Short time Fourier transform (STFT or spectrogram)

In [30]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['signal'].shape)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['signal'].shape)
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
)
- Before -
torch.Size([4, 20, 2000])

----------------------------------------------------------------------------------------------------

- After -
torch.Size([4, 40, 101, 41])


### Signal normalization after STFT

In [31]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)

signal_2d_mean, signal_2d_std = calculate_signal_statistics(train_loader, preprocess_train)

preprocess_train2 = transforms.Compose([
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std)
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)

pprint.pprint(preprocess_train)
pprint.pprint(preprocess_train2)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    preprocess_train(sample_batched)   
    
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    
    print()
    print('-' * 100)
    print()
    
    print('- After -')
    preprocess_train2(sample_batched)
    
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
)
Sequential(
  (0): EegNormalizeMeanStd(mean=tensor([[ 3.6224e+01,  8.1955e-01, -1.3452e-01,  ..., -6.3493e-02,
           -5.8977e-02, -8.1556e-02],
          [ 2.6344e+01,  3.9039e-01,  5.7765e-02,  ..., -1.2983e-03,
            5.7594e-05, -1.2575e-02],
          [-4.3636e+00,  8.2101e-02, -2.9571e-02,  ...,  1.6749e-03,
            3.5939e-03,  6.1787e-03],
          ...,
          [ 0.0000e+00,  1.1858e-01,  2.4689e-02,  ...,  1.4142e-04,
           -5.3444e-04, -1.3321e-08],
          [ 0.0000e+00,  6.1904e-01,  2.7609e-01,  ...,  1.2308e-03,
           -1.3162e-04, -1.4377e-08],
          [ 0.0000e+00, -6.7974e-01,  7.8550e-02,  ...,  7.2703e-04,
           -1.1213e-02, -1.2512e-07]], device='cuda:0'),std=tensor([[7.4391e+03, 1.5245e+03, 7.9377e+02,  ..., 2.9074e+01, 2.9047e+01,
           2.9265e+01],
          [3.3342e+03, 6.0361e+02, 3.0638e+02,  ..., 1.3517e+01

---

## Speed check without STFT

### `EDF`

In [32]:
%%time

transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegLimitMaxLength(max_length=200*60*10),  # 10m
    EegRandomCrop(crop_length=200*10, multiple=2, latency=200*10),  # crop: 10s, latency: 10s
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=32,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegDropChannels(drop_index=20)
    EegLimitMaxLength(max_length=120000)
    EegRandomCrop(crop_length=2000, multiple=2, latency=2000, return_timing=False)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.3509]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([-0.1357, -0.2511,  0.0018,  0.0711, -0.0479, -0.2520, -0.2120,  0.0005,
          -0.0007, -0.0591,  0.1582,  0.3231,  0.0052, -0.0274,  0.1218,  0.0294,
          -0.0802, -0.0193, -0.0197, -0.0166]),std=tensor([46.2526, 20.3984, 11.6559, 11.7465, 15.2970, 50.3807, 20.2444, 10.7235,
          11.5641, 15.6016, 20.7594, 15.2747, 13.4793, 21.4928, 16.8901, 15.2875,
          19.7100, 11.2538, 11.2400, 98.3843]),eps=1e-08)
)
CPU times: total: 31min 13s
Wall time: 2min 37s


### `Feather`

In [33]:
%%time

transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegLimitMaxLength(max_length=200*60*10),  # 10m
    EegRandomCrop(crop_length=200*10, multiple=2, latency=200*10),  # crop: 10s, latency: 10s
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=32,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegDropChannels(drop_index=20)
    EegLimitMaxLength(max_length=120000)
    EegRandomCrop(crop_length=2000, multiple=2, latency=2000, return_timing=False)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.3509]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([-0.1357, -0.2511,  0.0018,  0.0711, -0.0479, -0.2520, -0.2120,  0.0005,
          -0.0007, -0.0591,  0.1582,  0.3231,  0.0052, -0.0274,  0.1218,  0.0294,
          -0.0802, -0.0193, -0.0197, -0.0166]),std=tensor([46.2526, 20.3984, 11.6559, 11.7465, 15.2970, 50.3807, 20.2444, 10.7235,
          11.5641, 15.6016, 20.7594, 15.2747, 13.4793, 21.4928, 16.8901, 15.2875,
          19.7100, 11.2538, 11.2400, 98.3843]),eps=1e-08)
)
CPU times: total: 1min 47s
Wall time: 7.19 s


### `memmap`

In [34]:
%%time

transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegLimitMaxLength(max_length=200*60*10),  # 10m
    EegRandomCrop(crop_length=200*10, multiple=2, latency=200*10),  # crop: 10s, latency: 10s
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=32,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegDropChannels(drop_index=20)
    EegLimitMaxLength(max_length=120000)
    EegRandomCrop(crop_length=2000, multiple=2, latency=2000, return_timing=False)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.3509]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([-0.1357, -0.2511,  0.0018,  0.0711, -0.0479, -0.2520, -0.2120,  0.0005,
          -0.0007, -0.0591,  0.1582,  0.3231,  0.0052, -0.0274,  0.1218,  0.0294,
          -0.0802, -0.0193, -0.0197, -0.0166]),std=tensor([46.2526, 20.3984, 11.6559, 11.7465, 15.2970, 50.3807, 20.2444, 10.7235,
          11.5641, 15.6016, 20.7594, 15.2747, 13.4793, 21.4928, 16.8901, 15.2875,
          19.7100, 11.2538, 11.2400, 98.3843]),eps=1e-08)
)
CPU times: total: 1min 23s
Wall time: 7.14 s
Compiler : 188 ms


---

## Speed check with STFT

### `EDF`

In [35]:
%%time

transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegLimitMaxLength(max_length=200*60*10),  # 10m
    EegRandomCrop(crop_length=200*10, multiple=2, latency=200*10),  # crop: 10s, latency: 10s
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=32,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=200, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegDropChannels(drop_index=20)
    EegLimitMaxLength(max_length=120000)
    EegRandomCrop(crop_length=2000, multiple=2, latency=2000, return_timing=False)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.3509]),eps=1e-08)
  (2): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
  (3): EegNormalizeMeanStd(mean=tensor([[ 3.6224e+01,  8.1955e-01, -1.3452e-01,  ..., -6.3493e-02,
           -5.8977e-02, -8.1556e-02],
          [ 2.6344e+01,  3.9039e-01,  5.7765e-02,  ..., -1.2983e-03,
            5.7594e-05, -1.2575e-02],
          [-4.3636e+00,  8.2101e-02, -2.9571e-02,  ...,  1.6749e-03,
            3.5939e-03,  6.1787e-03],
          ...,
          [ 0.0000e+00,  1.1858e-01,  2.4689e-02,  ...,  1.4142e-04,
           -5.3444e-04, -1.3321e-08],
          [ 0.0000e+00,  6.1904e-01,  2.7609e-01,  ...,  1.2308e-03,
           -1.3162e-04, -1.4377e-08],
          [ 0.0000e+00, -6.7974e-01,  7

### `Feather`

In [36]:
%%time

transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegLimitMaxLength(max_length=200*60*10),  # 10m
    EegRandomCrop(crop_length=200*10, multiple=2, latency=200*10),  # crop: 10s, latency: 10s
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=32,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=200, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegDropChannels(drop_index=20)
    EegLimitMaxLength(max_length=120000)
    EegRandomCrop(crop_length=2000, multiple=2, latency=2000, return_timing=False)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.3509]),eps=1e-08)
  (2): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
  (3): EegNormalizeMeanStd(mean=tensor([[ 3.6224e+01,  8.1955e-01, -1.3452e-01,  ..., -6.3493e-02,
           -5.8977e-02, -8.1556e-02],
          [ 2.6344e+01,  3.9039e-01,  5.7765e-02,  ..., -1.2983e-03,
            5.7594e-05, -1.2575e-02],
          [-4.3636e+00,  8.2101e-02, -2.9571e-02,  ...,  1.6749e-03,
            3.5939e-03,  6.1787e-03],
          ...,
          [ 0.0000e+00,  1.1858e-01,  2.4689e-02,  ...,  1.4142e-04,
           -5.3444e-04, -1.3321e-08],
          [ 0.0000e+00,  6.1904e-01,  2.7609e-01,  ...,  1.2308e-03,
           -1.3162e-04, -1.4377e-08],
          [ 0.0000e+00, -6.7974e-01,  7

### `memmap`

In [37]:
%%time

transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegLimitMaxLength(max_length=200*60*10),  # 10m
    EegRandomCrop(crop_length=200*10, multiple=2, latency=200*10),  # crop: 10s, latency: 10s
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=32,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=200, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegDropChannels(drop_index=20)
    EegLimitMaxLength(max_length=120000)
    EegRandomCrop(crop_length=2000, multiple=2, latency=2000, return_timing=False)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.3509]),eps=1e-08)
  (2): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
  (3): EegNormalizeMeanStd(mean=tensor([[ 3.6224e+01,  8.1955e-01, -1.3452e-01,  ..., -6.3493e-02,
           -5.8977e-02, -8.1556e-02],
          [ 2.6344e+01,  3.9039e-01,  5.7765e-02,  ..., -1.2983e-03,
            5.7594e-05, -1.2575e-02],
          [-4.3636e+00,  8.2101e-02, -2.9571e-02,  ...,  1.6749e-03,
            3.5939e-03,  6.1787e-03],
          ...,
          [ 0.0000e+00,  1.1858e-01,  2.4689e-02,  ...,  1.4142e-04,
           -5.3444e-04, -1.3321e-08],
          [ 0.0000e+00,  6.1904e-01,  2.7609e-01,  ...,  1.2308e-03,
           -1.3162e-04, -1.4377e-08],
          [ 0.0000e+00, -6.7974e-01,  7

---

## Test on longer sequence

In [38]:
%%time

longer_transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegLimitMaxLength(max_length=200*60*10),  # 10m
    EegRandomCrop(crop_length=200*10*6, multiple=2, latency=200*10),  # crop: 1m, latency: 10s
    EegToTensor()
])
pprint.pprint(longer_transform)

config_data, longer_test_dataset = load_caueeg_task_split(dataset_path=data_path, 
                                                          task='task2', 
                                                          split='test',
                                                          load_event=False,
                                                          file_format='feather', 
                                                          transform=longer_transform)

longer_test_loader = DataLoader(longer_test_dataset,
                                batch_size=32,
                                shuffle=True,
                                drop_last=False,
                                num_workers=num_workers,
                                pin_memory=pin_memory,
                                collate_fn=eeg_collate_fn)
 
preprocess_test = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
    EegNormalizeAge(mean=age_mean, std=age_std),
])
preprocess_test = torch.nn.Sequential(*preprocess_test.transforms).to(device)
pprint.pprint(preprocess_test)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_test(sample_batched)

Compose(
    EegDropChannels(drop_index=20)
    EegLimitMaxLength(max_length=120000)
    EegRandomCrop(crop_length=12000, multiple=2, latency=2000, return_timing=False)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeMeanStd(mean=tensor([-0.1357, -0.2511,  0.0018,  0.0711, -0.0479, -0.2520, -0.2120,  0.0005,
          -0.0007, -0.0591,  0.1582,  0.3231,  0.0052, -0.0274,  0.1218,  0.0294,
          -0.0802, -0.0193, -0.0197, -0.0166]),std=tensor([46.2526, 20.3984, 11.6559, 11.7465, 15.2970, 50.3807, 20.2444, 10.7235,
          11.5641, 15.6016, 20.7594, 15.2747, 13.4793, 21.4928, 16.8901, 15.2875,
          19.7100, 11.2538, 11.2400, 98.3843]),eps=1e-08)
  (2): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.3509]),eps=1e-08)
)
CPU times: total: 1min 19s
Wall time: 6.6 s
