# Dataset and DataLoader

This notebook loads the `CAUEEG` dataset, tests some useful preprocessing, and makes up the PyTorch DataLoader instances for the training.

-----

## Configurations

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Documents\GitHub\caueeg-ceednet


In [2]:
# Load some packages
import os
import json
import pprint

import numpy as np
import random
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

# custom package
from datasets.caueeg_dataset import *
from datasets.caueeg_script import *
from datasets.pipeline import *

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.12.1
cuda is available.


In [4]:
# Data file path
data_path = r'local/dataset/caueeg-dataset/'

In [5]:
for task in ['annotation.json', 'abnormal.json', 'dementia.json']:
    task_path = os.path.join(data_path, task)
    with open(task_path, 'r') as json_file:
        task_dict = json.load(json_file)
        
    print('{')
    for k, v in task_dict.items():
        print(f'\t{k}:')
        if isinstance(v, list) and len(v) > 3:
            print(f'\t\t{v[0]}')
            print(f'\t\t{v[1]}')
            print(f'\t\t.')
            print(f'\t\t.')
            print(f'\t\t.')
            print(f'\t\t{v[-1]}')
        else:
            print(f'\t\t{v}')
        print()
    print('}')
    print('\n' + '-' * 100 + '\n')

{
	dataset_name:
		CAUEEG dataset

	signal_header:
		Fp1-AVG
		F3-AVG
		.
		.
		.
		Photic

	data:
		{'serial': '00001', 'age': 78, 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}
		{'serial': '00002', 'age': 56, 'symptom': ['normal', 'smi']}
		.
		.
		.
		{'serial': '01388', 'age': 73, 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_ef']}

}

----------------------------------------------------------------------------------------------------

{
	task_name:
		CAUEEG-Abnormal benchmark

	task_description:
		Classification of [Normal] and [Abnormal] symptoms

	class_label_to_name:
		['Normal', 'Abnormal']

	class_name_to_label:
		{'Normal': 0, 'Abnormal': 1}

	train_split:
		{'serial': '01258', 'age': 77, 'symptom': ['dementia', 'vd', 'sivd'], 'class_name': 'Abnormal', 'class_label': 1}
		{'serial': '00836', 'age': 80, 'symptom': ['normal', 'smi'], 'class_name': 'Normal', 'class_label': 0}
		.
		.
		.
		{'serial': '00105', 'age': 71, 'symptom': ['normal', 'smi'], 'class_name': 'N

-----

## Load the CAUEEG dataset

### Load the whole CAUEEG data as a PyTorch dataset instance without considering the target task (no train/val/test sets and no class label).

In [6]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='edf',  # can be ommitted
                                                         transform=None)

pprint.pprint(config_data, width=250)
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[3])
print('\n', '-' * 100, '\n')

{'dataset_name': 'CAUEEG dataset',
 'signal_header': ['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 78,
 'serial': '00001',
 'signal': array([[  0., -11., -13., ...,   0.,   0.,   0.],
       [ 29.,  33.,  34., ...,   0.,   0.,   0.],
       [ -3.,  -6.,  -3., ...,   0.,   0.,   0.],
       ...,
       [ -4.,  -2.,   1., ...,   0.,   0.,   0.],
       [112.,  67.,  76., ...,   0.,   0.,   0.],
       [ -1.,  -1.,  -1., ...,   0.,   0.,   0.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 78,
 'serial': '00004',
 'signal': array([[ 30.,  34.,  36., ...,   0.,   0.,   0.],
       [ 34.,  25., 

### When triggered `load_event` option, the dataset also loads event data.

In [7]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         file_format='edf',
                                                         transform=None)
pprint.pprint(full_eeg_dataset[0])

{'age': 78,
 'event': [[0, 'Start Recording'],
           [0, 'New Montage - Montage 002'],
           [36396, 'Eyes Open'],
           [72518, 'Eyes Closed'],
           [73862, 'Eyes Open'],
           [75248, 'Eyes Closed'],
           [76728, 'swallowing'],
           [77978, 'Eyes Open'],
           [79406, 'Eyes Closed'],
           [79996, 'Photic On - 3.0 Hz'],
           [80288, 'Eyes Open'],
           [81296, 'Eyes Closed'],
           [82054, 'Photic Off'],
           [84070, 'Photic On - 6.0 Hz'],
           [84488, 'Eyes Open'],
           [85538, 'Eyes Closed'],
           [86086, 'Photic Off'],
           [88144, 'Photic On - 9.0 Hz'],
           [90160, 'Photic Off'],
           [91458, 'Eyes Open'],
           [92218, 'Photic On - 12.0 Hz'],
           [92762, 'Eyes Closed'],
           [94198, 'Photic Off'],
           [94742, 'Eyes Open'],
           [95708, 'Eyes Closed'],
           [96256, 'Photic On - 15.0 Hz'],
           [98272, 'Photic Off'],
           [1003

-----

## Load the CAUEEG benchmarks

### Load the CAUEEG-Abnormal benchmark using the PyTorch dataset instances.

In [8]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')
pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'Abnormal'],
 'class_name_to_label': {'Abnormal': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal] and [Abnormal] symptoms',
 'task_name': 'CAUEEG-Abnormal benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 57,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00560',
 'signal': array([[  7.,  41.,  49., ...,   0.,   0.,   0.],
       [ 23.,  26.,  27., ...,   0.,   0.,   0.],
       [-16., -20., -18., ...,   0.,   0.,   0.],
       ...,
       [ 32.,  31.,  31., ...,   0.,   0.,   0.],
       [177., 228., 142., ...,   0.,   0.,   0.],
       [  0.,  -1.,   0., ...,   0.,   0.,   0.]]),
 'symptom': ['normal', 'cb_normal']}


### Load the CAUEEG-Dementia benchmark using the PyTorch dataset instances.

In [9]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')
pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal], [MCI], and [Dementia] '
                     'symptoms.',
 'task_name': 'CAUEEG-Dementia benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 62,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00789',
 'signal': array([[-87., -69., -70., ...,   0.,   0.,   0.],
       [-25., -18., -19., ...,   0.,   0.,   0.],
       [ -6.,   1.,   0., ...,   0.,   0.,   0.],
       ...,
       [  0.,  -5.,  -4., ...,   0.,   0.,   0.],
       [-31.,  -8.,  -7., ...,   0.,   0.,   0.],
       [  0.,   1.,  -1., ...,   0.,   0.,   0.]]),
 'symptom': ['normal', 'cb_normal']}


### With `load_event` triggered, the benchmark can use the event data.

In [10]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=True, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(test_dataset[0])

{'age': 57,
 'class_label': 0,
 'class_name': 'Normal',
 'event': [[0, 'Start Recording'],
           [0, 'New Montage - Montage 002'],
           [5408, 'Eyes Open'],
           [6332, 'Eyes Closed'],
           [16832, 'Eyes Open'],
           [17756, 'Eyes Closed'],
           [24200, 'Paused'],
           [28400, 'Recording Resumed'],
           [67442, 'Eyes Open'],
           [68366, 'Eyes Closed'],
           [82856, 'Eyes Open'],
           [83738, 'Eyes Closed'],
           [109833, 'Eyes Open'],
           [109833, 'Eyes Closed'],
           [132600, 'Paused']],
 'serial': '00560',
 'signal': array([[  7.,  41.,  49., ...,   0.,   0.,   0.],
       [ 23.,  26.,  27., ...,   0.,   0.,   0.],
       [-16., -20., -18., ...,   0.,   0.,   0.],
       ...,
       [ 32.,  31.,  31., ...,   0.,   0.,   0.],
       [177., 228., 142., ...,   0.,   0.,   0.],
       [  0.,  -1.,   0., ...,   0.,   0.,   0.]]),
 'symptom': ['normal', 'cb_normal']}


### Utilizing `PyArrow.feather` is much faster than directly using `EDF`.

In [11]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
for i, d in enumerate(test_dataset):
    if i > 30:
        break

CPU times: total: 4.14 s
Wall time: 4.14 s


In [12]:
%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=None)
for i, d in enumerate(test_dataset):
    if i > 30:
        break

CPU times: total: 0 ns
Wall time: 0 ns


---

## PyTorch Transforms

### Random crop

In [13]:
transform = EegRandomCrop(crop_length=100)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path,
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
for i in range(2):
    d = test_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', d['signal'].shape)
    print('\n', '-' * 100, '\n')

{'age': 62,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00789',
 'signal': array([[ 55,  51,  51, ...,  33,  37,  42],
       [  7,  11,  18, ...,  -9,  -9,  -8],
       [ -3,   3,  11, ..., -18, -14,  -7],
       ...,
       [-15,  -9,  -3, ...,  -7,  -9, -11],
       [180, 203, 222, ..., -29, -34, -38],
       [  1,   1,   0, ...,  -1,   0,   1]]),
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: (21, 100)

 ---------------------------------------------------------------------------------------------------- 

{'age': 62,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00789',
 'signal': array([[ 16,  16,  16, ...,  10,   8,  11],
       [  8,   9,   6, ...,   1,   2,   5],
       [ 16,  18,  19, ...,  -4,  -5,  -3],
       ...,
       [ -1,   4,   9, ...,  -3,  -4,  -7],
       [-36, -36, -36, ..., 151, 141, 130],
       [  1,   0,   1, ...,  -3,  -4,   0]]),
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: (21, 100)

 -------------------------------

### Drop channel(s)

In [14]:
anno_path = os.path.join(data_path, 'annotation.json')
with open(anno_path, 'r') as json_file:
    annotation = json.load(json_file)
signal_headers = annotation['signal_header']

channel_ekg = signal_headers.index('EKG')
channel_photic = signal_headers.index('Photic')

print('channel_ekg: ', channel_ekg)
print('channel_photic: ', channel_photic)

channel_ekg:  19
channel_photic:  20


In [15]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=None)
print('before:', test_dataset[0]['signal'].shape)
print(test_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=EegDropChannels(channel_ekg))
print('after:', test_dataset[0]['signal'].shape)
print(test_dataset[0]['signal'])

before: (21, 128400)
[[  7  41  49 ...  24  27  29]
 [ 23  26  27 ... -60 -58 -57]
 [-16 -20 -18 ...  17  18  17]
 ...
 [ 32  31  31 ...  18  16  15]
 [177 228 142 ...   2   1   0]
 [  0  -1   0 ...  -1   0   1]]

----------------------------------------------------------------------------------------------------

after: (20, 128400)
[[  7  41  49 ...  24  27  29]
 [ 23  26  27 ... -60 -58 -57]
 [-16 -20 -18 ...  17  18  17]
 ...
 [  3   7   7 ...  14  14  12]
 [ 32  31  31 ...  18  16  15]
 [  0  -1   0 ...  -1   0   1]]


In [16]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=None)
print('before:', test_dataset[0]['signal'].shape)
print(test_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=EegDropChannels(channel_photic))
print('after:', test_dataset[0]['signal'].shape)
print(test_dataset[0]['signal'])

before: (21, 128400)
[[  7  41  49 ...  24  27  29]
 [ 23  26  27 ... -60 -58 -57]
 [-16 -20 -18 ...  17  18  17]
 ...
 [ 32  31  31 ...  18  16  15]
 [177 228 142 ...   2   1   0]
 [  0  -1   0 ...  -1   0   1]]

----------------------------------------------------------------------------------------------------

after: (20, 128400)
[[  7  41  49 ...  24  27  29]
 [ 23  26  27 ... -60 -58 -57]
 [-16 -20 -18 ...  17  18  17]
 ...
 [  3   7   7 ...  14  14  12]
 [ 32  31  31 ...  18  16  15]
 [177 228 142 ...   2   1   0]]


In [17]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=None)
print('before:', test_dataset[0]['signal'].shape)
print(test_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=EegDropChannels([channel_ekg, channel_photic]))
print('after:', test_dataset[0]['signal'].shape)
print(test_dataset[0]['signal'])

before: (21, 128400)
[[  7  41  49 ...  24  27  29]
 [ 23  26  27 ... -60 -58 -57]
 [-16 -20 -18 ...  17  18  17]
 ...
 [ 32  31  31 ...  18  16  15]
 [177 228 142 ...   2   1   0]
 [  0  -1   0 ...  -1   0   1]]

----------------------------------------------------------------------------------------------------

after: (19, 128400)
[[  7  41  49 ...  24  27  29]
 [ 23  26  27 ... -60 -58 -57]
 [-16 -20 -18 ...  17  18  17]
 ...
 [-26 -11 -11 ...  18  18  17]
 [  3   7   7 ...  14  14  12]
 [ 32  31  31 ...  18  16  15]]


### To Tensor

In [18]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='feather',
                                                         transform=None)
print('Before:')
pprint.pprint(full_eeg_dataset[0])

print()
print('-' * 100)
print()

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='feather',
                                                         transform=EegToTensor())
print('After:')
pprint.pprint(full_eeg_dataset[0])

Before:
{'age': 78,
 'serial': '00001',
 'signal': array([[  0, -11, -13, ...,  18,  21,  22],
       [ 29,  33,  34, ...,  -7,  -4,  -4],
       [ -3,  -6,  -3, ...,  -1,   1,   2],
       ...,
       [ -4,  -2,   1, ...,   0,  -1,  -1],
       [112,  67,  76, ..., -13, -15, -11],
       [ -1,  -1,  -1, ...,  -1,  -1,  -1]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

----------------------------------------------------------------------------------------------------

After:
{'age': tensor(78.),
 'serial': '00001',
 'signal': tensor([[  0., -11., -13.,  ...,  18.,  21.,  22.],
        [ 29.,  33.,  34.,  ...,  -7.,  -4.,  -4.],
        [ -3.,  -6.,  -3.,  ...,  -1.,   1.,   2.],
        ...,
        [ -4.,  -2.,   1.,  ...,   0.,  -1.,  -1.],
        [112.,  67.,  76.,  ..., -13., -15., -11.],
        [ -1.,  -1.,  -1.,  ...,  -1.,  -1.,  -1.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}


### Compose the above all in one

In [19]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

pprint.pprint(train_dataset[0])

{'age': tensor(77.),
 'class_label': tensor(1),
 'class_name': 'Abnormal',
 'serial': '01258',
 'signal': tensor([[-12., -10., -17.,  ...,   6.,   3.,   7.],
        [ 20.,  15.,  11.,  ...,  -1.,  -2., -12.],
        [ 11.,   7.,   8.,  ...,  -2.,  -7.,  -6.],
        ...,
        [ -8., -12., -16.,  ...,   5.,   3.,   0.],
        [ -2.,  -4.,  -4.,  ...,   1.,  -1.,  -1.],
        [ 19.,  19.,  24.,  ..., -70., -58., -51.]]),
 'symptom': ['dementia', 'vd', 'sivd']}


---

## PyTorch DataLoader

In [20]:
if device.type == 'cuda':
    num_workers = 0  # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

In [21]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='feather',
                                                         transform=transform)

full_loader = DataLoader(full_eeg_dataset,
                         batch_size=4,
                         shuffle=True,
                         drop_last=True,
                         num_workers=num_workers,
                         pin_memory=pin_memory,
                         collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(full_loader):
    pprint.pprint(sample_batched)
    break

{'age': tensor([81., 84., 66., 69.]),
 'serial': ['01093', '01011', '00084', '01087'],
 'signal': tensor([[[ 257.,  252.,  257.,  ...,   23.,   21.,   20.],
         [  59.,   58.,   57.,  ...,    4.,    3.,    2.],
         [  -3.,   -3.,   -1.,  ...,   11.,   11.,    9.],
         ...,
         [ -13.,  -14.,  -13.,  ...,    4.,    8.,   10.],
         [ -19.,  -16.,  -12.,  ...,   11.,   12.,    9.],
         [-132., -133., -125.,  ...,  -59.,  -63.,  -68.]],

        [[  34.,   30.,   25.,  ...,    6.,    4.,    3.],
         [  31.,   25.,   20.,  ...,    1.,    2.,    1.],
         [  -1.,    2.,    4.,  ...,   -6.,   -6.,   -5.],
         ...,
         [   9.,   13.,   14.,  ...,   -5.,   -6.,   -5.],
         [ -13.,  -10.,   -8.,  ...,   -4.,   -6.,   -6.],
         [  62.,   65.,   69.,  ...,  -19.,  -24.,  -29.]],

        [[  -6.,   -5.,   -5.,  ...,  -24.,  -24.,  -23.],
         [   7.,    8.,    8.,  ...,  -30.,  -26.,  -26.],
         [  11.,    7.,    0.,  ...,   -6., 

In [22]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

test_loader = DataLoader(test_dataset,
                          batch_size=8,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(test_loader):
    pprint.pprint(sample_batched, width=250)
    break

{'age': tensor([62., 77., 65., 59., 76., 78., 60., 80.]),
 'class_label': tensor([0, 1, 1, 0, 1, 0, 0, 1]),
 'class_name': ['Normal', 'Abnormal', 'Abnormal', 'Normal', 'Abnormal', 'Normal', 'Normal', 'Abnormal'],
 'serial': ['00098', '01021', '01219', '00787', '00722', '00298', '01139', '00259'],
 'signal': tensor([[[ -31.,  -33.,  -33.,  ...,  716.,  663.,  672.],
         [   0.,    0.,    0.,  ...,   81.,   76.,   34.],
         [   7.,   10.,   12.,  ...,   45.,   30.,  -11.],
         ...,
         [ -16.,  -14.,  -10.,  ...,   11.,   20.,   16.],
         [  -9.,   -7.,   -6.,  ...,   -9.,    1.,    3.],
         [ -29.,  -36.,  -41.,  ...,   20.,   25.,   29.]],

        [[  18.,   19.,   18.,  ...,    8.,   11.,   16.],
         [  -5.,   -4.,   -3.,  ...,    9.,   12.,   17.],
         [  -3.,   -7.,  -11.,  ...,   -4.,   -6.,   -7.],
         ...,
         [  -5.,   -5.,   -4.,  ...,   -2.,   -2.,   -4.],
         [ -11.,  -14.,  -14.,  ...,   -6.,   -5.,   -6.],
         [ -

---

## Some preprocessing steps can be implemented via the PyTorch Modules

In [23]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=2,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

### To GPU device if it is possible

In [24]:
print('device:', device)
print()

preprocess_train = transforms.Compose([EegToDevice(device=device)])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched)
    break

device: cuda

Sequential(
  (0): EegToDevice(device=cuda)
)
- Before -
{'age': tensor([70., 68.]),
 'class_label': tensor([1, 0]),
 'class_name': ['MCI', 'Normal'],
 'serial': ['00388', '01022'],
 'signal': tensor([[[-10., -11., -12.,  ..., -15., -17., -17.],
         [ 12.,   9.,   8.,  ...,  14.,  15.,  16.],
         [  0.,  -4.,  -7.,  ...,  -8.,  -7.,  -6.],
         ...,
         [ 10.,  11.,  11.,  ...,   6.,   8.,   9.],
         [ -6.,  -6.,  -6.,  ...,  -7.,  -7.,  -6.],
         [ 31.,  30.,  29.,  ..., -11., -18., -21.]],

        [[ 19.,  21.,  17.,  ..., -11.,  -8.,  -4.],
         [ -5.,   2.,   2.,  ...,  -6.,  -8.,  -7.],
         [  1.,  -1.,  -2.,  ...,  -3.,  -4.,  -2.],
         ...,
         [  5.,   7.,   8.,  ...,  -5.,  -3.,   1.],
         [  0.,   0.,   1.,  ...,  -2.,   0.,   0.],
         [  0.,   5.,   4.,  ..., -27., -25., -21.]]]),
 'symptom': [['mci', 'mci_vascular'], ['normal', 'smi']]}

-----------------------------------------------------------------

### Normalization per signal

In [25]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizePerSignal()
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizePerSignal(eps=1e-08)
)
- Before -
Mean: tensor([[-1.5015, -1.2540,  0.9750, -0.2035, -0.0120, -1.4020, -0.4920,  0.0750,
          0.2650, -0.0150, -0.9570,  1.1425, -0.8590, -0.0855,  3.0465, -0.0670,
          0.0280, -1.1910, -0.3915, -0.4005],
        [-3.2060, -0.2100,  0.4410,  0.3770,  0.2915,  0.4115, -1.1685,  0.4630,
          0.7025,  0.6765, -0.7380,  0.0970, -0.2030, -1.7420, -1.9505,  0.0985,
          1.9050, -1.0475, -0.0145,  0.7985]])

Std: tensor([[10.2681,  7.4200,  7.0835,  7.6284,  8.2955, 10.5996,  9.6177,  7.0683,
          6.9477,  6.7667,  9.3398, 10.4793, 12.0134,  8.6347, 10.9687,  8.9374,
          7.0422,  6.1566,  6.0609, 21.0069],
        [20.7061, 16.6624,  6.5889,  7.0615,  9.6049, 19.0360, 11.9814,  5.1844,
          5.9414,  9.9172, 14.4610,  9.2566, 10.6721, 15.0691, 12.9404,  6.6295,
         24.6432, 11.8071,  7.4599, 65.7220]])

--------------------------------------------------------

### Signal normalization using the specified mean and std values

In [26]:
signal_mean, signal_std = calculate_signal_statistics(train_loader, repeats=30, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Mean and standard deviation for signal:
tensor([[[-0.0340],
         [ 0.0634],
         [ 0.0308],
         [-0.0197],
         [-0.0238],
         [-0.0111],
         [-0.0413],
         [ 0.0295],
         [ 0.0091],
         [ 0.0030],
         [ 0.0294],
         [-0.0206],
         [-0.0196],
         [-0.0190],
         [-0.0312],
         [-0.0034],
         [-0.0349],
         [ 0.0046],
         [ 0.0249],
         [-0.0021]]])
-
tensor([[[45.1173],
         [20.5836],
         [11.8230],
         [11.7496],
         [15.7130],
         [48.4388],
         [19.9356],
         [10.5975],
         [11.8556],
         [16.0715],
         [20.9259],
         [14.4848],
         [13.7052],
         [21.9231],
         [16.7736],
         [14.9407],
         [19.5530],
         [11.3869],
         [11.5357],
         [93.8896]]])

----------------------------------------------------------------------------------------------------

Sequential(
  (0): EegToDevice(device=cuda)
  (1): 

### Age normalization

In [27]:
age_mean, age_std = calculate_age_statistics(train_loader, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['age'])

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['age'])
    break

Age mean and standard deviation:
tensor([71.2453]) tensor([7.8228])

----------------------------------------------------------------------------------------------------

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([7.8228]),eps=1e-08)
)
- Before -
tensor([67., 68.])

----------------------------------------------------------------------------------------------------

- After -
tensor([-0.5427, -0.4148], device='cuda:0')


### Short time Fourier transform (STFT or spectrogram)

In [28]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['signal'].shape)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['signal'].shape)
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
)
- Before -
torch.Size([2, 20, 2000])

----------------------------------------------------------------------------------------------------

- After -
torch.Size([2, 40, 101, 41])


---

## Speed check without STFT

In [29]:
crop_length = 200 * 10
batch_size = 128

### `EDF`

In [30]:
%%time
transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=10000000, multiple=1, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([7.8228]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([-0.0340,  0.0634,  0.0308, -0.0197, -0.0238, -0.0111, -0.0413,  0.0295,
           0.0091,  0.0030,  0.0294, -0.0206, -0.0196, -0.0190, -0.0312, -0.0034,
          -0.0349,  0.0046,  0.0249, -0.0021]),std=tensor([45.1173, 20.5836, 11.8230, 11.7496, 15.7130, 48.4388, 19.9356, 10.5975,
          11.8556, 16.0715, 20.9259, 14.4848, 13.7052, 21.9231, 16.7736, 14.9407,
          19.5530, 11.3869, 11.5357, 93.8896]),eps=1e-08)
)
CPU times: total: 23min 8s
Wall time: 1min 55s


### `feather`

In [31]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=10000000, multiple=1, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([7.8228]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([-0.0340,  0.0634,  0.0308, -0.0197, -0.0238, -0.0111, -0.0413,  0.0295,
           0.0091,  0.0030,  0.0294, -0.0206, -0.0196, -0.0190, -0.0312, -0.0034,
          -0.0349,  0.0046,  0.0249, -0.0021]),std=tensor([45.1173, 20.5836, 11.8230, 11.7496, 15.7130, 48.4388, 19.9356, 10.5975,
          11.8556, 16.0715, 20.9259, 14.4848, 13.7052, 21.9231, 16.7736, 14.9407,
          19.5530, 11.3869, 11.5357, 93.8896]),eps=1e-08)
)
CPU times: total: 50 s
Wall time: 3.28 s


---

## Speed check with STFT

In [32]:
crop_length = 300 * 10
n_fft, hop_length, seq_len_2d = calculate_stft_params(seq_length=crop_length, verbose=True)
batch_size = 128

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
signal_2d_mean, signal_2d_std = calculate_signal_statistics(train_loader, preprocess_train)

Input sequence length: (3000) would become (78, 77) after the STFT with n_fft (155) and hop_length (39).

----------------------------------------------------------------------------------------------------



### `EDF`

In [33]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

Compose(
    EegRandomCrop(crop_length=3000, length_limit=10000000, multiple=1, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([7.8228]),eps=1e-08)
  (2): EegSpectrogram(n_fft=155, complex_mode=as_real, stft_kwargs={'hop_length': 39})
  (3): EegNormalizeMeanStd(mean=tensor([[-3.9695e+01, -4.8631e-01, -1.2728e-01,  ..., -2.1061e-03,
            2.9554e-02,  2.1025e-02],
          [ 1.3874e+01, -1.0662e-01, -7.5962e-02,  ...,  6.4494e-03,
            8.2820e-03,  6.8985e-03],
          [ 9.4372e+00,  1.4145e-01, -9.2469e-03,  ...,  2.0075e-02,
            1.7569e-02,  1.7965e-02],
          ...,
          [ 0.0000e+00,  9.7692e-01,  4.0125e-01,  ..., -8.8930e-03,
            8.9706e-03,  2.4872e-03],
          [ 0.0000e+00,  1.3422e+00,  5.8539e-01,  ...,  4.4170e-03,
           -4.2220e-03,  3.9863e-04],
          [ 0.0000e+00, -1.3657e+00, -3.2

### `PyArrow.feather`

In [34]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

Compose(
    EegRandomCrop(crop_length=3000, length_limit=10000000, multiple=1, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([7.8228]),eps=1e-08)
  (2): EegSpectrogram(n_fft=155, complex_mode=as_real, stft_kwargs={'hop_length': 39})
  (3): EegNormalizeMeanStd(mean=tensor([[-3.9695e+01, -4.8631e-01, -1.2728e-01,  ..., -2.1061e-03,
            2.9554e-02,  2.1025e-02],
          [ 1.3874e+01, -1.0662e-01, -7.5962e-02,  ...,  6.4494e-03,
            8.2820e-03,  6.8985e-03],
          [ 9.4372e+00,  1.4145e-01, -9.2469e-03,  ...,  2.0075e-02,
            1.7569e-02,  1.7965e-02],
          ...,
          [ 0.0000e+00,  9.7692e-01,  4.0125e-01,  ..., -8.8930e-03,
            8.9706e-03,  2.4872e-03],
          [ 0.0000e+00,  1.3422e+00,  5.8539e-01,  ...,  4.4170e-03,
           -4.2220e-03,  3.9863e-04],
          [ 0.0000e+00, -1.3657e+00, -3.2