# Dataset and DataLoader

This notebook loads the `CAUEEG` dataset, tests some useful preprocessing, and makes up the PyTorch DataLoader instances for the training.

-----

## Configurations

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
# Load some packages
import os
import glob
import json
import pprint

import numpy as np
import random
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

# custom package
from datasets.caueeg_dataset import *
from datasets.caueeg_script import *
from datasets.pipeline import *

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.11.0+cu113
cuda is available.


In [4]:
# Data file path
data_path = r'local/dataset/02_Curated_Data_220419/'

-----

## Load the CAUEEG dataset

### Load the whole CAUEEG data as a PyTorch dataset instance without considering the target task (no train/val/test sets and no class label).

In [5]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='edf',
                                                         transform=None)

pprint.pprint(config_data, width=250)
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[1])
print('\n', '-' * 100, '\n')

{'dataset_name': 'CAUEEG dataset',
 'signal_header': ['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 78,
 'serial': '00001',
 'signal': array([[  0., -11., -13., ...,   0.,   0.,   0.],
       [ 29.,  33.,  34., ...,   0.,   0.,   0.],
       [ -3.,  -6.,  -3., ...,   0.,   0.,   0.],
       ...,
       [ -4.,  -2.,   1., ...,   0.,   0.,   0.],
       [112.,  67.,  76., ...,   0.,   0.,   0.],
       [ -1.,  -1.,  -1., ...,   0.,   0.,   0.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 56,
 'serial': '00002',
 'signal': array([[  39.,   58.,   72., ...,    0.,    0.,    0.],
       [   4.,

### Load the CAUEEG task1 datasets as the PyTorch dataset instances.

In [6]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')

pprint.pprint(train_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(val_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal], [MCI], and [Dementia] '
                     'symptoms.',
 'task_name': 'CAUEEG-task1 benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 88,
 'class_label': 2,
 'class_name': 'Dementia',
 'serial': '01379',
 'signal': array([[ -1.,  -4., -35., ...,   0.,   0.,   0.],
       [ 11.,  18., -31., ...,   0.,   0.,   0.],
       [ 18.,  30.,  -7., ...,   0.,   0.,   0.],
       ...,
       [ 20.,  33., -12., ...,   0.,   0.,   0.],
       [ 91., 138.,  23., ...,   0.,   0.,   0.],
       [  0.,  -1.,  -1., ...,   0.,   0.,   0.]]),
 'symptom': ['dementia', 'ad', 'load']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 69,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00970',
 'si

### Load the CAUEEG task2 datasets as the PyTorch dataset instances.

In [7]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')

pprint.pprint(train_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(val_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'Abnormal'],
 'class_name_to_label': {'Abnormal': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal] and [Abnormal] symptoms',
 'task_name': 'CAUEEG-task2 benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 60,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00137',
 'signal': array([[ 13.,  21.,  22., ...,   0.,   0.,   0.],
       [  9.,   8.,   8., ...,   0.,   0.,   0.],
       [  5.,   8.,   9., ...,   0.,   0.,   0.],
       ...,
       [ -5.,  -6.,  -3., ...,   0.,   0.,   0.],
       [-62., -48., -36., ...,   0.,   0.,   0.],
       [ -1.,  -1.,   0., ...,   0.,   0.,   0.]]),
 'symptom': ['normal', 'smi']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 72,
 'class_label': 1,
 'class_name': 'Abnormal',
 'serial': '00709',
 'signal': array([[-67., -55., -54., ...,   0.,   0.,   0.],
  

### Event information

In [8]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=True, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(train_dataset[0])

{'age': 88,
 'class_label': 2,
 'class_name': 'Dementia',
 'event': [[0, 'Start Recording'],
           [0, 'New Montage - Montage 005'],
           [36596, 'Eyes Open'],
           [41382, 'Move'],
           [41740, 'Move'],
           [72212, 'Eyes Closed'],
           [75866, 'Eyes Open'],
           [76706, 'Eyes Closed'],
           [78512, 'Eyes Open'],
           [79352, 'Eyes Closed'],
           [80446, 'Photic On - 3.0 Hz'],
           [80738, 'Eyes Open'],
           [81494, 'Eyes Closed'],
           [82464, 'Photic Off'],
           [84522, 'Photic On - 6.0 Hz'],
           [84816, 'Eyes Open'],
           [85530, 'Eyes Closed'],
           [86538, 'Photic Off'],
           [88554, 'Photic On - 9.0 Hz'],
           [90570, 'Photic Off'],
           [90780, 'Eyes Open'],
           [91578, 'Eyes Closed'],
           [92628, 'Photic On - 12.0 Hz'],
           [94644, 'Photic Off'],
           [96702, 'Photic On - 15.0 Hz'],
           [98718, 'Photic Off'],
           [1007

### Data Format: `EDF`

In [9]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'serial': '00137', 'age': 60, 'symptom': ['normal', 'smi'], 'class_name': 'Normal', 'class_label': 0, 'signal': array([[ 13.,  21.,  22., ...,   0.,   0.,   0.],
       [  9.,   8.,   8., ...,   0.,   0.,   0.],
       [  5.,   8.,   9., ...,   0.,   0.,   0.],
       ...,
       [ -5.,  -6.,  -3., ...,   0.,   0.,   0.],
       [-62., -48., -36., ...,   0.,   0.,   0.],
       [ -1.,  -1.,   0., ...,   0.,   0.,   0.]])}
{'serial': '00526', 'age': 73, 'symptom': ['dementia', 'ad', 'load'], 'class_name': 'Abnormal', 'class_label': 1, 'signal': array([[-21.,  -3.,   6., ...,   0.,   0.,   0.],
       [ 13.,  10.,  11., ...,   0.,   0.,   0.],
       [ -7., -10., -12., ...,   0.,   0.,   0.],
       ...,
       [  0.,   0.,  -1., ...,   0.,   0.,   0.],
       [ 13.,  22.,  14., ...,   0.,   0.,   0.],
       [  1.,   2.,   2., ...,   0.,   0.,   0.]])}
CPU times: total: 234 ms
Wall time: 251 ms


### Data Format: `PyArrow Feather`

In [10]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'serial': '00137', 'age': 60, 'symptom': ['normal', 'smi'], 'class_name': 'Normal', 'class_label': 0, 'signal': array([[ 13,  21,  22, ...,  32,  30,  27],
       [  9,   8,   8, ...,  -4,  -3,  -4],
       [  5,   8,   9, ...,  -4,  -6,  -9],
       ...,
       [ -5,  -6,  -3, ...,  16,  15,  13],
       [-62, -48, -36, ..., -78, -77, -76],
       [ -1,  -1,   0, ...,   0,   0,   0]])}
{'serial': '00526', 'age': 73, 'symptom': ['dementia', 'ad', 'load'], 'class_name': 'Abnormal', 'class_label': 1, 'signal': array([[-21,  -3,   6, ...,  -7, -11, -10],
       [ 13,  10,  11, ...,   5,   5,   7],
       [ -7, -10, -12, ...,   2,   1,   1],
       ...,
       [  0,   0,  -1, ...,   3,   4,   3],
       [ 13,  22,  14, ...,  -3,  -8,  -9],
       [  1,   2,   2, ...,   0,   1,   1]])}
CPU times: total: 359 ms
Wall time: 244 ms


### Data Format: `NumPy Memmap`

In [11]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'serial': '00137', 'age': 60, 'symptom': ['normal', 'smi'], 'class_name': 'Normal', 'class_label': 0, 'signal': memmap([[ 13,  21,  22, ...,  32,  30,  27],
        [  9,   8,   8, ...,  -4,  -3,  -4],
        [  5,   8,   9, ...,  -4,  -6,  -9],
        ...,
        [ -5,  -6,  -3, ...,  16,  15,  13],
        [-62, -48, -36, ..., -78, -77, -76],
        [ -1,  -1,   0, ...,   0,   0,   0]])}
{'serial': '00526', 'age': 73, 'symptom': ['dementia', 'ad', 'load'], 'class_name': 'Abnormal', 'class_label': 1, 'signal': memmap([[-21,  -3,   6, ...,  -7, -11, -10],
        [ 13,  10,  11, ...,   5,   5,   7],
        [ -7, -10, -12, ...,   2,   1,   1],
        ...,
        [  0,   0,  -1, ...,   3,   4,   3],
        [ 13,  22,  14, ...,  -3,  -8,  -9],
        [  1,   2,   2, ...,   0,   1,   1]])}
CPU times: total: 0 ns
Wall time: 4 ms


---

## PyTorch Transforms

### Random crop

In [12]:
transform = EegRandomCrop(crop_length=100)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path,
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', d['signal'].shape)
    print('\n', '-' * 100, '\n')

{'age': 60,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00137',
 'signal': array([[  53,   56,   58, ...,   75,   72,   73],
       [  -3,   -2,   -2, ...,   -1,   -3,   -4],
       [ -21,  -23,  -22, ...,  -21,  -22,  -22],
       ...,
       [ -16,  -17,  -18, ...,  -16,  -17,  -18],
       [ -39,  -36,  -32, ...,   52,   57,   57],
       [ 411,  779,  751, ..., -128, -126, -124]]),
 'symptom': ['normal', 'smi']}

>>> signal shape: (21, 100)

 ---------------------------------------------------------------------------------------------------- 

{'age': 60,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00137',
 'signal': array([[  9,   8,   8, ...,  56,  56,  58],
       [-11, -12, -10, ...,  -7,  -6,  -6],
       [-16, -15, -14, ..., -24, -24, -22],
       ...,
       [ -7,  -6,  -5, ..., -26, -25, -24],
       [-48, -55, -61, ...,  69,  63,  59],
       [679, 422, -40, ..., 347, 342, 283]]),
 'symptom': ['normal', 'smi']}

>>> signal shape: (21, 100)

 -------

### Random crop with multiple cropping

In [13]:
transform = EegRandomCrop(crop_length=200, multiple=2)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 60,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00137',
 'signal': [array([[  22,   20,   19, ...,   29,   27,   25],
       [  -8,  -10,   -8, ...,   12,    8,    6],
       [ -13,  -13,  -12, ...,    1,    1,    0],
       ...,
       [  -9,   -9,   -9, ...,  -15,   -6,    2],
       [  -4,   -4,   -4, ..., -507, -754, -818],
       [  -1,   -1,   -1, ...,    0,   -1,   -1]]),
            array([[ -8,  -6,  -8, ...,   3,   2,   2],
       [ -3,  -4,  -3, ...,  -3,  -4,  -1],
       [  8,   5,   3, ...,  -1,  -1,   2],
       ...,
       [  6,  -1,  -6, ...,  10,  12,  13],
       [ 33,  28,  24, ..., -34, -25, -17],
       [  0,   0,   2, ...,  -3,  -1,   0]])],
 'symptom': ['normal', 'smi']}

>>> signal shape: [(21, 200), (21, 200)]

 ---------------------------------------------------------------------------------------------------- 

{'age': 60,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00137',
 'signal': [array([[  1,  -5,  -8, ...,   3,   3,   2

### Random crop with multiple cropping and latency

In [14]:
transform = EegRandomCrop(crop_length=300, multiple=3, latency=50000, return_timing=True)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
for i in range(2): 
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 60,
 'class_label': 0,
 'class_name': 'Normal',
 'crop_timing': [168177, 156969, 101014],
 'serial': '00137',
 'signal': [array([[ 47,  45,  44, ...,  -4,  -2,  -1],
       [ 21,  18,  14, ..., -12, -14, -14],
       [ 30,  29,  23, ...,  -2,  -4,  -4],
       ...,
       [-21, -19, -17, ..., -12, -12, -14],
       [-14, -16, -15, ..., -14, -13, -14],
       [  0,  -1,  -1, ...,   0,   0,   0]]),
            array([[  1,  -2,  -3, ...,  -4,  -4,  -3],
       [-18, -18, -18, ...,  -6,  -6,  -6],
       [-15, -16, -16, ..., -13,  -9,  -7],
       ...,
       [ -3,  -3,  -5, ...,   3,   0,  -1],
       [ 13,  14,  13, ...,  57,  40,  19],
       [ -1,  -1,  -1, ...,   0,   0,   0]]),
            array([[  -11,   -11,    -9, ...,     9,     9,    10],
       [  -12,    -8,    -5, ...,    -6,    -5,    -3],
       [   -3,    -3,    -1, ...,     2,     3,     5],
       ...,
       [    8,     7,     7, ...,    -6,    -6,    -8],
       [  -40,   -15,    12, ...,    -7,    -6,    -7]

### Random crop with multiple cropping, latency, and max length limit

In [15]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200, 
                  length_limit=50300,
                  multiple=3, 
                  latency=50000, 
                  return_timing=True)
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 88,
 'class_label': 2,
 'class_name': 'Dementia',
 'crop_timing': [50029, 50000, 50044],
 'serial': '01379',
 'signal': [array([[-26, -26, -27, ..., -69, -68, -66],
       [-66, -64, -67, ..., -58, -58, -61],
       [  5,   7,   1, ...,   9,   9,   4],
       ...,
       [  1,  -1,  -5, ...,   6,   5,  -1],
       [ 22,  22,  14, ...,  19,  18,  10],
       [ -1,   0,  -1, ...,  -1,  -1,   0]]),
            array([[-28, -25, -24, ..., -70, -71, -66],
       [-60, -63, -64, ..., -62, -63, -63],
       [ -9, -13, -13, ...,   5,   8,   9],
       ...,
       [  1,  -4,  -5, ...,  11,  10,  10],
       [-35, -33, -28, ...,  -5,  -6,   2],
       [  0,  -1,   0, ...,  -1,  -1,  -1]]),
            array([[-28, -27, -19, ..., -57, -55, -46],
       [-75, -72, -64, ..., -52, -50, -48],
       [-12, -12,  -8, ...,   5,   4,   5],
       ...,
       [ -2,  -3,  -3, ...,   1,   2,   5],
       [  4,   2,   6, ...,   2,   0,   6],
       [ -1,  -1,  -1, ...,   2,   0,  -1]])],
 'symptom': 

### Drop channel(s)

In [16]:
anno_path = os.path.join(data_path, 'annotation.json')
with open(anno_path, 'r') as json_file:
    annotation = json.load(json_file)
signal_headers = annotation['signal_header']
del annotation
print(signal_headers)

channel_ekg = signal_headers.index('EKG')
print('channel_ekg: ', channel_ekg)

channel_photic = signal_headers.index('Photic')
print('channel_photic: ', channel_photic)

['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']
channel_ekg:  19
channel_photic:  20


In [17]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=EegDropChannels(channel_ekg))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 173000)
[[ -1  -4 -35 ... -20 -16 -11]
 [ 11  18 -31 ...  -1   3   6]
 [ 18  30  -7 ...  -4  -1   0]
 ...
 [ 20  33 -12 ...   2   4   4]
 [ 91 138  23 ...  13  16  21]
 [  0  -1  -1 ...   0   0   0]]

----------------------------------------------------------------------------------------------------

after: (20, 173000)
[[ -1  -4 -35 ... -20 -16 -11]
 [ 11  18 -31 ...  -1   3   6]
 [ 18  30  -7 ...  -4  -1   0]
 ...
 [ 22  38   0 ... -13 -10 -10]
 [ 20  33 -12 ...   2   4   4]
 [  0  -1  -1 ...   0   0   0]]


In [18]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=EegDropChannels(channel_photic))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 173000)
[[ -1  -4 -35 ... -20 -16 -11]
 [ 11  18 -31 ...  -1   3   6]
 [ 18  30  -7 ...  -4  -1   0]
 ...
 [ 20  33 -12 ...   2   4   4]
 [ 91 138  23 ...  13  16  21]
 [  0  -1  -1 ...   0   0   0]]

----------------------------------------------------------------------------------------------------

after: (20, 173000)
[[ -1  -4 -35 ... -20 -16 -11]
 [ 11  18 -31 ...  -1   3   6]
 [ 18  30  -7 ...  -4  -1   0]
 ...
 [ 22  38   0 ... -13 -10 -10]
 [ 20  33 -12 ...   2   4   4]
 [ 91 138  23 ...  13  16  21]]


In [19]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=EegDropChannels([channel_ekg, channel_photic]))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 173000)
[[ -1  -4 -35 ... -20 -16 -11]
 [ 11  18 -31 ...  -1   3   6]
 [ 18  30  -7 ...  -4  -1   0]
 ...
 [ 20  33 -12 ...   2   4   4]
 [ 91 138  23 ...  13  16  21]
 [  0  -1  -1 ...   0   0   0]]

----------------------------------------------------------------------------------------------------

after: (19, 173000)
[[ -1  -4 -35 ... -20 -16 -11]
 [ 11  18 -31 ...  -1   3   6]
 [ 18  30  -7 ...  -4  -1   0]
 ...
 [ 13  20 -17 ...  -5   1   3]
 [ 22  38   0 ... -13 -10 -10]
 [ 20  33 -12 ...   2   4   4]]


### To Tensor

In [20]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='feather',
                                                         transform=None)
print('Before:')
pprint.pprint(full_eeg_dataset[0])

print()
print('-' * 100)
print()

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='feather',
                                                         transform=EegToTensor())
print('After:')
pprint.pprint(full_eeg_dataset[0])

Before:
{'age': 78,
 'serial': '00001',
 'signal': array([[  0, -11, -13, ...,  18,  21,  22],
       [ 29,  33,  34, ...,  -7,  -4,  -4],
       [ -3,  -6,  -3, ...,  -1,   1,   2],
       ...,
       [ -4,  -2,   1, ...,   0,  -1,  -1],
       [112,  67,  76, ..., -13, -15, -11],
       [ -1,  -1,  -1, ...,  -1,  -1,  -1]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

----------------------------------------------------------------------------------------------------

After:
{'age': tensor(78.),
 'serial': '00001',
 'signal': tensor([[  0., -11., -13.,  ...,  18.,  21.,  22.],
        [ 29.,  33.,  34.,  ...,  -7.,  -4.,  -4.],
        [ -3.,  -6.,  -3.,  ...,  -1.,   1.,   2.],
        ...,
        [ -4.,  -2.,   1.,  ...,   0.,  -1.,  -1.],
        [112.,  67.,  76.,  ..., -13., -15., -11.],
        [ -1.,  -1.,  -1.,  ...,  -1.,  -1.,  -1.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}


### Compose the above all in one

In [21]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=4, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

pprint.pprint(train_dataset[0])

{'age': tensor(88.),
 'class_label': tensor(2),
 'class_name': 'Dementia',
 'serial': '01379',
 'signal': [tensor([[ -20.,  -27.,  -31.,  ...,  -28.,  -29.,  -29.],
        [ -11.,  -16.,  -17.,  ...,   23.,   24.,   26.],
        [   3.,    6.,    6.,  ...,  -10.,   -6.,   -2.],
        ...,
        [  -2.,    1.,    2.,  ...,   12.,   11.,   14.],
        [  -7.,   -7.,   -7.,  ...,   -1.,    1.,    2.],
        [  61.,  -75., -281.,  ..., -510., -517., -368.]]),
            tensor([[13., 15., 15.,  ..., 37., 38., 37.],
        [ 6., 11., 12.,  ..., 26., 28., 31.],
        [-1.,  2.,  2.,  ..., -4.,  1.,  4.],
        ...,
        [20., 22., 19.,  ..., -5.,  0.,  6.],
        [ 8.,  7.,  2.,  ..., -7., -4., -6.],
        [47., 52., 50.,  ..., 22., 36., 42.]]),
            tensor([[-55., -54., -55.,  ..., -14., -12.,  -7.],
        [ 51.,  50.,  47.,  ...,  14.,  18.,  25.],
        [  2.,  -3., -10.,  ...,  -5.,  -2.,   1.],
        ...,
        [ 14.,  13.,   8.,  ...,   2.,   3.,  

---

## PyTorch DataLoader

In [22]:
if device.type == 'cuda':
    num_workers = 0  # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

In [23]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='memmap',
                                                         transform=transform)

full_loader = DataLoader(full_eeg_dataset,
                         batch_size=4,
                         shuffle=True,
                         drop_last=True,
                         num_workers=num_workers,
                         pin_memory=pin_memory,
                         collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(full_loader):
    pprint.pprint(sample_batched)
    break

{'age': tensor([70., 70., 79., 79., 83., 83., 67., 67.]),
 'serial': ['00389',
            '00389',
            '00723',
            '00723',
            '01035',
            '01035',
            '00939',
            '00939'],
 'signal': tensor([[[ 13.,   7.,  14.,  ..., -10.,  24.,  45.],
         [  9.,   4.,   6.,  ...,  15.,  16.,   6.],
         [ -2.,  -6.,  -9.,  ...,   0.,  -1.,   2.],
         ...,
         [  2.,  -1.,  -3.,  ...,  -2.,  -5.,   0.],
         [ -1.,  -1.,   0.,  ...,  -5., -11.,  -6.],
         [-45., -45., -41.,  ...,  -6., -11., -15.]],

        [[ -6., -12., -10.,  ..., -33., -35., -36.],
         [  6.,   6.,   5.,  ..., -20., -20., -17.],
         [  5.,   6.,   7.,  ...,   5.,  -2.,  -7.],
         ...,
         [  9.,  11.,  12.,  ...,  -3.,  -7.,  -7.],
         [ 15.,  15.,  16.,  ...,  22.,  21.,  15.],
         [-66., -72., -78.,  ...,  78.,  83.,  85.]],

        [[-36., -31., -27.,  ...,  15.,  15.,  16.],
         [-11., -11., -10.,  ...,  -3.,  

In [24]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=8,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    pprint.pprint(sample_batched, width=250)
    break

{'age': tensor([85., 85., 73., 73., 92., 92., 73., 73., 66., 66., 75., 75., 61., 61.,
        84., 84.]),
 'class_label': tensor([2, 2, 0, 0, 2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 2, 2]),
 'class_name': ['Dementia', 'Dementia', 'Normal', 'Normal', 'Dementia', 'Dementia', 'MCI', 'MCI', 'Normal', 'Normal', 'MCI', 'MCI', 'Normal', 'Normal', 'Dementia', 'Dementia'],
 'serial': ['00630', '00630', '00903', '00903', '01128', '01128', '00091', '00091', '00704', '00704', '00811', '00811', '00767', '00767', '00039', '00039'],
 'signal': tensor([[[-22., -25., -25.,  ...,  28.,  23.,  36.],
         [-18., -17., -11.,  ...,  17.,  16.,  11.],
         [  9.,  11.,  11.,  ...,   9.,   9.,  10.],
         ...,
         [ 14.,   9.,   5.,  ...,  -6.,  -6.,  -9.],
         [ 14.,  13.,  14.,  ..., -11., -12., -12.],
         [ -9.,  -2.,   4.,  ...,  12.,  12.,  19.]],

        [[  2.,   3.,   6.,  ..., -13., -13., -14.],
         [ 12.,  14.,  18.,  ..., -11., -12., -11.],
         [  1.,   2.,   5.,  ...,  

---

## Preprocessing steps run by the PyTorch Modules

In [25]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=2,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

### To GPU device if it is possible

In [26]:
print('device:', device)
print()

preprocess_train = transforms.Compose([EegToDevice(device=device)])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched)
    break

device: cuda

Sequential(
  (0): EegToDevice(device=cuda)
)
- Before -
{'age': tensor([73., 73., 76., 76.]),
 'class_label': tensor([0, 0, 1, 1]),
 'class_name': ['Normal', 'Normal', 'Abnormal', 'Abnormal'],
 'serial': ['00086', '00086', '00218', '00218'],
 'signal': tensor([[[ -4.,  -2.,  -2.,  ...,  10.,   7.,   4.],
         [-15., -16., -17.,  ..., -33., -36., -38.],
         [-18., -15., -13.,  ...,   7.,   1.,  -1.],
         ...,
         [-10., -12., -12.,  ..., -26., -27., -31.],
         [  1.,   1.,   3.,  ...,   0.,   0.,  -2.],
         [-24.,  44., -41.,  ..., -52., -17.,  45.]],

        [[ -7.,  -9., -14.,  ...,  -5., -10., -13.],
         [-12., -16., -20.,  ...,  -4.,  -7.,  -8.],
         [ 12.,  12.,  13.,  ...,  -1.,  -2.,  -3.],
         ...,
         [  0.,  -6., -12.,  ...,  -3.,  -4.,  -4.],
         [  9.,  11.,  11.,  ...,  -3.,  -1.,   0.],
         [-40., -12.,  54.,  ..., -49.,  13.,   0.]],

        [[  1.,  -2.,   1.,  ..., -37., -32., -43.],
         [-

### Normalization per signal

In [27]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizePerSignal()
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizePerSignal(eps=1e-08)
)
- Before -
Mean: tensor([[-3.4385, -6.0560, -0.6960, -0.1600,  0.4090, -4.8755, -1.3905, -1.3640,
         -0.1860,  0.5235,  1.2925, -0.0290,  0.2905,  1.5840, -2.2205,  0.6550,
          6.2580,  2.0415, -0.9345, -0.2405],
        [ 2.4240,  3.6905, -0.5770, -0.3915, -0.2865, -2.7200, -0.5115,  0.8420,
         -0.5460, -0.0955,  0.1885, -1.9045,  0.3300, -2.4885,  2.3225,  0.0200,
          1.9420, -2.5160, -0.0480,  0.0365],
        [ 2.0325,  1.3565, -0.5185, -0.2350, -0.2260,  1.4070,  2.4470, -1.9080,
          0.5040,  0.5730, -0.4885,  0.6505,  1.3050, -3.2440, -1.8865,  1.0555,
         -0.9300,  2.0435, -0.4770,  0.2650],
        [-2.3570,  2.4550, -0.3360,  0.0635, -0.4625, -5.5255, -1.3445,  1.5420,
          0.0700, -0.4360, -4.1820,  0.4955, -0.4190, -1.8970,  2.6785, -0.2130,
          0.6095,  1.1950,  0.1090,  1.0620]])

Std: tensor([[ 63.7045,  20.2194,  10.6914,   8.9998,  10.3299,

### Signal normalization using the specified mean and std values

In [28]:
signal_mean, signal_std = calculate_signal_statistics(train_loader, repeats=1, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Mean and standard deviation for signal:
tensor([[[ 0.2844],
         [ 0.2253],
         [-0.0821],
         [ 0.0339],
         [-0.0259],
         [ 0.4161],
         [ 0.1093],
         [-0.0760],
         [-0.1932],
         [-0.1694],
         [-0.0005],
         [ 0.0662],
         [-0.0377],
         [ 0.1298],
         [-0.2924],
         [-0.1287],
         [ 0.2746],
         [ 0.0489],
         [-0.0198],
         [-0.0386]]])
-
tensor([[[44.3354],
         [19.7560],
         [11.2161],
         [10.9679],
         [15.1231],
         [45.9360],
         [19.8389],
         [10.3051],
         [11.4575],
         [15.2745],
         [20.2380],
         [13.9408],
         [12.9799],
         [21.6595],
         [17.7961],
         [14.4634],
         [19.4985],
         [10.7546],
         [10.9364],
         [96.2515]]])

----------------------------------------------------------------------------------------------------

Sequential(
  (0): EegToDevice(device=cuda)
  (1): 

### Age normalization

In [29]:
age_mean, age_std = calculate_age_statistics(train_loader, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['age'])

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['age'])
    break

Age mean and standard deviation:
tensor([70.6000]) tensor([6.2312])

----------------------------------------------------------------------------------------------------

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6000]),std=tensor([6.2312]),eps=1e-08)
)
- Before -
tensor([69., 69., 65., 65.])

----------------------------------------------------------------------------------------------------

- After -
tensor([-0.2568, -0.2568, -0.8987, -0.8987], device='cuda:0')


### Short time Fourier transform (STFT or spectrogram)

In [30]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['signal'].shape)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['signal'].shape)
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
)
- Before -
torch.Size([4, 20, 2000])

----------------------------------------------------------------------------------------------------

- After -
torch.Size([4, 40, 101, 41])


### Signal normalization after STFT

In [31]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)

signal_2d_mean, signal_2d_std = calculate_signal_statistics(train_loader, preprocess_train)

preprocess_train2 = transforms.Compose([
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std)
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)

pprint.pprint(preprocess_train)
pprint.pprint(preprocess_train2)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    preprocess_train(sample_batched)   
    
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    
    print()
    print('-' * 100)
    print()
    
    print('- After -')
    preprocess_train2(sample_batched)
    
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
)
Sequential(
  (0): EegNormalizeMeanStd(mean=tensor([[ 4.9821e+01,  1.4474e+00, -7.2085e-02,  ...,  9.7853e-02,
            1.0077e-01,  9.0708e-02],
          [ 9.2374e+00,  1.0931e-01,  1.5210e-01,  ...,  4.1250e-02,
            4.1423e-02,  5.0705e-02],
          [-1.6948e+00,  2.6213e-01, -1.0001e-01,  ...,  7.0841e-03,
            7.5616e-03, -1.3990e-02],
          ...,
          [ 0.0000e+00,  4.0039e-01,  2.2454e-01,  ..., -2.7715e-04,
           -1.2458e-03, -1.4335e-08],
          [ 0.0000e+00,  7.9169e-01,  4.4377e-01,  ...,  1.3421e-03,
            2.7339e-03, -2.6603e-09],
          [ 0.0000e+00, -1.8461e+00, -5.4384e-01,  ..., -8.0143e-04,
           -1.3102e-03, -8.9033e-08]], device='cuda:0'),std=tensor([[7.4402e+03, 1.5207e+03, 8.0498e+02,  ..., 2.9127e+01, 2.9083e+01,
           2.9303e+01],
          [3.2419e+03, 5.8653e+02, 3.0239e+02,  ..., 1.3144e+01

---

## Speed check without STFT

In [32]:
crop_length = 200 * 10
multiple = 4
batch_size = 128

### `EDF`

In [33]:
%%time
transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=120000, multiple=4, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6000]),std=tensor([6.2312]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([ 0.2844,  0.2253, -0.0821,  0.0339, -0.0259,  0.4161,  0.1093, -0.0760,
          -0.1932, -0.1694, -0.0005,  0.0662, -0.0377,  0.1298, -0.2924, -0.1287,
           0.2746,  0.0489, -0.0198, -0.0386]),std=tensor([44.3354, 19.7560, 11.2161, 10.9679, 15.1231, 45.9360, 19.8389, 10.3051,
          11.4575, 15.2745, 20.2380, 13.9408, 12.9799, 21.6595, 17.7961, 14.4634,
          19.4985, 10.7546, 10.9364, 96.2515]),eps=1e-08)
)
CPU times: total: 26min 59s
Wall time: 2min 15s


### `Feather`

In [34]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=120000, multiple=4, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6000]),std=tensor([6.2312]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([ 0.2844,  0.2253, -0.0821,  0.0339, -0.0259,  0.4161,  0.1093, -0.0760,
          -0.1932, -0.1694, -0.0005,  0.0662, -0.0377,  0.1298, -0.2924, -0.1287,
           0.2746,  0.0489, -0.0198, -0.0386]),std=tensor([44.3354, 19.7560, 11.2161, 10.9679, 15.1231, 45.9360, 19.8389, 10.3051,
          11.4575, 15.2745, 20.2380, 13.9408, 12.9799, 21.6595, 17.7961, 14.4634,
          19.4985, 10.7546, 10.9364, 96.2515]),eps=1e-08)
)
CPU times: total: 1min 11s
Wall time: 4.47 s


### `memmap`

In [35]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=120000, multiple=4, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6000]),std=tensor([6.2312]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([ 0.2844,  0.2253, -0.0821,  0.0339, -0.0259,  0.4161,  0.1093, -0.0760,
          -0.1932, -0.1694, -0.0005,  0.0662, -0.0377,  0.1298, -0.2924, -0.1287,
           0.2746,  0.0489, -0.0198, -0.0386]),std=tensor([44.3354, 19.7560, 11.2161, 10.9679, 15.1231, 45.9360, 19.8389, 10.3051,
          11.4575, 15.2745, 20.2380, 13.9408, 12.9799, 21.6595, 17.7961, 14.4634,
          19.4985, 10.7546, 10.9364, 96.2515]),eps=1e-08)
)
CPU times: total: 17.8 s
Wall time: 1.47 s


### `memmap` (Drop → Crop)

In [36]:
%%time

transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegDropChannels(drop_index=20)
    EegRandomCrop(crop_length=2000, length_limit=120000, multiple=4, latency=2000, return_timing=False)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6000]),std=tensor([6.2312]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([ 0.2844,  0.2253, -0.0821,  0.0339, -0.0259,  0.4161,  0.1093, -0.0760,
          -0.1932, -0.1694, -0.0005,  0.0662, -0.0377,  0.1298, -0.2924, -0.1287,
           0.2746,  0.0489, -0.0198, -0.0386]),std=tensor([44.3354, 19.7560, 11.2161, 10.9679, 15.1231, 45.9360, 19.8389, 10.3051,
          11.4575, 15.2745, 20.2380, 13.9408, 12.9799, 21.6595, 17.7961, 14.4634,
          19.4985, 10.7546, 10.9364, 96.2515]),eps=1e-08)
)
CPU times: total: 1min 17s
Wall time: 6.47 s


---

## Speed check with STFT

In [37]:
crop_length = 300 * 10
n_fft, hop_length, seq_len_2d = calculate_stft_params(seq_length=crop_length, verbose=True)
multiple = 2
batch_size = 128

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
signal_2d_mean, signal_2d_std = calculate_signal_statistics(train_loader, preprocess_train)

Input sequence length: (3000) would become (78, 77) after the STFT with n_fft (155) and hop_length (39).

----------------------------------------------------------------------------------------------------



### `EDF`

In [38]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

Compose(
    EegRandomCrop(crop_length=3000, length_limit=120000, multiple=2, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6000]),std=tensor([6.2312]),eps=1e-08)
  (2): EegSpectrogram(n_fft=155, complex_mode=as_real, stft_kwargs={'hop_length': 39})
  (3): EegNormalizeMeanStd(mean=tensor([[ 3.1320e+01,  3.4730e-01,  1.6575e-01,  ..., -3.2248e-02,
           -3.3008e-02, -3.1671e-02],
          [-1.2258e+01,  2.0957e-02,  7.7337e-02,  ..., -1.5162e-03,
           -6.2796e-03, -5.6553e-03],
          [-1.3421e-01,  1.5528e-02,  2.7589e-02,  ...,  4.3992e-04,
           -1.0055e-03, -1.1822e-03],
          ...,
          [ 0.0000e+00,  7.0140e-01,  3.1880e-01,  ..., -6.5730e-04,
            1.5833e-03, -2.0423e-04],
          [ 0.0000e+00,  2.2840e-01,  9.8081e-02,  ..., -1.5298e-03,
            1.5692e-03, -6.6039e-05],
          [ 0.0000e+00,  5.5736e-02,  8.682

### `Feather`

In [39]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

Compose(
    EegRandomCrop(crop_length=3000, length_limit=120000, multiple=2, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6000]),std=tensor([6.2312]),eps=1e-08)
  (2): EegSpectrogram(n_fft=155, complex_mode=as_real, stft_kwargs={'hop_length': 39})
  (3): EegNormalizeMeanStd(mean=tensor([[ 3.1320e+01,  3.4730e-01,  1.6575e-01,  ..., -3.2248e-02,
           -3.3008e-02, -3.1671e-02],
          [-1.2258e+01,  2.0957e-02,  7.7337e-02,  ..., -1.5162e-03,
           -6.2796e-03, -5.6553e-03],
          [-1.3421e-01,  1.5528e-02,  2.7589e-02,  ...,  4.3992e-04,
           -1.0055e-03, -1.1822e-03],
          ...,
          [ 0.0000e+00,  7.0140e-01,  3.1880e-01,  ..., -6.5730e-04,
            1.5833e-03, -2.0423e-04],
          [ 0.0000e+00,  2.2840e-01,  9.8081e-02,  ..., -1.5298e-03,
            1.5692e-03, -6.6039e-05],
          [ 0.0000e+00,  5.5736e-02,  8.682

### `memmap`

In [40]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

Compose(
    EegRandomCrop(crop_length=3000, length_limit=120000, multiple=2, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6000]),std=tensor([6.2312]),eps=1e-08)
  (2): EegSpectrogram(n_fft=155, complex_mode=as_real, stft_kwargs={'hop_length': 39})
  (3): EegNormalizeMeanStd(mean=tensor([[ 3.1320e+01,  3.4730e-01,  1.6575e-01,  ..., -3.2248e-02,
           -3.3008e-02, -3.1671e-02],
          [-1.2258e+01,  2.0957e-02,  7.7337e-02,  ..., -1.5162e-03,
           -6.2796e-03, -5.6553e-03],
          [-1.3421e-01,  1.5528e-02,  2.7589e-02,  ...,  4.3992e-04,
           -1.0055e-03, -1.1822e-03],
          ...,
          [ 0.0000e+00,  7.0140e-01,  3.1880e-01,  ..., -6.5730e-04,
            1.5833e-03, -2.0423e-04],
          [ 0.0000e+00,  2.2840e-01,  9.8081e-02,  ..., -1.5298e-03,
            1.5692e-03, -6.6039e-05],
          [ 0.0000e+00,  5.5736e-02,  8.682

---

## Test on longer sequence

In [41]:
%%time
longer_transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10*6,     # crop: 1m
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(longer_transform)

config_data, longer_test_dataset = load_caueeg_task_split(dataset_path=data_path, 
                                                          task='task2', 
                                                          split='test',
                                                          load_event=False,
                                                          file_format='feather', 
                                                          transform=longer_transform)

longer_test_loader = DataLoader(longer_test_dataset,
                                batch_size=32,
                                shuffle=True,
                                drop_last=False,
                                num_workers=num_workers,
                                pin_memory=pin_memory,
                                collate_fn=eeg_collate_fn)
 
preprocess_test = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
    EegNormalizeAge(mean=age_mean, std=age_std),
])
preprocess_test = torch.nn.Sequential(*preprocess_test.transforms).to(device)
pprint.pprint(preprocess_test)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_test(sample_batched)

Compose(
    EegRandomCrop(crop_length=12000, length_limit=120000, multiple=2, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeMeanStd(mean=tensor([ 0.2844,  0.2253, -0.0821,  0.0339, -0.0259,  0.4161,  0.1093, -0.0760,
          -0.1932, -0.1694, -0.0005,  0.0662, -0.0377,  0.1298, -0.2924, -0.1287,
           0.2746,  0.0489, -0.0198, -0.0386]),std=tensor([44.3354, 19.7560, 11.2161, 10.9679, 15.1231, 45.9360, 19.8389, 10.3051,
          11.4575, 15.2745, 20.2380, 13.9408, 12.9799, 21.6595, 17.7961, 14.4634,
          19.4985, 10.7546, 10.9364, 96.2515]),eps=1e-08)
  (2): EegNormalizeAge(mean=tensor([70.6000]),std=tensor([6.2312]),eps=1e-08)
)
CPU times: total: 12 s
Wall time: 1.02 s
