# Dataset and DataLoader

This notebook loads the `CAUEEG` dataset, tests some useful preprocessing, and makes up the PyTorch DataLoader instances for the training.

-----

## Configurations

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Documents\GitHub\caueeg-ceednet


In [2]:
# Load some packages
import os
import json
import pprint

import numpy as np
import random
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

# custom package
from datasets.caueeg_dataset import *
from datasets.caueeg_script import *
from datasets.pipeline import *

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.12.1+cu113
cuda is available.


In [4]:
# Data file path
data_path = r'local/dataset/caueeg-dataset/'

In [5]:
for task in ['annotation.json', 'abnormal.json', 'dementia.json']:
    task_path = os.path.join(data_path, task)
    with open(task_path, 'r') as json_file:
        task_dict = json.load(json_file)
        
    print('{')
    for k, v in task_dict.items():
        print(f'\t{k}:')
        if isinstance(v, list) and len(v) > 3:
            print(f'\t\t{v[0]}')
            print(f'\t\t{v[1]}')
            print(f'\t\t.')
            print(f'\t\t.')
            print(f'\t\t.')
            print(f'\t\t{v[-1]}')
        else:
            print(f'\t\t{v}')
        print()
    print('}')
    print('\n' + '-' * 100 + '\n')

{
	dataset_name:
		CAUEEG dataset

	signal_header:
		Fp1-AVG
		F3-AVG
		.
		.
		.
		Photic

	data:
		{'serial': '00001', 'age': 78, 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}
		{'serial': '00002', 'age': 56, 'symptom': ['normal', 'smi']}
		.
		.
		.
		{'serial': '01388', 'age': 73, 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_ef']}

}

----------------------------------------------------------------------------------------------------

{
	task_name:
		CAUEEG-Abnormal benchmark

	task_description:
		Classification of [Normal] and [Abnormal] symptoms

	class_label_to_name:
		['Normal', 'Abnormal']

	class_name_to_label:
		{'Normal': 0, 'Abnormal': 1}

	train_split:
		{'serial': '01258', 'age': 77, 'symptom': ['dementia', 'vd', 'sivd'], 'class_name': 'Abnormal', 'class_label': 1}
		{'serial': '00836', 'age': 80, 'symptom': ['normal', 'smi'], 'class_name': 'Normal', 'class_label': 0}
		.
		.
		.
		{'serial': '00105', 'age': 71, 'symptom': ['normal', 'smi'], 'class_name': 'N

-----

## Load the CAUEEG dataset

### Load the whole CAUEEG data as a PyTorch dataset instance without considering the target task (no train/val/test sets and no class label).

In [6]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='edf',  # can be ommitted
                                                         transform=None)

pprint.pprint(config_data, width=250)
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[3])
print('\n', '-' * 100, '\n')

{'dataset_name': 'CAUEEG dataset',
 'signal_header': ['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 78,
 'serial': '00001',
 'signal': array([[  0., -11., -13., ...,   0.,   0.,   0.],
       [ 29.,  33.,  34., ...,   0.,   0.,   0.],
       [ -3.,  -6.,  -3., ...,   0.,   0.,   0.],
       ...,
       [ -4.,  -2.,   1., ...,   0.,   0.,   0.],
       [112.,  67.,  76., ...,   0.,   0.,   0.],
       [ -1.,  -1.,  -1., ...,   0.,   0.,   0.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 78,
 'serial': '00004',
 'signal': array([[ 30.,  34.,  36., ...,   0.,   0.,   0.],
       [ 34.,  25., 

### When triggered `load_event` option, the dataset also loads event data.

In [7]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         file_format='edf',
                                                         transform=None)
pprint.pprint(full_eeg_dataset[0])

{'age': 78,
 'event': [[0, 'Start Recording'],
           [0, 'New Montage - Montage 002'],
           [36396, 'Eyes Open'],
           [72518, 'Eyes Closed'],
           [73862, 'Eyes Open'],
           [75248, 'Eyes Closed'],
           [76728, 'swallowing'],
           [77978, 'Eyes Open'],
           [79406, 'Eyes Closed'],
           [79996, 'Photic On - 3.0 Hz'],
           [80288, 'Eyes Open'],
           [81296, 'Eyes Closed'],
           [82054, 'Photic Off'],
           [84070, 'Photic On - 6.0 Hz'],
           [84488, 'Eyes Open'],
           [85538, 'Eyes Closed'],
           [86086, 'Photic Off'],
           [88144, 'Photic On - 9.0 Hz'],
           [90160, 'Photic Off'],
           [91458, 'Eyes Open'],
           [92218, 'Photic On - 12.0 Hz'],
           [92762, 'Eyes Closed'],
           [94198, 'Photic Off'],
           [94742, 'Eyes Open'],
           [95708, 'Eyes Closed'],
           [96256, 'Photic On - 15.0 Hz'],
           [98272, 'Photic Off'],
           [1003

-----

## Load the CAUEEG benchmarks

### Load the CAUEEG-Abnormal benchmark using the PyTorch dataset instances.

In [8]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')
pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'Abnormal'],
 'class_name_to_label': {'Abnormal': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal] and [Abnormal] symptoms',
 'task_name': 'CAUEEG-Abnormal benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 57,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00560',
 'signal': array([[  7.,  41.,  49., ...,   0.,   0.,   0.],
       [ 23.,  26.,  27., ...,   0.,   0.,   0.],
       [-16., -20., -18., ...,   0.,   0.,   0.],
       ...,
       [ 32.,  31.,  31., ...,   0.,   0.,   0.],
       [177., 228., 142., ...,   0.,   0.,   0.],
       [  0.,  -1.,   0., ...,   0.,   0.,   0.]]),
 'symptom': ['normal', 'cb_normal']}


### Load the CAUEEG-Dementia benchmark using the PyTorch dataset instances.

In [9]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')
pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal], [MCI], and [Dementia] '
                     'symptoms.',
 'task_name': 'CAUEEG-Dementia benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 62,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00789',
 'signal': array([[-87., -69., -70., ...,   0.,   0.,   0.],
       [-25., -18., -19., ...,   0.,   0.,   0.],
       [ -6.,   1.,   0., ...,   0.,   0.,   0.],
       ...,
       [  0.,  -5.,  -4., ...,   0.,   0.,   0.],
       [-31.,  -8.,  -7., ...,   0.,   0.,   0.],
       [  0.,   1.,  -1., ...,   0.,   0.,   0.]]),
 'symptom': ['normal', 'cb_normal']}


### Load the CAUEEG-Abnormal (no-overlap version) benchmark using the PyTorch dataset instances.

In [10]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal-no-overlap',
                                                                                  load_event=False)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')
pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'Abnormal'],
 'class_name_to_label': {'Abnormal': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal] and [Abnormal] symptoms',
 'task_name': 'CAUEEG-Abnormal benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 57,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00560',
 'signal': array([[  7.,  41.,  49., ...,   0.,   0.,   0.],
       [ 23.,  26.,  27., ...,   0.,   0.,   0.],
       [-16., -20., -18., ...,   0.,   0.,   0.],
       ...,
       [ 32.,  31.,  31., ...,   0.,   0.,   0.],
       [177., 228., 142., ...,   0.,   0.,   0.],
       [  0.,  -1.,   0., ...,   0.,   0.,   0.]]),
 'symptom': ['normal', 'cb_normal']}


### Load the CAUEEG-Dementia (no-overlap version) benchmark using the PyTorch dataset instances.

In [11]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia-no-overlap',
                                                                                  load_event=False, 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')
pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal], [MCI], and [Dementia] '
                     'symptoms.',
 'task_name': 'CAUEEG-Dementia benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 70,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00142',
 'signal': array([[-38., -24., -18., ...,   0.,   0.,   0.],
       [ -7.,  -4.,  -2., ...,   0.,   0.,   0.],
       [ -4.,  -6.,  -7., ...,   0.,   0.,   0.],
       ...,
       [ -4.,  -6.,  -7., ...,   0.,   0.,   0.],
       [-33.,  13.,  22., ...,   0.,   0.,   0.],
       [ -1.,   0.,   1., ...,   0.,   0.,   0.]]),
 'symptom': ['normal', 'cb_normal']}


### With `load_event` triggered, the benchmark can use the event data.

In [12]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=True, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(test_dataset[0])

{'age': 57,
 'class_label': 0,
 'class_name': 'Normal',
 'event': [[0, 'Start Recording'],
           [0, 'New Montage - Montage 002'],
           [5408, 'Eyes Open'],
           [6332, 'Eyes Closed'],
           [16832, 'Eyes Open'],
           [17756, 'Eyes Closed'],
           [24200, 'Paused'],
           [28400, 'Recording Resumed'],
           [67442, 'Eyes Open'],
           [68366, 'Eyes Closed'],
           [82856, 'Eyes Open'],
           [83738, 'Eyes Closed'],
           [109833, 'Eyes Open'],
           [109833, 'Eyes Closed'],
           [132600, 'Paused']],
 'serial': '00560',
 'signal': array([[  7.,  41.,  49., ...,   0.,   0.,   0.],
       [ 23.,  26.,  27., ...,   0.,   0.,   0.],
       [-16., -20., -18., ...,   0.,   0.,   0.],
       ...,
       [ 32.,  31.,  31., ...,   0.,   0.,   0.],
       [177., 228., 142., ...,   0.,   0.,   0.],
       [  0.,  -1.,   0., ...,   0.,   0.,   0.]]),
 'symptom': ['normal', 'cb_normal']}


### Utilizing `PyArrow.feather` is much faster than directly using `EDF`.

In [13]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
for i, d in enumerate(test_dataset):
    if i > 30:
        break

CPU times: total: 3.84 s
Wall time: 4.66 s


In [14]:
%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=None)
for i, d in enumerate(test_dataset):
    if i > 30:
        break

CPU times: total: 0 ns
Wall time: 0 ns


---

## PyTorch Transforms

### Random crop

In [15]:
transform = EegRandomCrop(crop_length=100)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path,
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
for i in range(2):
    d = test_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', d['signal'].shape)
    print('\n', '-' * 100, '\n')

{'age': 62,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00789',
 'signal': array([[  -26,   -26,   -24, ...,   -25,   -31,   -32],
       [  -26,   -30,   -25, ...,   -15,    -9,    -9],
       [  -12,    -6,    -1, ...,     5,    10,    11],
       ...,
       [    7,     8,     5, ...,     7,     0,    -5],
       [   61,    52,    43, ...,  -743, -1006, -1129],
       [    1,    -1,     0, ...,    -2,    -4,     0]]),
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: (21, 100)

 ---------------------------------------------------------------------------------------------------- 

{'age': 62,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00789',
 'signal': array([[ 11,   2,  -8, ..., -21, -22, -17],
       [ 33,  24,  14, ...,   0,   8,  13],
       [  7,  -3, -14, ...,   0,   6,   8],
       ...,
       [ 13,   7,   0, ...,  -2,  -2,  -4],
       [ 73,  83,  91, ...,  10,  23,  43],
       [ -1,   0,   2, ...,  -2,  -2,   0]]),
 'symptom': ['normal', 'cb_

### Drop channel(s)

In [16]:
anno_path = os.path.join(data_path, 'annotation.json')
with open(anno_path, 'r') as json_file:
    annotation = json.load(json_file)
signal_headers = annotation['signal_header']

channel_ekg = signal_headers.index('EKG')
channel_photic = signal_headers.index('Photic')

print('channel_ekg: ', channel_ekg)
print('channel_photic: ', channel_photic)

channel_ekg:  19
channel_photic:  20


In [17]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=None)
print('before:', test_dataset[0]['signal'].shape)
print(test_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=EegDropChannels(channel_ekg))
print('after:', test_dataset[0]['signal'].shape)
print(test_dataset[0]['signal'])

before: (21, 128400)
[[  7  41  49 ...  24  27  29]
 [ 23  26  27 ... -60 -58 -57]
 [-16 -20 -18 ...  17  18  17]
 ...
 [ 32  31  31 ...  18  16  15]
 [177 228 142 ...   2   1   0]
 [  0  -1   0 ...  -1   0   1]]

----------------------------------------------------------------------------------------------------

after: (20, 128400)
[[  7  41  49 ...  24  27  29]
 [ 23  26  27 ... -60 -58 -57]
 [-16 -20 -18 ...  17  18  17]
 ...
 [  3   7   7 ...  14  14  12]
 [ 32  31  31 ...  18  16  15]
 [  0  -1   0 ...  -1   0   1]]


In [18]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=None)
print('before:', test_dataset[0]['signal'].shape)
print(test_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=EegDropChannels(channel_photic))
print('after:', test_dataset[0]['signal'].shape)
print(test_dataset[0]['signal'])

before: (21, 128400)
[[  7  41  49 ...  24  27  29]
 [ 23  26  27 ... -60 -58 -57]
 [-16 -20 -18 ...  17  18  17]
 ...
 [ 32  31  31 ...  18  16  15]
 [177 228 142 ...   2   1   0]
 [  0  -1   0 ...  -1   0   1]]

----------------------------------------------------------------------------------------------------

after: (20, 128400)
[[  7  41  49 ...  24  27  29]
 [ 23  26  27 ... -60 -58 -57]
 [-16 -20 -18 ...  17  18  17]
 ...
 [  3   7   7 ...  14  14  12]
 [ 32  31  31 ...  18  16  15]
 [177 228 142 ...   2   1   0]]


In [19]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=None)
print('before:', test_dataset[0]['signal'].shape)
print(test_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=EegDropChannels([channel_ekg, channel_photic]))
print('after:', test_dataset[0]['signal'].shape)
print(test_dataset[0]['signal'])

before: (21, 128400)
[[  7  41  49 ...  24  27  29]
 [ 23  26  27 ... -60 -58 -57]
 [-16 -20 -18 ...  17  18  17]
 ...
 [ 32  31  31 ...  18  16  15]
 [177 228 142 ...   2   1   0]
 [  0  -1   0 ...  -1   0   1]]

----------------------------------------------------------------------------------------------------

after: (19, 128400)
[[  7  41  49 ...  24  27  29]
 [ 23  26  27 ... -60 -58 -57]
 [-16 -20 -18 ...  17  18  17]
 ...
 [-26 -11 -11 ...  18  18  17]
 [  3   7   7 ...  14  14  12]
 [ 32  31  31 ...  18  16  15]]


### To Tensor

In [20]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='feather',
                                                         transform=None)
