# Dataset and DataLoader

This notebook loads the `CAUEEG` dataset, tests some useful preprocessing, and makes up the PyTorch DataLoader instances for the training.

-----

## Configurations

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
# Load some packages
import os
import glob
import json
import pprint

import numpy as np
import random
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

# custom package
from datasets.caueeg_dataset import *
from datasets.caueeg_script import *
from datasets.pipeline import *

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.11.0+cu113
cuda is available.


In [4]:
# Data file path
data_path = r'local/dataset/02_Curated_Data_220412/'

-----

## Load the CAUEEG dataset

### Load the whole CAUEEG data as a PyTorch dataset instance without considering the target task (no train/val/test sets and no class label).

In [5]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='edf',
                                                         transform=None)

pprint.pprint(config_data, width=250)
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[1])
print('\n', '-' * 100, '\n')

{'dataset_name': 'CAUEEG dataset',
 'signal_header': ['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 78,
 'serial': '00001',
 'signal': array([[  0., -11., -13., ...,   0.,   0.,   0.],
       [ 29.,  33.,  34., ...,   0.,   0.,   0.],
       [ -3.,  -6.,  -3., ...,   0.,   0.,   0.],
       ...,
       [ -4.,  -2.,   1., ...,   0.,   0.,   0.],
       [112.,  67.,  76., ...,   0.,   0.,   0.],
       [ -1.,  -1.,  -1., ...,   0.,   0.,   0.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 56,
 'serial': '00002',
 'signal': array([[  39.,   58.,   72., ...,    0.,    0.,    0.],
       [   4.,

### Load the CAUEEG task1 datasets as the PyTorch dataset instances.

In [6]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')

pprint.pprint(train_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(val_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal], [MCI], and [Dementia] '
                     'symptoms.',
 'task_name': 'CAUEEG-task1 benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 81,
 'class_label': 2,
 'class_name': 'Dementia',
 'serial': '01222',
 'signal': array([[  1.,   1.,  -1., ...,   0.,   0.,   0.],
       [-24., -14., -14., ...,   0.,   0.,   0.],
       [ -6.,  -5.,  -6., ...,   0.,   0.,   0.],
       ...,
       [ -1.,   6.,   6., ...,   0.,   0.,   0.],
       [-42., -62., -59., ...,   0.,   0.,   0.],
       [  0.,   0.,  -1., ...,   0.,   0.,   0.]]),
 'symptom': ['dementia', 'ad', 'load']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 74,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00857',
 'si

### Load the CAUEEG task2 datasets as the PyTorch dataset instances.

In [7]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')

pprint.pprint(train_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(val_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'Abnormal'],
 'class_name_to_label': {'Abnormal': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal] and [Abnormal] symptoms',
 'task_name': 'CAUEEG-task2 benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 66,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00085',
 'signal': array([[ 39.,  87.,  85., ...,   0.,   0.,   0.],
       [ 17.,  64.,  67., ...,   0.,   0.,   0.],
       [ -1.,   9.,  12., ...,   0.,   0.,   0.],
       ...,
       [ 10.,   3.,   3., ...,   0.,   0.,   0.],
       [ 30.,  46., -30., ...,   0.,   0.,   0.],
       [  2.,   2.,   1., ...,   0.,   0.,   0.]]),
 'symptom': ['normal', 'cb_normal']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 79,
 'class_label': 1,
 'class_name': 'Abnormal',
 'serial': '01203',
 'signal': array([[117., 165., 163., ...,   0.,   0.,   0

### Event information

In [8]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=True, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(train_dataset[0])

{'age': 81,
 'class_label': 2,
 'class_name': 'Dementia',
 'event': [[0, 'Start Recording'],
           [0, 'New Montage - Montage 002'],
           [13926, 'Eyes Open'],
           [14562, 'Eyes Closed'],
           [16000, 'Paused'],
           [18000, 'Recording Resumed'],
           [20984, 'Eyes Open'],
           [21898, 'Eyes Closed'],
           [23258, 'Eyes Open'],
           [23866, 'Eyes Closed'],
           [38064, 'Eyes Open'],
           [53800, 'Paused'],
           [55600, 'Recording Resumed'],
           [62090, 'artifact'],
           [62800, 'Paused'],
           [64600, 'Recording Resumed'],
           [66332, 'Move'],
           [70290, 'Move'],
           [71122, 'Move'],
           [73030, 'Move'],
           [73762, 'Move'],
           [74600, 'Paused'],
           [86600, 'Recording Resumed'],
           [88970, 'swallowing'],
           [90020, 'Eyes Closed'],
           [91490, 'Eyes Open'],
           [92708, 'Eyes Closed'],
           [95124, 'Eyes Open'],

### Data Format: `EDF`

In [9]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'serial': '00085', 'age': 66, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0, 'signal': array([[ 39.,  87.,  85., ...,   0.,   0.,   0.],
       [ 17.,  64.,  67., ...,   0.,   0.,   0.],
       [ -1.,   9.,  12., ...,   0.,   0.,   0.],
       ...,
       [ 10.,   3.,   3., ...,   0.,   0.,   0.],
       [ 30.,  46., -30., ...,   0.,   0.,   0.],
       [  2.,   2.,   1., ...,   0.,   0.,   0.]])}
{'serial': '00790', 'age': 59, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0, 'signal': array([[  45.,   -7.,   -9., ...,    0.,    0.,    0.],
       [   1.,   -2.,  -11., ...,    0.,    0.,    0.],
       [   2.,   -6.,   -5., ...,    0.,    0.,    0.],
       ...,
       [  -4.,    2.,    4., ...,    0.,    0.,    0.],
       [ -71., -115.,  -98., ...,    0.,    0.,    0.],
       [   0.,    0.,   -1., ...,    0.,    0.,    0.]])}
CPU times: total: 188 ms
Wall time: 189 ms


### Data Format: `PyArrow Feather`

In [10]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'serial': '00085', 'age': 66, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0, 'signal': array([[ 39,  87,  85, ..., -52, -49, -45],
       [ 17,  64,  67, ..., -62, -58, -56],
       [ -1,   9,  12, ...,  -8,  -7,  -6],
       ...,
       [ 10,   3,   3, ...,  -1,   0,   2],
       [ 30,  46, -30, ...,  -6, -17, -15],
       [  2,   2,   1, ...,   1,   0,   0]])}
{'serial': '00790', 'age': 59, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0, 'signal': array([[  45,   -7,   -9, ...,   15,   18,   16],
       [   1,   -2,  -11, ...,  -15,  -14,  -14],
       [   2,   -6,   -5, ...,   12,   10,   11],
       ...,
       [  -4,    2,    4, ...,   -2,   -1,   -1],
       [ -71, -115,  -98, ...,    3,    5,   -2],
       [   0,    0,   -1, ...,    1,   -1,    0]])}
CPU times: total: 234 ms
Wall time: 245 ms


### Data Format: `NumPy Memmap`

In [11]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'serial': '00085', 'age': 66, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0, 'signal': memmap([[ 39,  87,  85, ..., -52, -49, -45],
        [ 17,  64,  67, ..., -62, -58, -56],
        [ -1,   9,  12, ...,  -8,  -7,  -6],
        ...,
        [ 10,   3,   3, ...,  -1,   0,   2],
        [ 30,  46, -30, ...,  -6, -17, -15],
        [  2,   2,   1, ...,   1,   0,   0]])}
{'serial': '00790', 'age': 59, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0, 'signal': memmap([[  45,   -7,   -9, ...,   15,   18,   16],
        [   1,   -2,  -11, ...,  -15,  -14,  -14],
        [   2,   -6,   -5, ...,   12,   10,   11],
        ...,
        [  -4,    2,    4, ...,   -2,   -1,   -1],
        [ -71, -115,  -98, ...,    3,    5,   -2],
        [   0,    0,   -1, ...,    1,   -1,    0]])}
CPU times: total: 0 ns
Wall time: 3.51 ms


---

## PyTorch Transforms

### Random crop

In [12]:
transform = EegRandomCrop(crop_length=100)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path,
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', d['signal'].shape)
    print('\n', '-' * 100, '\n')

{'age': 66,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00085',
 'signal': array([[-33, -32, -33, ...,  28,  29,  29],
       [-34, -36, -37, ..., -30, -23, -19],
       [-12, -10, -10, ..., -11, -11, -10],
       ...,
       [-11, -11, -10, ...,  -6,  -4,  -2],
       [ -5,   8,   4, ...,  -2,   7,  -6],
       [ -1,  -1,   0, ...,   0,   0,  -1]]),
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: (21, 100)

 ---------------------------------------------------------------------------------------------------- 

{'age': 66,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00085',
 'signal': array([[-65, -65, -65, ..., -40, -39, -41],
       [ 31,  33,  33, ...,  37,  36,  36],
       [ -7,  -6,  -3, ...,  19,  19,  20],
       ...,
       [-11, -10,  -9, ..., -11,  -9,  -9],
       [ -2,   5,  -2, ...,  -1,   6,  -5],
       [  0,  -1,   0, ...,   0,  -1,   2]]),
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: (21, 100)

 -------------------------------

### Random crop with multiple cropping

In [13]:
transform = EegRandomCrop(crop_length=200, multiple=2)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 66,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00085',
 'signal': [array([[ 85,  91,  92, ..., -32, -35, -34],
       [-18, -15, -16, ...,   3,   0,   1],
       [-16, -16, -15, ..., -16, -16, -15],
       ...,
       [ -7,  -8,  -8, ...,  16,  16,  15],
       [ -5,  -8,   2, ...,  -6,   1,  10],
       [  0,   0,   0, ...,  -1,   0,  -1]]),
            array([[ 12,  13,  15, ..., -24, -23, -22],
       [ -9,  -7,  -7, ..., -58, -56, -55],
       [  6,   7,   8, ..., -12, -10,  -7],
       ...,
       [ 10,   9,   9, ...,  -7,  -6,  -6],
       [  8,  -1, -10, ...,  10,  -3,  -2],
       [  0,   0,   0, ...,  -1,   0,   1]])],
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: [(21, 200), (21, 200)]

 ---------------------------------------------------------------------------------------------------- 

{'age': 66,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00085',
 'signal': [array([[-76, -75, -69, ..., -24, -24, -20],
       [  4,   5,   8, ...,

### Random crop with multiple cropping and latency

In [14]:
transform = EegRandomCrop(crop_length=300, multiple=3, latency=50000, return_timing=True)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
for i in range(2): 
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 66,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00085',
 'signal': [array([[ -9,  -9,  -8, ...,  23,  23,  23],
       [-30, -32, -33, ..., -28, -28, -28],
       [ -3,  -5,  -7, ...,  -3,  -4,  -5],
       ...,
       [  0,   0,  -2, ...,   0,   0,  -1],
       [  6,  12,  -2, ...,   6,   5,  -8],
       [ -3,   1,   1, ...,  -3,   1,   0]]),
            array([[-21, -21, -20, ...,  -5,  -5,  -9],
       [ 84,  83,  84, ...,  10,   9,   7],
       [  8,   8,   9, ...,   7,   6,   5],
       ...,
       [  8,   6,   6, ...,   2,   1,   2],
       [ -1,   7,   1, ...,  -2,   4,  -5],
       [ -1,   0,  -1, ...,  -2,  -3,   0]]),
            array([[-23, -25, -25, ..., -47, -46, -45],
       [  9,   8,   9, ...,  -2,  -2,  -3],
       [  5,   7,   8, ...,   9,   9,   7],
       ...,
       [ -4,  -2,   0, ...,   7,   7,   7],
       [ -8,   2,   6, ...,  -7,   4,   3],
       [ -4,   0,   1, ...,   0,   1,   2]])],
 'start_point': [95773, 69063, 109772],
 'symptom': [

### Random crop with multiple cropping, latency, and max length limit

In [15]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200, 
                  length_limit=50300,
                  multiple=3, 
                  latency=50000, 
                  return_timing=True)
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 81,
 'class_label': 2,
 'class_name': 'Dementia',
 'serial': '01222',
 'signal': [array([[-16, -20, -23, ...,  -4,  -6,  -6],
       [-13, -14, -13, ..., -15, -15, -15],
       [ 14,  12,  11, ...,  -5,  -4,  -2],
       ...,
       [ -9,  -9,  -8, ..., -11, -10,  -8],
       [ -4,  -7,  -6, ..., -15, -13, -10],
       [  0,   1,   0, ...,   2,   0,   1]]),
            array([[-58, -60, -62, ..., -45, -42, -41],
       [-14, -17, -18, ..., -19, -16, -14],
       [ 17,  16,  13, ...,   5,   8,   8],
       ...,
       [ -9, -10, -12, ...,   3,   1,  -2],
       [  3,   3,   7, ...,  61,  27,  20],
       [ -2,  -2,   0, ...,   1,  -3,  -3]]),
            array([[-60, -62, -63, ..., -28, -31, -30],
       [-16, -15, -17, ..., -18, -21, -21],
       [ 14,  12,  10, ...,  -3,  -6,  -6],
       ...,
       [  0,   0,  -1, ...,  -2,  -3,  -4],
       [-10, -19, -16, ..., -23, -24, -24],
       [  2,   0,   1, ...,   2,   2,   0]])],
 'start_point': [50037, 50085, 50054],
 'symptom': 

### Drop channel(s)

In [16]:
anno_path = os.path.join(data_path, 'annotation.json')
with open(anno_path, 'r') as json_file:
    annotation = json.load(json_file)
signal_headers = annotation['signal_header']
del annotation
print(signal_headers)

channel_ekg = signal_headers.index('EKG')
print('channel_ekg: ', channel_ekg)

channel_photic = signal_headers.index('Photic')
print('channel_photic: ', channel_photic)

['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']
channel_ekg:  19
channel_photic:  20


In [17]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=EegDropChannels(channel_ekg))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 156600)
[[  1   1  -1 ...   2   3   3]
 [-24 -14 -14 ... -12  -9  -8]
 [ -6  -5  -6 ...  -3  -1  -1]
 ...
 [ -1   6   6 ...  -1  -2  -2]
 [-42 -62 -59 ... -12 -14 -16]
 [  0   0  -1 ...  -1   0   0]]

----------------------------------------------------------------------------------------------------

after: (20, 156600)
[[  1   1  -1 ...   2   3   3]
 [-24 -14 -14 ... -12  -9  -8]
 [ -6  -5  -6 ...  -3  -1  -1]
 ...
 [ 14  22  23 ...  -5  -4  -3]
 [ -1   6   6 ...  -1  -2  -2]
 [  0   0  -1 ...  -1   0   0]]


In [18]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=EegDropChannels(channel_photic))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 156600)
[[  1   1  -1 ...   2   3   3]
 [-24 -14 -14 ... -12  -9  -8]
 [ -6  -5  -6 ...  -3  -1  -1]
 ...
 [ -1   6   6 ...  -1  -2  -2]
 [-42 -62 -59 ... -12 -14 -16]
 [  0   0  -1 ...  -1   0   0]]

----------------------------------------------------------------------------------------------------

after: (20, 156600)
[[  1   1  -1 ...   2   3   3]
 [-24 -14 -14 ... -12  -9  -8]
 [ -6  -5  -6 ...  -3  -1  -1]
 ...
 [ 14  22  23 ...  -5  -4  -3]
 [ -1   6   6 ...  -1  -2  -2]
 [-42 -62 -59 ... -12 -14 -16]]


In [19]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=EegDropChannels([channel_ekg, channel_photic]))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 156600)
[[  1   1  -1 ...   2   3   3]
 [-24 -14 -14 ... -12  -9  -8]
 [ -6  -5  -6 ...  -3  -1  -1]
 ...
 [ -1   6   6 ...  -1  -2  -2]
 [-42 -62 -59 ... -12 -14 -16]
 [  0   0  -1 ...  -1   0   0]]

----------------------------------------------------------------------------------------------------

after: (19, 156600)
[[  1   1  -1 ...   2   3   3]
 [-24 -14 -14 ... -12  -9  -8]
 [ -6  -5  -6 ...  -3  -1  -1]
 ...
 [  9  18  19 ...  -3  -2  -1]
 [ 14  22  23 ...  -5  -4  -3]
 [ -1   6   6 ...  -1  -2  -2]]


### To Tensor

In [20]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='feather',
                                                         transform=None)
print('Before:')
pprint.pprint(full_eeg_dataset[0])

print()
print('-' * 100)
print()

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='feather',
                                                         transform=EegToTensor())
print('After:')
pprint.pprint(full_eeg_dataset[0])

Before:
{'age': 78,
 'serial': '00001',
 'signal': array([[  0, -11, -13, ...,  18,  21,  22],
       [ 29,  33,  34, ...,  -7,  -4,  -4],
       [ -3,  -6,  -3, ...,  -1,   1,   2],
       ...,
       [ -4,  -2,   1, ...,   0,  -1,  -1],
       [112,  67,  76, ..., -13, -15, -11],
       [ -1,  -1,  -1, ...,  -1,  -1,  -1]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

----------------------------------------------------------------------------------------------------

After:
{'age': tensor(78.),
 'serial': '00001',
 'signal': tensor([[  0., -11., -13.,  ...,  18.,  21.,  22.],
        [ 29.,  33.,  34.,  ...,  -7.,  -4.,  -4.],
        [ -3.,  -6.,  -3.,  ...,  -1.,   1.,   2.],
        ...,
        [ -4.,  -2.,   1.,  ...,   0.,  -1.,  -1.],
        [112.,  67.,  76.,  ..., -13., -15., -11.],
        [ -1.,  -1.,  -1.,  ...,  -1.,  -1.,  -1.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}


### Compose the above all in one

In [21]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=4, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

pprint.pprint(train_dataset[0])

{'age': tensor(81.),
 'class_label': tensor(2),
 'class_name': 'Dementia',
 'serial': '01222',
 'signal': [tensor([[-59., -59., -55.,  ..., -58., -65., -65.],
        [-20., -19., -17.,  ..., -49., -51., -53.],
        [  6.,   7.,   9.,  ...,  10.,   8.,   9.],
        ...,
        [  2.,   3.,   4.,  ...,   7.,   8.,   9.],
        [ 16.,  17.,  16.,  ...,  18.,  19.,  20.],
        [ 17.,  28.,  20.,  ...,  -5.,  -7.,  -5.]]),
            tensor([[ -6.,  -9., -15.,  ..., -44., -47., -49.],
        [  2.,   1.,  -1.,  ..., -21., -19., -20.],
        [ -8., -10., -12.,  ...,   6.,   7.,   7.],
        ...,
        [  0.,   2.,   4.,  ...,  -4.,  -2.,   1.],
        [  7.,   7.,   7.,  ...,  -6.,  -7.,  -5.],
        [  4.,   0.,  -7.,  ...,  -1.,   6.,   9.]]),
            tensor([[ 17.,  10.,   9.,  ..., -25., -23., -31.],
        [  0.,  -1.,  -1.,  ...,   1.,  -1.,   0.],
        [-28., -27., -27.,  ...,  -2.,  -1.,  -3.],
        ...,
        [  4.,   3.,   2.,  ...,   6.,   4.,  

---

## PyTorch DataLoader

In [22]:
if device.type == 'cuda':
    num_workers = 0  # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

In [23]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='memmap',
                                                         transform=transform)

full_loader = DataLoader(full_eeg_dataset,
                         batch_size=4,
                         shuffle=True,
                         drop_last=True,
                         num_workers=num_workers,
                         pin_memory=pin_memory,
                         collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(full_loader):
    pprint.pprint(sample_batched)
    break

{'age': tensor([69., 69., 69., 69., 55., 55., 76., 76.]),
 'serial': ['00399',
            '00399',
            '00823',
            '00823',
            '00816',
            '00816',
            '00940',
            '00940'],
 'signal': tensor([[[  -1.,    0.,    2.,  ...,    5.,    6.,    8.],
         [ -11.,   -9.,   -9.,  ...,    7.,   11.,   14.],
         [  12.,   12.,   12.,  ...,    9.,   11.,   14.],
         ...,
         [  -3.,   -1.,    1.,  ...,   -7.,   -4.,   -1.],
         [  -3.,   -3.,   -3.,  ...,   -4.,   -3.,   -3.],
         [ 120.,  118.,  115.,  ...,  106.,   99.,   91.]],

        [[  25.,   27.,   28.,  ...,  -18.,  -24.,  -25.],
         [ -66.,  -65.,  -63.,  ...,  -27.,  -29.,  -29.],
         [  10.,   12.,   14.,  ...,   23.,   23.,   24.],
         ...,
         [   6.,    6.,    7.,  ...,    5.,    6.,    7.],
         [   9.,    8.,    7.,  ...,    5.,    5.,    4.],
         [  -5.,   -9.,  -10.,  ...,  114.,  116.,  119.]],

        [[  13.,   11.

In [24]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task1',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=8,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    pprint.pprint(sample_batched, width=250)
    break

{'age': tensor([77., 77., 93., 93., 68., 68., 59., 59., 84., 84., 78., 78., 64., 64.,
        59., 59.]),
 'class_label': tensor([0, 0, 2, 2, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0]),
 'class_name': ['Normal', 'Normal', 'Dementia', 'Dementia', 'Normal', 'Normal', 'Dementia', 'Dementia', 'Normal', 'Normal', 'Normal', 'Normal', 'Normal', 'Normal', 'Normal', 'Normal'],
 'serial': ['00697', '00697', '00928', '00928', '00675', '00675', '01006', '01006', '01013', '01013', '00050', '00050', '01010', '01010', '00864', '00864'],
 'signal': tensor([[[-22., -19., -17.,  ...,  21.,  21.,  22.],
         [ -5.,  -5.,  -6.,  ..., -11., -10.,  -7.],
         [ 13.,  11.,   9.,  ...,   0.,   3.,   6.],
         ...,
         [ -1.,  -6.,  -9.,  ..., -18., -20., -18.],
         [  5.,   1.,  -2.,  ...,   3.,   3.,   3.],
         [ -5.,  -1.,   9.,  ..., -12., -20., -15.]],

        [[ 30.,  25.,  25.,  ...,   8.,   2.,   5.],
         [ 16.,  13.,   9.,  ...,   2.,   3.,   1.],
         [  7.,   4.,   2.,

---

## Preprocessing steps run by the PyTorch Modules

In [25]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=2,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

### To GPU device if it is possible

In [26]:
print('device:', device)
print()

preprocess_train = transforms.Compose([EegToDevice(device=device)])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched)
    break

device: cuda

Sequential(
  (0): EegToDevice(device=cuda)
)
- Before -
{'age': tensor([82., 82., 74., 74.]),
 'class_label': tensor([1, 1, 1, 1]),
 'class_name': ['Abnormal', 'Abnormal', 'Abnormal', 'Abnormal'],
 'serial': ['00580', '00580', '00563', '00563'],
 'signal': tensor([[[ -43.,  -43.,  -45.,  ...,  -10.,  -11.,  -16.],
         [  -4.,   -7.,  -12.,  ...,   14.,   14.,   13.],
         [   1.,    2.,    0.,  ...,    1.,    3.,    3.],
         ...,
         [  15.,   11.,    6.,  ...,   -6.,   -5.,   -4.],
         [  11.,   14.,   14.,  ...,   -7.,   -8.,   -8.],
         [ -90., -112., -118.,  ...,   61.,   72.,   76.]],

        [[ -12.,  -13.,  -16.,  ...,  -10.,  -12.,  -12.],
         [   6.,    2.,   -2.,  ...,  -13.,  -19.,  -23.],
         [  -8.,  -10.,  -12.,  ...,  -10.,  -14.,  -14.],
         ...,
         [   1.,   -1.,   -3.,  ...,   -3.,   -3.,   -1.],
         [  12.,   15.,   17.,  ...,  -12.,  -13.,   -9.],
         [  60.,   43.,   23.,  ...,  -44.,  -44.

### Normalization per signal

In [27]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizePerSignal()
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizePerSignal(eps=1e-08)
)
- Before -
Mean: tensor([[  6.4520,   4.1240,   5.4440,  -2.4650,  -0.5790,   8.6005,   4.8075,
         -13.1735,  -7.1335,  -3.3060,   5.6695,  -2.9920,   3.6250,  -0.0810,
           8.9990,   0.6855,   0.8340,  -1.2190,  -3.2995,   0.1295],
        [  0.6530,   2.8260,   0.8645,   1.1610,   1.0980,   1.6790,  -1.1485,
          -3.1490,  -2.7160,  -1.8175,  -0.3910,   0.5745,   1.4915,   1.1840,
           3.7535,  -0.3780,  -2.7430,  -0.0540,  -0.5885,  -0.0520],
        [  2.8785,   2.8445,   0.8880,   0.0995,   0.1575,  -5.0670,  -1.5020,
          -0.2590,  -0.5920,  -0.2060,   1.3025,   1.0405,   0.5200,  -5.5765,
           0.1480,  -0.9880,   0.7060,   1.0980,   0.2615,  -1.3680],
        [ -3.5925,   0.4550,   0.1645,   0.2800,   1.1640,  -7.1715,  -3.1490,
          -1.2005,   0.1525,   1.4280,   1.5335,   1.1865,   1.2040,  -1.0580,
           0.0375,   0.7515,  -3.2665,   0.5960,  -0.24

### Signal normalization using the specified mean and std values

In [28]:
signal_mean, signal_std = calculate_signal_statistics(train_loader, repeats=1, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Mean and standard deviation for signal:
tensor([[[ 0.0112],
         [-0.0979],
         [-0.1854],
         [-0.0261],
         [ 0.2286],
         [ 0.2173],
         [ 0.0367],
         [-0.0019],
         [ 0.0108],
         [ 0.0986],
         [-0.3425],
         [-0.0903],
         [ 0.0697],
         [-0.0039],
         [ 0.3391],
         [ 0.0294],
         [ 0.1369],
         [-0.0167],
         [-0.0780],
         [-0.0110]]])
-
tensor([[[46.4446],
         [20.7298],
         [12.0584],
         [12.2802],
         [16.0450],
         [51.3411],
         [21.0485],
         [10.9341],
         [11.9656],
         [16.5442],
         [21.6827],
         [15.0109],
         [13.8311],
         [22.3359],
         [17.9284],
         [15.6722],
         [20.1076],
         [11.5512],
         [13.0562],
         [97.4733]]])

----------------------------------------------------------------------------------------------------

Sequential(
  (0): EegToDevice(device=cuda)
  (1): 

### Age normalization

In [29]:
age_mean, age_std = calculate_age_statistics(train_loader, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['age'])

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['age'])
    break

Age mean and standard deviation:
tensor([70.6025]) tensor([6.4713])

----------------------------------------------------------------------------------------------------

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.4713]),eps=1e-08)
)
- Before -
tensor([73., 73., 62., 62.])

----------------------------------------------------------------------------------------------------

- After -
tensor([ 0.3705,  0.3705, -1.3293, -1.3293], device='cuda:0')


### Short time Fourier transform (STFT or spectrogram)

In [30]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['signal'].shape)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['signal'].shape)
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
)
- Before -
torch.Size([4, 20, 2000])

----------------------------------------------------------------------------------------------------

- After -
torch.Size([4, 40, 101, 41])


### Signal normalization after STFT

In [31]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)

signal_2d_mean, signal_2d_std = calculate_signal_statistics(train_loader, preprocess_train)

preprocess_train2 = transforms.Compose([
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std)
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)

pprint.pprint(preprocess_train)
pprint.pprint(preprocess_train2)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    preprocess_train(sample_batched)   
    
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    
    print()
    print('-' * 100)
    print()
    
    print('- After -')
    preprocess_train2(sample_batched)
    
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
)
Sequential(
  (0): EegNormalizeMeanStd(mean=tensor([[-3.1690e+01, -3.7667e-01,  1.5687e-03,  ..., -8.1782e-03,
           -8.0691e-03, -1.3191e-02],
          [-4.7364e+00, -1.3662e-03,  9.4536e-03,  ...,  5.5986e-03,
            2.6490e-03,  1.5448e-02],
          [ 3.3595e-01, -7.8127e-02, -1.9920e-03,  ...,  8.7425e-03,
            7.8993e-03, -7.4004e-03],
          ...,
          [ 0.0000e+00, -3.0341e-01, -1.0968e-01,  ...,  4.6328e-04,
            6.6100e-04, -2.7577e-09],
          [ 0.0000e+00, -3.7279e-01, -1.8398e-01,  ...,  7.3652e-05,
           -1.0610e-04, -2.2872e-08],
          [ 0.0000e+00, -9.4285e-01, -8.8958e-01,  ...,  2.8520e-03,
           -4.1753e-03,  1.4242e-07]], device='cuda:0'),std=tensor([[7.6877e+03, 1.5605e+03, 8.1076e+02,  ..., 2.9823e+01, 2.9799e+01,
           3.0013e+01],
          [3.2892e+03, 6.0161e+02, 3.0946e+02,  ..., 1.3446e+01

---

## Speed check without STFT

In [32]:
crop_length = 200 * 10
multiple = 4
batch_size = 128

### `EDF`

In [33]:
%%time
transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=120000, multiple=4, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.4713]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([ 0.0112, -0.0979, -0.1854, -0.0261,  0.2286,  0.2173,  0.0367, -0.0019,
           0.0108,  0.0986, -0.3425, -0.0903,  0.0697, -0.0039,  0.3391,  0.0294,
           0.1369, -0.0167, -0.0780, -0.0110]),std=tensor([46.4446, 20.7298, 12.0584, 12.2802, 16.0450, 51.3411, 21.0485, 10.9341,
          11.9656, 16.5442, 21.6827, 15.0109, 13.8311, 22.3359, 17.9284, 15.6722,
          20.1076, 11.5512, 13.0562, 97.4733]),eps=1e-08)
)
CPU times: total: 26min 55s
Wall time: 2min 15s


### `Feather`

In [34]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=120000, multiple=4, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.4713]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([ 0.0112, -0.0979, -0.1854, -0.0261,  0.2286,  0.2173,  0.0367, -0.0019,
           0.0108,  0.0986, -0.3425, -0.0903,  0.0697, -0.0039,  0.3391,  0.0294,
           0.1369, -0.0167, -0.0780, -0.0110]),std=tensor([46.4446, 20.7298, 12.0584, 12.2802, 16.0450, 51.3411, 21.0485, 10.9341,
          11.9656, 16.5442, 21.6827, 15.0109, 13.8311, 22.3359, 17.9284, 15.6722,
          20.1076, 11.5512, 13.0562, 97.4733]),eps=1e-08)
)
CPU times: total: 1min 3s
Wall time: 4.01 s


### `memmap`

In [35]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=120000, multiple=4, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.4713]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([ 0.0112, -0.0979, -0.1854, -0.0261,  0.2286,  0.2173,  0.0367, -0.0019,
           0.0108,  0.0986, -0.3425, -0.0903,  0.0697, -0.0039,  0.3391,  0.0294,
           0.1369, -0.0167, -0.0780, -0.0110]),std=tensor([46.4446, 20.7298, 12.0584, 12.2802, 16.0450, 51.3411, 21.0485, 10.9341,
          11.9656, 16.5442, 21.6827, 15.0109, 13.8311, 22.3359, 17.9284, 15.6722,
          20.1076, 11.5512, 13.0562, 97.4733]),eps=1e-08)
)
CPU times: total: 18 s
Wall time: 1.5 s


### `memmap` (Drop → Crop)

In [36]:
%%time

transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegDropChannels(drop_index=20)
    EegRandomCrop(crop_length=2000, length_limit=120000, multiple=4, latency=2000, return_timing=False)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.4713]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([ 0.0112, -0.0979, -0.1854, -0.0261,  0.2286,  0.2173,  0.0367, -0.0019,
           0.0108,  0.0986, -0.3425, -0.0903,  0.0697, -0.0039,  0.3391,  0.0294,
           0.1369, -0.0167, -0.0780, -0.0110]),std=tensor([46.4446, 20.7298, 12.0584, 12.2802, 16.0450, 51.3411, 21.0485, 10.9341,
          11.9656, 16.5442, 21.6827, 15.0109, 13.8311, 22.3359, 17.9284, 15.6722,
          20.1076, 11.5512, 13.0562, 97.4733]),eps=1e-08)
)
CPU times: total: 1min 21s
Wall time: 6.77 s


---

## Speed check with STFT

In [37]:
crop_length = 200 * 10
multiple = 2
batch_size = 128

### `EDF`

In [38]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=200, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=120000, multiple=2, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.4713]),eps=1e-08)
  (2): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
  (3): EegNormalizeMeanStd(mean=tensor([[-3.1690e+01, -3.7667e-01,  1.5687e-03,  ..., -8.1782e-03,
           -8.0691e-03, -1.3191e-02],
          [-4.7364e+00, -1.3662e-03,  9.4536e-03,  ...,  5.5986e-03,
            2.6490e-03,  1.5448e-02],
          [ 3.3595e-01, -7.8127e-02, -1.9920e-03,  ...,  8.7425e-03,
            7.8993e-03, -7.4004e-03],
          ...,
          [ 0.0000e+00, -3.0341e-01, -1.0968e-01,  ...,  4.6328e-04,
            6.6100e-04, -2.7577e-09],
          [ 0.0000e+00, -3.7279e-01, -1.8398e-01,  ...,  7.3652e-05,
           -1.0610e-04, -2.2872e-08],
          [ 0.0000e+00, -9.4285e-01, -8.8958e-01,  ...,  2.

### `Feather`

In [39]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=200, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=120000, multiple=2, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.4713]),eps=1e-08)
  (2): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
  (3): EegNormalizeMeanStd(mean=tensor([[-3.1690e+01, -3.7667e-01,  1.5687e-03,  ..., -8.1782e-03,
           -8.0691e-03, -1.3191e-02],
          [-4.7364e+00, -1.3662e-03,  9.4536e-03,  ...,  5.5986e-03,
            2.6490e-03,  1.5448e-02],
          [ 3.3595e-01, -7.8127e-02, -1.9920e-03,  ...,  8.7425e-03,
            7.8993e-03, -7.4004e-03],
          ...,
          [ 0.0000e+00, -3.0341e-01, -1.0968e-01,  ...,  4.6328e-04,
            6.6100e-04, -2.7577e-09],
          [ 0.0000e+00, -3.7279e-01, -1.8398e-01,  ...,  7.3652e-05,
           -1.0610e-04, -2.2872e-08],
          [ 0.0000e+00, -9.4285e-01, -8.8958e-01,  ...,  2.

### `memmap`

In [40]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='task2',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=200, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=120000, multiple=2, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.4713]),eps=1e-08)
  (2): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
  (3): EegNormalizeMeanStd(mean=tensor([[-3.1690e+01, -3.7667e-01,  1.5687e-03,  ..., -8.1782e-03,
           -8.0691e-03, -1.3191e-02],
          [-4.7364e+00, -1.3662e-03,  9.4536e-03,  ...,  5.5986e-03,
            2.6490e-03,  1.5448e-02],
          [ 3.3595e-01, -7.8127e-02, -1.9920e-03,  ...,  8.7425e-03,
            7.8993e-03, -7.4004e-03],
          ...,
          [ 0.0000e+00, -3.0341e-01, -1.0968e-01,  ...,  4.6328e-04,
            6.6100e-04, -2.7577e-09],
          [ 0.0000e+00, -3.7279e-01, -1.8398e-01,  ...,  7.3652e-05,
           -1.0610e-04, -2.2872e-08],
          [ 0.0000e+00, -9.4285e-01, -8.8958e-01,  ...,  2.

---

## Test on longer sequence

In [42]:
%%time
longer_transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10*6,     # crop: 1m
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(longer_transform)

config_data, longer_test_dataset = load_caueeg_task_split(dataset_path=data_path, 
                                                          task='task2', 
                                                          split='test',
                                                          load_event=False,
                                                          file_format='feather', 
                                                          transform=longer_transform)

longer_test_loader = DataLoader(longer_test_dataset,
                                batch_size=32,
                                shuffle=True,
                                drop_last=False,
                                num_workers=num_workers,
                                pin_memory=pin_memory,
                                collate_fn=eeg_collate_fn)
 
preprocess_test = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
    EegNormalizeAge(mean=age_mean, std=age_std),
])
preprocess_test = torch.nn.Sequential(*preprocess_test.transforms).to(device)
pprint.pprint(preprocess_test)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_test(sample_batched)

Compose(
    EegRandomCrop(crop_length=12000, length_limit=120000, multiple=2, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeMeanStd(mean=tensor([ 0.0112, -0.0979, -0.1854, -0.0261,  0.2286,  0.2173,  0.0367, -0.0019,
           0.0108,  0.0986, -0.3425, -0.0903,  0.0697, -0.0039,  0.3391,  0.0294,
           0.1369, -0.0167, -0.0780, -0.0110]),std=tensor([46.4446, 20.7298, 12.0584, 12.2802, 16.0450, 51.3411, 21.0485, 10.9341,
          11.9656, 16.5442, 21.6827, 15.0109, 13.8311, 22.3359, 17.9284, 15.6722,
          20.1076, 11.5512, 13.0562, 97.4733]),eps=1e-08)
  (2): EegNormalizeAge(mean=tensor([70.6025]),std=tensor([6.4713]),eps=1e-08)
)
CPU times: total: 10.8 s
Wall time: 973 ms