print('Before:')
pprint.pprint(full_eeg_dataset[0])

print()
print('-' * 100)
print()

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='feather',
                                                         transform=EegToTensor())
print('After:')
pprint.pprint(full_eeg_dataset[0])

Before:
{'age': 78,
 'serial': '00001',
 'signal': array([[  0, -11, -13, ...,  18,  21,  22],
       [ 29,  33,  34, ...,  -7,  -4,  -4],
       [ -3,  -6,  -3, ...,  -1,   1,   2],
       ...,
       [ -4,  -2,   1, ...,   0,  -1,  -1],
       [112,  67,  76, ..., -13, -15, -11],
       [ -1,  -1,  -1, ...,  -1,  -1,  -1]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

----------------------------------------------------------------------------------------------------

After:
{'age': tensor(78.),
 'serial': '00001',
 'signal': tensor([[  0., -11., -13.,  ...,  18.,  21.,  22.],
        [ 29.,  33.,  34.,  ...,  -7.,  -4.,  -4.],
        [ -3.,  -6.,  -3.,  ...,  -1.,   1.,   2.],
        ...,
        [ -4.,  -2.,   1.,  ...,   0.,  -1.,  -1.],
        [112.,  67.,  76.,  ..., -13., -15., -11.],
        [ -1.,  -1.,  -1.,  ...,  -1.,  -1.,  -1.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}


### Compose the above all in one

In [21]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

pprint.pprint(train_dataset[0])

{'age': tensor(77.),
 'class_label': tensor(1),
 'class_name': 'Abnormal',
 'serial': '01258',
 'signal': tensor([[-36., -39., -39.,  ...,  -3., -10., -10.],
        [  3.,   3.,  -4.,  ...,  -6.,   7.,   7.],
        [-19., -19., -21.,  ...,  -9.,  -8.,  -5.],
        ...,
        [ -3.,  -6.,  -9.,  ...,  -9.,  -7.,  -7.],
        [  8.,   7.,   7.,  ...,   7.,   5.,   4.],
        [-12., -12., -11.,  ...,  19.,   3.,  -4.]]),
 'symptom': ['dementia', 'vd', 'sivd']}


---

## PyTorch DataLoader

In [22]:
if device.type == 'cuda':
    num_workers = 0  # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

In [23]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='feather',
                                                         transform=transform)

full_loader = DataLoader(full_eeg_dataset,
                         batch_size=4,
                         shuffle=True,
                         drop_last=True,
                         num_workers=num_workers,
                         pin_memory=pin_memory,
                         collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(full_loader):
    pprint.pprint(sample_batched)
    break

{'age': tensor([76., 85., 80., 39.]),
 'serial': ['00375', '00252', '00651', '01107'],
 'signal': tensor([[[  15.,   14.,   19.,  ...,    0.,    3.,    3.],
         [   0.,   -1.,    0.,  ...,   19.,   15.,   12.],
         [  -3.,   -1.,    1.,  ...,    8.,    7.,    5.],
         ...,
         [ -11.,  -10.,   -8.,  ...,   -6.,   -5.,   -4.],
         [  -5.,   -5.,   -5.,  ...,   -6.,   -6.,   -6.],
         [  -7.,  -18.,  -26.,  ...,  -43.,  -39.,  -40.]],

        [[-131., -126., -127.,  ...,   58.,   59.,   51.],
         [  27.,   24.,   21.,  ...,   20.,   16.,   13.],
         [ -45.,  -46.,  -47.,  ...,   13.,    7.,    3.],
         ...,
         [ -31.,  -31.,  -31.,  ...,  -11.,  -13.,  -15.],
         [ -64.,  -58.,  -53.,  ...,   -6.,   -6.,   -4.],
         [  -4.,    7.,   16.,  ...,  -17.,  -17.,  -30.]],

        [[ -23.,  -20.,  -18.,  ...,    3.,    0.,    1.],
         [ -18.,  -15.,  -14.,  ...,    6.,    6.,    9.],
         [  14.,   15.,   16.,  ...,  -27., 

In [24]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

test_loader = DataLoader(test_dataset,
                          batch_size=8,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(test_loader):
    pprint.pprint(sample_batched, width=250)
    break

{'age': tensor([78., 66., 75., 84., 80., 61., 82., 54.]),
 'class_label': tensor([1, 1, 0, 1, 1, 0, 1, 0]),
 'class_name': ['Abnormal', 'Abnormal', 'Normal', 'Abnormal', 'Abnormal', 'Normal', 'Abnormal', 'Normal'],
 'serial': ['01134', '01016', '00093', '00039', '00481', '00900', '00229', '00851'],
 'signal': tensor([[[ 267.,  264.,  265.,  ...,    6.,    9.,    9.],
         [  11.,    6.,    3.,  ...,    3.,   -6.,   -1.],
         [   2.,    0.,   -3.,  ...,   13.,   11.,   11.],
         ...,
         [  -8.,  -10.,  -13.,  ...,  -12.,   -7.,   -6.],
         [  10.,   11.,   11.,  ...,   -3.,    1.,    3.],
         [  22.,   32.,   19.,  ...,   19.,    4.,    1.]],

        [[  13.,   11.,   10.,  ...,   57.,   55.,   53.],
         [  12.,   10.,    9.,  ...,   15.,   14.,   14.],
         [  -7.,   -5.,   -7.,  ...,   -8.,   -6.,   -7.],
         ...,
         [  -5.,   -3.,    0.,  ...,   -1.,    1.,    1.],
         [  -1.,    1.,    1.,  ...,   -4.,   -2.,   -3.],
         [

---

## Some preprocessing steps can be implemented via the PyTorch Modules

In [25]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=2,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

### To GPU device if it is possible

In [26]:
print('device:', device)
print()

preprocess_train = transforms.Compose([EegToDevice(device=device)])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched)
    break

device: cuda

Sequential(
  (0): EegToDevice(device=cuda)
)
- Before -
{'age': tensor([76., 64.]),
 'class_label': tensor([1, 2]),
 'class_name': ['MCI', 'Dementia'],
 'serial': ['00378', '00633'],
 'signal': tensor([[[-89., -89., -84.,  ..., -15., -15., -15.],
         [-79., -73., -70.,  ..., -11., -12., -13.],
         [ 29.,  25.,  25.,  ...,   7.,   7.,   6.],
         ...,
         [ 38.,  38.,  40.,  ...,   7.,   9.,   9.],
         [ 32.,  33.,  35.,  ...,   0.,   1.,   0.],
         [  0.,  -4., -11.,  ..., -60., -58., -53.]],

        [[ -8., -11., -12.,  ...,   5.,   6.,   8.],
         [ -5.,  -6.,  -7.,  ...,   4.,   4.,   5.],
         [  3.,   3.,   3.,  ...,   0.,   0.,  -1.],
         ...,
         [ -3.,  -3.,  -4.,  ...,   5.,   5.,   6.],
         [ -6.,  -5.,  -5.,  ...,  -5.,  -5.,  -5.],
         [-14.,   1.,   8.,  ...,  -4.,  -5.,  -6.]]]),
 'symptom': [['mci', 'mci_vascular'], ['dementia', 'ad', 'eoad']]}

------------------------------------------------------

### Normalization per signal

In [27]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizePerSignal()
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizePerSignal(eps=1e-08)
)
- Before -
Mean: tensor([[-6.8880, -4.5790, -0.2805, -1.0875,  1.2830, -5.1060, -0.4220,  1.0950,
          2.6655,  1.0315, -3.3815,  0.6835,  0.9020,  4.0190,  0.7710,  1.1295,
         -0.5065,  0.9630,  1.2840,  0.5850],
        [ 1.3710,  0.9205,  0.1395, -0.2905, -0.7080,  5.2385,  1.1425, -0.0530,
         -2.1810, -0.1820,  1.0195, -0.3765, -0.8045,  1.4955, -0.8655, -0.5900,
          1.6540,  0.3090, -0.6355,  0.1305]])

Std: tensor([[ 68.6684,  27.2979,   6.1905,   9.9731,  14.1283,  59.1413,  15.8267,
           7.1583,  10.5741,  14.3079,  25.5349,   8.8883,  14.1888,  13.6482,
          13.8810,  15.3526,  12.5770,   6.4207,  10.6819, 121.9661],
        [ 31.4508,  15.2986,   8.7377,  10.5960,  23.9189,  31.0532,  12.5155,
           8.9206,  10.3289,  27.6918,  11.0976,   7.0854,  13.5130,  10.4345,
           8.0944,  16.5765,  14.7400,  11.9985,  10.2284,  32.6769]])

----------------

### Signal normalization using the specified mean and std values

In [28]:
signal_mean, signal_std = calculate_signal_statistics(train_loader, repeats=30, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Mean and standard deviation for signal:
tensor([[[ 0.0233],
         [-0.0022],
         [-0.0132],
         [ 0.0054],
         [-0.0396],
         [ 0.0262],
         [-0.0222],
         [-0.0088],
         [ 0.0500],
         [ 0.0818],
         [-0.0316],
         [ 0.0135],
         [-0.0011],
         [-0.0490],
         [ 0.0508],
         [-0.0007],
         [ 0.0144],
         [ 0.0059],
         [-0.0014],
         [ 0.0075]]])
-
tensor([[[45.4992],
         [20.4970],
         [11.7411],
         [11.8701],
         [15.6259],
         [48.7971],
         [19.9753],
         [10.6373],
         [11.8386],
         [15.8078],
         [20.8406],
         [14.4188],
         [13.7448],
         [21.7461],
         [16.8655],
         [14.9501],
         [19.5106],
         [11.4795],
         [11.7924],
         [94.0883]]])

----------------------------------------------------------------------------------------------------

Sequential(
  (0): EegToDevice(device=cuda)
  (1): 

### Age normalization

In [29]:
age_mean, age_std = calculate_age_statistics(train_loader, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['age'])

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['age'])
    break

Age mean and standard deviation:
tensor([71.2453]) tensor([7.7365])

----------------------------------------------------------------------------------------------------

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([7.7365]),eps=1e-08)
)
- Before -
tensor([61., 69.])

----------------------------------------------------------------------------------------------------

- After -
tensor([-1.3243, -0.2902], device='cuda:0')


### Short time Fourier transform (STFT or spectrogram)

In [30]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['signal'].shape)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['signal'].shape)
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
)
- Before -
torch.Size([2, 20, 2000])

----------------------------------------------------------------------------------------------------

- After -
torch.Size([2, 40, 101, 41])


---

## Speed check without STFT

In [31]:
crop_length = 200 * 10
batch_size = 128

### `EDF`

In [32]:
%%time
transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=10000000, multiple=1, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([7.7365]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([ 0.0233, -0.0022, -0.0132,  0.0054, -0.0396,  0.0262, -0.0222, -0.0088,
           0.0500,  0.0818, -0.0316,  0.0135, -0.0011, -0.0490,  0.0508, -0.0007,
           0.0144,  0.0059, -0.0014,  0.0075]),std=tensor([45.4992, 20.4970, 11.7411, 11.8701, 15.6259, 48.7971, 19.9753, 10.6373,
          11.8386, 15.8078, 20.8406, 14.4188, 13.7448, 21.7461, 16.8655, 14.9501,
          19.5106, 11.4795, 11.7924, 94.0883]),eps=1e-08)
)
CPU times: total: 18min 44s
Wall time: 2min 9s


### `PyArrow.feather`

In [33]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=10000000, multiple=1, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([7.7365]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([ 0.0233, -0.0022, -0.0132,  0.0054, -0.0396,  0.0262, -0.0222, -0.0088,
           0.0500,  0.0818, -0.0316,  0.0135, -0.0011, -0.0490,  0.0508, -0.0007,
           0.0144,  0.0059, -0.0014,  0.0075]),std=tensor([45.4992, 20.4970, 11.7411, 11.8701, 15.6259, 48.7971, 19.9753, 10.6373,
          11.8386, 15.8078, 20.8406, 14.4188, 13.7448, 21.7461, 16.8655, 14.9501,
          19.5106, 11.4795, 11.7924, 94.0883]),eps=1e-08)
)
CPU times: total: 4.59 s
Wall time: 3.66 s


---

## Speed check with STFT

In [34]:
crop_length = 300 * 10
n_fft, hop_length, seq_len_2d = calculate_stft_params(seq_length=crop_length, verbose=True)
batch_size = 128

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
signal_2d_mean, signal_2d_std = calculate_signal_statistics(train_loader, preprocess_train)

Input sequence length: (3000) would become (78, 77) after the STFT with n_fft (155) and hop_length (39).

----------------------------------------------------------------------------------------------------



### `EDF`

In [35]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

Compose(
    EegRandomCrop(crop_length=3000, length_limit=10000000, multiple=1, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([7.7365]),eps=1e-08)
  (2): EegSpectrogram(n_fft=155, complex_mode=as_real, stft_kwargs={'hop_length': 39})
  (3): EegNormalizeMeanStd(mean=tensor([[ 2.2792e+01,  3.1216e-01,  1.0842e-01,  ...,  1.3747e-01,
            1.2862e-01,  1.2759e-01],
          [ 2.7824e+01, -4.2494e-02,  5.1770e-02,  ...,  3.1540e-02,
            2.3955e-02,  2.7768e-02],
          [ 7.3945e+00,  1.5329e-01, -6.2711e-02,  ..., -2.3968e-02,
            7.5325e-03, -9.5869e-04],
          ...,
          [ 0.0000e+00,  4.6296e-01,  3.1164e-01,  ...,  1.0917e-02,
           -1.0444e-02, -3.8881e-04],
          [ 0.0000e+00,  5.7710e-01,  3.2173e-01,  ..., -2.2850e-03,
            1.9137e-03, -5.3399e-04],
          [ 0.0000e+00,  9.0587e-01,  7.9

### `PyArrow.feather`

In [36]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='feather',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

Compose(
    EegRandomCrop(crop_length=3000, length_limit=10000000, multiple=1, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([7.7365]),eps=1e-08)
  (2): EegSpectrogram(n_fft=155, complex_mode=as_real, stft_kwargs={'hop_length': 39})
  (3): EegNormalizeMeanStd(mean=tensor([[ 2.2792e+01,  3.1216e-01,  1.0842e-01,  ...,  1.3747e-01,
            1.2862e-01,  1.2759e-01],
          [ 2.7824e+01, -4.2494e-02,  5.1770e-02,  ...,  3.1540e-02,
            2.3955e-02,  2.7768e-02],
          [ 7.3945e+00,  1.5329e-01, -6.2711e-02,  ..., -2.3968e-02,
            7.5325e-03, -9.5869e-04],
          ...,
          [ 0.0000e+00,  4.6296e-01,  3.1164e-01,  ...,  1.0917e-02,
           -1.0444e-02, -3.8881e-04],
          [ 0.0000e+00,  5.7710e-01,  3.2173e-01,  ..., -2.2850e-03,
            1.9137e-03, -5.3399e-04],
          [ 0.0000e+00,  9.0587e-01,  7.9