# Dataset and DataLoader

This notebook loads the `CAUEEG` dataset, tests some useful preprocessing, and makes up the PyTorch DataLoader instances for the training.

-----

## Configurations

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
import glob
import json
import pprint

import numpy as np
import random
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

# custom package
from datasets.caueeg_dataset import *
from datasets.caueeg_script import *
from datasets.pipeline import *

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.11.0+cu113
cuda is available.


In [4]:
# Data file path
data_path = r'local/dataset/02_Curated_Data_220419/'

In [5]:
for task in ['annotation.json', 'abnormal.json', 'dementia.json']:
    task_path = os.path.join(data_path, task)
    with open(task_path, 'r') as json_file:
        task_dict = json.load(json_file)
        
    print('{')
    for k, v in task_dict.items():
        print(f'\t{k}:')
        if isinstance(v, list) and len(v) > 3:
            print(f'\t\t{v[0]}')
            print(f'\t\t{v[1]}')
            print(f'\t\t{v[2]}')
            print(f'\t\t.')
            print(f'\t\t.')
            print(f'\t\t.')
            print(f'\t\t{v[-1]}')
        else:
            print(f'\t\t{v}')
        print()
    print('}')

{
	dataset_name:
		CAUEEG dataset

	signal_header:
		Fp1-AVG
		F3-AVG
		C3-AVG
		.
		.
		.
		Photic

	data:
		{'serial': '00001', 'age': 78, 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}
		{'serial': '00002', 'age': 56, 'symptom': ['normal', 'smi']}
		{'serial': '00003', 'age': 93, 'symptom': ['mci', 'mci_vascular']}
		.
		.
		.
		{'serial': '01388', 'age': 73, 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_ef']}

}
{
	task_name:
		CAUEEG-Abnormal benchmark

	task_description:
		Classification of [Normal] and [Abnormal] symptoms

	class_label_to_name:
		['Normal', 'Abnormal']

	class_name_to_label:
		{'Normal': 0, 'Abnormal': 1}

	train_split:
		{'serial': '01258', 'age': 77, 'symptom': ['dementia', 'vd', 'sivd'], 'class_name': 'Abnormal', 'class_label': 1}
		{'serial': '00836', 'age': 80, 'symptom': ['normal', 'smi'], 'class_name': 'Normal', 'class_label': 0}
		{'serial': '00761', 'age': 75, 'symptom': ['dementia', 'ad', 'load'], 'class_name': 'Abnormal', 'class_label': 1}


In [13]:
import shutil

output_path = r'c:\Users/Minjae/Desktop/caueeg-dataset-test'

for task in ['abnormal.json', 'dementia.json']:
    task_path = os.path.join(data_path, task)
    with open(task_path, 'r') as json_file:
        task_dict = json.load(json_file)
    
    for data in task_dict['test_split']:
        for ext in [('edf', 'edf'), ('feather', 'feather'), ('memmap', 'dat')]:
            shutil.copyfile(os.path.join(data_path, f"signal/{ext[0]}/{data['serial']}.{ext[1]}"), 
                            os.path.join(output_path, f"signal/{ext[0]}/{data['serial']}.{ext[1]}")) 

    for data in task_dict['test_split']:
        shutil.copyfile(os.path.join(data_path, f"event/{data['serial']}.json"), 
                        os.path.join(output_path, f"event/{data['serial']}.json"))

In [14]:
for task in ['abnormal.json', 'dementia.json']:
    task_path = os.path.join(data_path, task)
    with open(task_path, 'r') as json_file:
        task_dict = json.load(json_file)
    
    print(len(task_dict['test_split']))

136
118


-----

## Load the CAUEEG dataset

### Load the whole CAUEEG data as a PyTorch dataset instance without considering the target task (no train/val/test sets and no class label).

In [6]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='edf',
                                                         transform=None)

pprint.pprint(config_data, width=250)
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[1])
print('\n', '-' * 100, '\n')

{'dataset_name': 'CAUEEG dataset',
 'signal_header': ['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 78,
 'serial': '00001',
 'signal': array([[  0., -11., -13., ...,   0.,   0.,   0.],
       [ 29.,  33.,  34., ...,   0.,   0.,   0.],
       [ -3.,  -6.,  -3., ...,   0.,   0.,   0.],
       ...,
       [ -4.,  -2.,   1., ...,   0.,   0.,   0.],
       [112.,  67.,  76., ...,   0.,   0.,   0.],
       [ -1.,  -1.,  -1., ...,   0.,   0.,   0.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 56,
 'serial': '00002',
 'signal': array([[  39.,   58.,   72., ...,    0.,    0.,    0.],
       [   4.,

### Load the CAUEEG task1 datasets as the PyTorch dataset instances.

In [7]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')

pprint.pprint(train_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(val_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'Abnormal'],
 'class_name_to_label': {'Abnormal': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal] and [Abnormal] symptoms',
 'task_name': 'CAUEEG-Abnormal benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 77,
 'class_label': 1,
 'class_name': 'Abnormal',
 'serial': '01258',
 'signal': array([[ 3., -1., -5., ...,  0.,  0.,  0.],
       [ 7., 15.,  9., ...,  0.,  0.,  0.],
       [-6., -5., -3., ...,  0.,  0.,  0.],
       ...,
       [ 4.,  6.,  5., ...,  0.,  0.,  0.],
       [62., 54., 53., ...,  0.,  0.,  0.],
       [ 0.,  0.,  1., ...,  0.,  0.,  0.]]),
 'symptom': ['dementia', 'vd', 'sivd']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 81,
 'class_label': 1,
 'class_name': 'Abnormal',
 'serial': '00152',
 'signal': array([[ 14.,  10.,   2., ...,   0.,   0.,   0.],
       [  8.,   6.,   0.

### Load the CAUEEG task2 datasets as the PyTorch dataset instances.

In [8]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')

pprint.pprint(train_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(val_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal], [MCI], and [Dementia] '
                     'symptoms.',
 'task_name': 'CAUEEG-Dementia benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': array([[30., 15., 18., ...,  0.,  0.,  0.],
       [-3.,  4.,  5., ...,  0.,  0.,  0.],
       [-2.,  7.,  8., ...,  0.,  0.,  0.],
       ...,
       [ 0.,  5.,  7., ...,  0.,  0.,  0.],
       [27., 27., 34., ...,  0.,  0.,  0.],
       [ 0.,  0., -1., ...,  0.,  0.,  0.]]),
 'symptom': ['normal', 'cb_normal']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 80,
 'class_label': 2,
 'class_name': 'Dementia',
 'serial': '00341',
 'signal': array([[ -9.,  -2.,  -3., ...

In [9]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)

num_train = [0, 0]
for d in train_dataset:
    num_train[d['class_label']] += 1
print('train', num_train, sum(num_train))

num_val = [0, 0]
for d in val_dataset:
    num_val[d['class_label']] += 1
print('val', num_val, sum(num_val))
        
num_test = [0, 0]
for d in test_dataset:
    num_test[d['class_label']] += 1
print('test', num_test, sum(num_test))
           
print()
print('total', [num1 + num2 + num3 for num1, num2, num3 in zip(num_train, num_val, num_test)], sum(num_train + num_val + num_test))

train [367, 740] 1107
val [46, 90] 136
test [46, 90] 136

total [459, 920] 1379


In [10]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)

num_train = [0, 0, 0]
for d in train_dataset:
    num_train[d['class_label']] += 1
print('train', num_train, sum(num_train))

num_val = [0, 0, 0]
for d in val_dataset:
    num_val[d['class_label']] += 1
print('val', num_val, sum(num_val))
        
num_test = [0, 0, 0]
for d in test_dataset:
    num_test[d['class_label']] += 1
print('test', num_test, sum(num_test))
           
print()
print('total', [num1 + num2 + num3 for num1, num2, num3 in zip(num_train, num_val, num_test)], sum(num_train + num_val + num_test))

train [367, 334, 249] 950
val [46, 42, 31] 119
test [46, 41, 31] 118

total [459, 417, 311] 1187


### Event information

In [11]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=True, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(train_dataset[0])

{'age': 77,
 'class_label': 1,
 'class_name': 'Abnormal',
 'event': [[0, 'Start Recording'],
           [0, 'New Montage - Montage 002'],
           [1358, 'Eyes Open'],
           [6146, 'Eyes Closed'],
           [11984, 'Eyes Open'],
           [22232, 'Eyes Closed'],
           [27650, 'Eyes Open'],
           [30170, 'Eyes Closed'],
           [36218, 'Eyes Open'],
           [42098, 'Eyes Closed'],
           [48692, 'Eyes Open'],
           [54236, 'Eyes Closed'],
           [65324, 'Eyes Open'],
           [66668, 'Eyes Closed'],
           [72548, 'Eyes Open'],
           [77408, 'Move'],
           [78848, 'Eyes Closed'],
           [84056, 'Eyes Open'],
           [90272, 'Eyes Closed'],
           [96782, 'Eyes Open'],
           [102326, 'Eyes Closed'],
           [108583, 'Eyes Open'],
           [114170, 'Eyes Closed'],
           [121438, 'Photic On - 3.0 Hz'],
           [123454, 'Photic Off'],
           [125511, 'Photic On - 6.0 Hz'],
           [127528, 'Photic Off'

### Data Format: `EDF`

In [12]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'serial': '00587', 'age': 53, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0, 'signal': array([[30., 15., 18., ...,  0.,  0.,  0.],
       [-3.,  4.,  5., ...,  0.,  0.,  0.],
       [-2.,  7.,  8., ...,  0.,  0.,  0.],
       ...,
       [ 0.,  5.,  7., ...,  0.,  0.,  0.],
       [27., 27., 34., ...,  0.,  0.,  0.],
       [ 0.,  0., -1., ...,  0.,  0.,  0.]])}
{'serial': '01301', 'age': 88, 'symptom': ['dementia', 'ad', 'load'], 'class_name': 'Dementia', 'class_label': 2, 'signal': array([[ 18.,  15.,  13., ...,   0.,   0.,   0.],
       [ -2.,   1.,   1., ...,   0.,   0.,   0.],
       [  5.,   0.,  -1., ...,   0.,   0.,   0.],
       ...,
       [  2.,  -3.,  -3., ...,   0.,   0.,   0.],
       [ 51., 115., 103., ...,   0.,   0.,   0.],
       [  0.,  -1.,  -1., ...,   0.,   0.,   0.]])}
CPU times: total: 156 ms
Wall time: 145 ms


### Data Format: `NumPy Memmap`

In [13]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'serial': '00587', 'age': 53, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0, 'signal': memmap([[ 30,  15,  18, ..., -42, -40, -38],
        [ -3,   4,   5, ..., -16, -14, -13],
        [ -2,   7,   8, ...,  -3,  -7,  -9],
        ...,
        [  0,   5,   7, ..., -10, -12, -12],
        [ 27,  27,  34, ..., -25, -27, -27],
        [  0,   0,  -1, ...,   1,   1,  -1]])}
{'serial': '01301', 'age': 88, 'symptom': ['dementia', 'ad', 'load'], 'class_name': 'Dementia', 'class_label': 2, 'signal': memmap([[ 18,  15,  13, ...,  36,  37,  39],
        [ -2,   1,   1, ...,  20,  19,  13],
        [  5,   0,  -1, ...,  -5,  -3,   1],
        ...,
        [  2,  -3,  -3, ...,   9,   9,  11],
        [ 51, 115, 103, ...,  -6, -14,  -8],
        [  0,  -1,  -1, ...,  -1,   0,  -1]])}
CPU times: total: 0 ns
Wall time: 4 ms


---

## PyTorch Transforms

### Random crop

In [14]:
transform = EegRandomCrop(crop_length=100)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path,
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', d['signal'].shape)
    print('\n', '-' * 100, '\n')

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': memmap([[-111, -111, -112, ...,  -80,  -83,  -90],
        [   1,   -2,   -5, ...,    1,   -2,   -5],
        [   4,    4,    4, ...,    5,    5,    5],
        ...,
        [   0,    1,    5, ...,    3,    5,    7],
        [ -28,  -28,  -28, ...,   95,   97,   89],
        [   0,    0,    1, ...,   -1,    0,    0]]),
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: (21, 100)

 ---------------------------------------------------------------------------------------------------- 

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': memmap([[ 40,  46,  51, ...,  17,  18,  22],
        [  3,  10,  14, ..., -11, -10,  -9],
        [  2,   4,   8, ...,  -4,  -3,  -2],
        ...,
        [  1,  -2,  -1, ...,   8,   5,   4],
        [ 72,  66,  62, ...,  -6,  -7, -10],
        [ -1,  -1,  -1, ...,   0,   0,   0]]),
 'symptom': ['normal', 'cb_normal']}

>>> signal 

### Random crop with multiple cropping

In [15]:
transform = EegRandomCrop(crop_length=200, multiple=2)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': [memmap([[-102, -100, -100, ...,  -26,  -26,  -23],
        [   0,    3,    5, ...,   13,   16,   19],
        [   4,    3,    3, ...,    0,    3,    6],
        ...,
        [   5,    3,    1, ...,  -10,  -12,  -13],
        [  11,    8,    5, ...,   33,   29,   26],
        [  -1,   -1,    2, ...,    1,    0,    0]]),
            memmap([[  5,   2,  -6, ..., -84, -84, -85],
        [ -6, -10, -14, ...,  -7, -10, -10],
        [ -8,  -7,  -4, ...,   4,   0,  -2],
        ...,
        [ -4,  -2,   1, ...,  -2,  -4,  -6],
        [ 84,  82,  75, ...,   4,  17,  31],
        [  0,   0,   1, ...,   3,   3,  -1]])],
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: [(21, 200), (21, 200)]

 ---------------------------------------------------------------------------------------------------- 

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': [memmap([[-38, -40, -4

### Random crop with multiple cropping and latency

In [16]:
transform = EegRandomCrop(crop_length=300, multiple=3, latency=50000, return_timing=True)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2): 
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'crop_timing': [81946, 65630, 86176],
 'serial': '00587',
 'signal': [memmap([[-30, -33, -29, ..., -15, -17, -17],
        [ -4,  -5,  -5, ...,   7,   3,  -1],
        [  0,   1,   0, ...,   4,   4,   2],
        ...,
        [  3,   4,   3, ...,   2,   1,  -2],
        [ 63,  61,  59, ...,  -8,  -9, -10],
        [  2,   0,  -1, ...,   1,  -1,   0]]),
            memmap([[-18, -19, -20, ...,  99,  98,  96],
        [ -3,  -3,  -3, ...,  20,  19,  16],
        [ -7,  -5,  -4, ...,  -5,  -5,  -6],
        ...,
        [ -6,  -4,  -4, ..., -13, -11, -11],
        [-37, -38, -36, ...,  14,   8,   4],
        [ -1,   3,   3, ...,  -1,   3,   3]]),
            memmap([[230, 212, 201, ..., 112, 111, 105],
        [ 44,  41,  37, ...,   7,  10,  11],
        [  0,   0,   0, ...,  -6,  -7,  -5],
        ...,
        [ -7,  -6,  -6, ...,  -6,  -6,  -5],
        [ 12,   9,   6, ..., -22, -24, -25],
        [ -1,  -1,   0, ...,  -1,  -1,   

### Random crop with multiple cropping, latency, and max length limit

In [17]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200, 
                  length_limit=50300,
                  multiple=3, 
                  latency=50000, 
                  return_timing=True)
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 77,
 'class_label': 1,
 'class_name': 'Abnormal',
 'crop_timing': [50049, 50078, 50083],
 'serial': '01258',
 'signal': [memmap([[103,  96,  89, ..., -13,  -1,   1],
        [ 58,  68,  75, ...,   4,  16,   9],
        [-25, -30, -29, ..., -42, -43, -40],
        ...,
        [-13, -16, -13, ...,  -3,  -3,  -4],
        [ 25,  15,   4, ..., -22, -27, -37],
        [  0,   0,   0, ...,  -1,  -1,  -1]]),
            memmap([[ 11,  13,   7, ...,  -7,   2,   3],
        [ 42,  19,  16, ...,  16,  24,  21],
        [-27, -23, -28, ..., -43, -44, -40],
        ...,
        [  0,   2,   1, ...,   7,   4,   3],
        [-37, -36, -39, ..., 107, 106, 100],
        [  0,   0,  -1, ...,   0,   0,  -1]]),
            memmap([[ -7, -13, -15, ...,  -1,  -7,  -4],
        [ 27,  13,  22, ...,  14,  -4,  -7],
        [-30, -31, -33, ..., -29, -29, -30],
        ...,
        [  2,   2,   1, ...,   9,  10,  10],
        [-68, -75, -83, ...,  89,  68,  72],
        [  3,  -1,   0, ...,   1,   2, 

### Drop channel(s)

In [18]:
anno_path = os.path.join(data_path, 'annotation.json')
with open(anno_path, 'r') as json_file:
    annotation = json.load(json_file)
signal_headers = annotation['signal_header']
del annotation
print(signal_headers)

channel_ekg = signal_headers.index('EKG')
print('channel_ekg: ', channel_ekg)

channel_photic = signal_headers.index('Photic')
print('channel_photic: ', channel_photic)

['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']
channel_ekg:  19
channel_photic:  20


In [19]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=EegDropChannels(channel_ekg))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 183200)
[[   3   -1   -5 ...   -5   -2   -9]
 [   7   15    9 ...    8  -10  -15]
 [  -6   -5   -3 ...   -2    0    3]
 ...
 [   4    6    5 ...    2    6    8]
 [  62   54   53 ...  -22 -137 -282]
 [   0    0    1 ...    0   -1    0]]

----------------------------------------------------------------------------------------------------

after: (20, 183200)
[[  3  -1  -5 ...  -5  -2  -9]
 [  7  15   9 ...   8 -10 -15]
 [ -6  -5  -3 ...  -2   0   3]
 ...
 [  0   6   3 ...  10   7   5]
 [  4   6   5 ...   2   6   8]
 [  0   0   1 ...   0  -1   0]]


In [20]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=EegDropChannels(channel_photic))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 183200)
[[   3   -1   -5 ...   -5   -2   -9]
 [   7   15    9 ...    8  -10  -15]
 [  -6   -5   -3 ...   -2    0    3]
 ...
 [   4    6    5 ...    2    6    8]
 [  62   54   53 ...  -22 -137 -282]
 [   0    0    1 ...    0   -1    0]]

----------------------------------------------------------------------------------------------------

after: (20, 183200)
[[   3   -1   -5 ...   -5   -2   -9]
 [   7   15    9 ...    8  -10  -15]
 [  -6   -5   -3 ...   -2    0    3]
 ...
 [   0    6    3 ...   10    7    5]
 [   4    6    5 ...    2    6    8]
 [  62   54   53 ...  -22 -137 -282]]


In [21]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=EegDropChannels([channel_ekg, channel_photic]))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 183200)
[[   3   -1   -5 ...   -5   -2   -9]
 [   7   15    9 ...    8  -10  -15]
 [  -6   -5   -3 ...   -2    0    3]
 ...
 [   4    6    5 ...    2    6    8]
 [  62   54   53 ...  -22 -137 -282]
 [   0    0    1 ...    0   -1    0]]

----------------------------------------------------------------------------------------------------

after: (19, 183200)
[[  3  -1  -5 ...  -5  -2  -9]
 [  7  15   9 ...   8 -10 -15]
 [ -6  -5  -3 ...  -2   0   3]
 ...
 [-30 -27 -27 ...  22  18  14]
 [  0   6   3 ...  10   7   5]
 [  4   6   5 ...   2   6   8]]


### To Tensor

In [22]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='memmap',
                                                         transform=None)
print('Before:')
pprint.pprint(full_eeg_dataset[0])

print()
print('-' * 100)
print()

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='memmap',
                                                         transform=EegToTensor())
print('After:')
pprint.pprint(full_eeg_dataset[0])

Before:
{'age': 78,
 'serial': '00001',
 'signal': memmap([[  0, -11, -13, ...,  18,  21,  22],
        [ 29,  33,  34, ...,  -7,  -4,  -4],
        [ -3,  -6,  -3, ...,  -1,   1,   2],
        ...,
        [ -4,  -2,   1, ...,   0,  -1,  -1],
        [112,  67,  76, ..., -13, -15, -11],
        [ -1,  -1,  -1, ...,  -1,  -1,  -1]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

----------------------------------------------------------------------------------------------------

After:
{'age': tensor(78.),
 'serial': '00001',
 'signal': tensor([[  0., -11., -13.,  ...,  18.,  21.,  22.],
        [ 29.,  33.,  34.,  ...,  -7.,  -4.,  -4.],
        [ -3.,  -6.,  -3.,  ...,  -1.,   1.,   2.],
        ...,
        [ -4.,  -2.,   1.,  ...,   0.,  -1.,  -1.],
        [112.,  67.,  76.,  ..., -13., -15., -11.],
        [ -1.,  -1.,  -1.,  ...,  -1.,  -1.,  -1.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}


### Compose the above all in one

In [23]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=4, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

pprint.pprint(train_dataset[0])

{'age': tensor(77.),
 'class_label': tensor(1),
 'class_name': 'Abnormal',
 'serial': '01258',
 'signal': [tensor([[ 60.,  70.,  65.,  ...,  -8.,  -7.,   0.],
        [ 19.,  35.,  38.,  ...,  25.,  12.,   1.],
        [ 17.,  16.,  17.,  ..., -36., -34., -34.],
        ...,
        [ 18.,  17.,  16.,  ..., -10., -10., -12.],
        [  0.,   0.,   1.,  ...,   0.,   2.,   3.],
        [  3.,   5.,  11.,  ..., -51., -28., -25.]]),
            tensor([[-307., -306., -306.,  ...,  -31.,  -26.,  -33.],
        [ -16.,  -18.,  -26.,  ...,  -13.,   -2.,    1.],
        [  -2.,   -4.,   -6.,  ...,   11.,   13.,   17.],
        ...,
        [  -2.,   -1.,    0.,  ...,    4.,    5.,    2.],
        [  -2.,   -3.,   -4.,  ...,   12.,   13.,   14.],
        [ -15.,  -17.,  -16.,  ...,   23.,   44.,   51.]]),
            tensor([[ 18.,  21.,  24.,  ...,  12.,  13.,  14.],
        [ 12.,   9.,   4.,  ...,   7.,  11.,  10.],
        [-39., -37., -34.,  ...,  12.,  10.,  10.],
        ...,
        [ 

---

## PyTorch DataLoader

In [24]:
if device.type == 'cuda':
    num_workers = 0  # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

In [25]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='memmap',
                                                         transform=transform)

full_loader = DataLoader(full_eeg_dataset,
                         batch_size=4,
                         shuffle=True,
                         drop_last=True,
                         num_workers=num_workers,
                         pin_memory=pin_memory,
                         collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(full_loader):
    pprint.pprint(sample_batched)
    break

{'age': tensor([77., 77., 63., 63., 82., 82., 77., 77.]),
 'serial': ['00691',
            '00691',
            '00570',
            '00570',
            '00460',
            '00460',
            '00718',
            '00718'],
 'signal': tensor([[[ 251.,  247.,  231.,  ...,    5.,   23.,   18.],
         [  35.,   34.,   43.,  ...,  -28.,  -29.,  -28.],
         [   0.,    2.,    5.,  ...,   -4.,   -8.,  -12.],
         ...,
         [  10.,   12.,   15.,  ...,   -2.,   -1.,    0.],
         [ -10.,   -9.,  -10.,  ...,   -3.,   -2.,   -1.],
         [ -11.,  -21.,  -32.,  ...,  129.,  139.,  146.]],

        [[  55.,   54.,   49.,  ...,  547.,  535.,  549.],
         [  26.,   23.,   20.,  ...,   86.,   95.,   92.],
         [  16.,   17.,   16.,  ...,  -40.,  -39.,  -43.],
         ...,
         [  15.,   14.,   12.,  ...,  -54.,  -52.,  -51.],
         [   9.,    9.,    9.,  ...,  -40.,  -40.,  -41.],
         [  86.,   90.,   93.,  ...,  237.,  239.,  234.]],

        [[  37.,   38.

In [26]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=8,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    pprint.pprint(sample_batched, width=250)
    break

{'age': tensor([66., 66., 83., 83., 60., 60., 54., 54., 55., 55., 81., 81., 72., 72.,
        80., 80.]),
 'class_label': tensor([0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]),
 'class_name': ['Normal', 'Normal', 'Abnormal', 'Abnormal', 'Normal', 'Normal', 'Normal', 'Normal', 'Normal', 'Normal', 'Normal', 'Normal', 'Abnormal', 'Abnormal', 'Abnormal', 'Abnormal'],
 'serial': ['00085', '00085', '01050', '01050', '01006', '01006', '01135', '01135', '00324', '00324', '00721', '00721', '00942', '00942', '00440', '00440'],
 'signal': tensor([[[   9.,   10.,   13.,  ...,   11.,   13.,   13.],
         [  32.,   32.,   34.,  ...,   48.,   49.,   50.],
         [  11.,   12.,   15.,  ...,    3.,    4.,    5.],
         ...,
         [   3.,    3.,    2.,  ...,   19.,   16.,   15.],
         [ -13.,  -12.,  -13.,  ...,   -9.,  -10.,   -9.],
         [  -3.,   -1.,    9.,  ...,   -5.,    6.,    9.]],

        [[ -15.,  -16.,  -19.,  ...,  -25.,  -26.,  -28.],
         [  94.,   90.,   91.,  ..

---

## Preprocessing steps run by the PyTorch Modules

In [27]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=2,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

### To GPU device if it is possible

In [28]:
print('device:', device)
print()

preprocess_train = transforms.Compose([EegToDevice(device=device)])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched)
    break

device: cuda

Sequential(
  (0): EegToDevice(device=cuda)
)
- Before -
{'age': tensor([57., 57., 70., 70.]),
 'class_label': tensor([0, 0, 0, 0]),
 'class_name': ['Normal', 'Normal', 'Normal', 'Normal'],
 'serial': ['00497', '00497', '00103', '00103'],
 'signal': tensor([[[   3.,    1.,   -4.,  ...,  -10.,   -7.,   -6.],
         [  -3.,   -2.,   -1.,  ...,   -3.,   -1.,    0.],
         [   1.,    4.,    3.,  ...,   -3.,   -3.,   -4.],
         ...,
         [  -5.,   -4.,   -7.,  ...,   -6.,   -5.,   -5.],
         [  -5.,   -4.,   -5.,  ...,    3.,    5.,    3.],
         [  44.,   81.,   98.,  ...,  -64.,  -28.,  -12.]],

        [[ -44.,  -46.,  -48.,  ...,    1.,   -3.,   -3.],
         [ -17.,  -17.,  -18.,  ...,   -2.,   -5.,   -5.],
         [  -5.,   -5.,   -5.,  ...,    7.,    9.,    8.],
         ...,
         [  -5.,   -3.,   -3.,  ...,   -5.,   -3.,   -3.],
         [   2.,    4.,    4.,  ...,   -9.,   -5.,   -4.],
         [  63.,   77.,  102.,  ...,  -63.,  -38.,  -37.]

### Normalization per signal

In [29]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizePerSignal()
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizePerSignal(eps=1e-08)
)
- Before -
Mean: tensor([[ 4.0060,  0.6150,  0.7960,  1.3665,  0.3060,  0.4480, -1.6430, -0.5550,
         -0.2125,  0.1675, -0.1275,  1.3010,  0.3540, -0.5905, -3.8095, -0.4695,
          0.9610,  0.1305,  0.0095, -0.2920],
        [ 0.7465,  0.3265, -0.4710, -0.3020, -0.5850, -1.3090, -0.9440, -0.7555,
         -0.9220, -0.5595,  2.9785,  0.3035, -0.6420, -1.8255,  0.1035, -0.7305,
         -0.2445,  0.7690,  0.8820,  0.3430],
        [-4.2775, -1.3325,  0.6585, -0.0600, -1.0890, -1.2195,  1.1875, -1.2030,
         -0.1740,  0.1605,  0.0220,  0.6605,  0.0855,  1.6840,  1.2235,  0.9715,
         -3.1590, -0.1015,  0.4140, -0.0960],
        [ 4.4520,  1.2520, -0.2580,  0.3045,  0.6555, -2.3015,  0.0335, -0.6465,
          0.2775, -0.3430, -0.8940,  1.4865,  0.0755, -2.5560, -0.8160, -0.0795,
          1.1965,  0.6850, -0.3305,  0.5715]])

Std: tensor([[ 10.0899,   7.9725,   5.3464,   6.3704,   9.3541,

### Signal normalization using the specified mean and std values

In [30]:
signal_mean, signal_std = calculate_signal_statistics(train_loader, repeats=1, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Mean and standard deviation for signal:
tensor([[[-0.0359],
         [ 0.0304],
         [-0.1009],
         [ 0.0230],
         [-0.0015],
         [-0.0228],
         [ 0.1051],
         [-0.1024],
         [-0.0823],
         [ 0.1393],
         [ 0.0105],
         [ 0.1269],
         [ 0.2591],
         [ 0.0407],
         [-0.0728],
         [-0.0767],
         [-0.2653],
         [ 0.0130],
         [-0.0072],
         [ 0.0019]]])
-
tensor([[[44.5781],
         [19.7245],
         [11.6635],
         [12.1446],
         [14.6613],
         [47.3535],
         [19.4451],
         [10.3273],
         [11.5484],
         [15.2525],
         [19.9712],
         [13.9477],
         [13.4642],
         [21.2484],
         [17.3686],
         [14.2186],
         [19.4747],
         [10.9813],
         [11.1437],
         [94.1633]]])

----------------------------------------------------------------------------------------------------

Sequential(
  (0): EegToDevice(device=cuda)
  (1): 

### Age normalization

In [31]:
age_mean, age_std = calculate_age_statistics(train_loader, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['age'])

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['age'])
    break

Age mean and standard deviation:
tensor([71.2453]) tensor([6.2074])

----------------------------------------------------------------------------------------------------

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([6.2074]),eps=1e-08)
)
- Before -
tensor([63., 63., 76., 76.])

----------------------------------------------------------------------------------------------------

- After -
tensor([-1.3283, -1.3283,  0.7660,  0.7660], device='cuda:0')


### Short time Fourier transform (STFT or spectrogram)

In [32]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['signal'].shape)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['signal'].shape)
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
)
- Before -
torch.Size([4, 20, 2000])

----------------------------------------------------------------------------------------------------

- After -
torch.Size([4, 40, 101, 41])


### Signal normalization after STFT

In [33]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)

signal_2d_mean, signal_2d_std = calculate_signal_statistics(train_loader, preprocess_train)

preprocess_train2 = transforms.Compose([
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std)
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)

pprint.pprint(preprocess_train)
pprint.pprint(preprocess_train2)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    preprocess_train(sample_batched)   
    
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    
    print()
    print('-' * 100)
    print()
    
    print('- After -')
    preprocess_train2(sample_batched)
    
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
)
Sequential(
  (0): EegNormalizeMeanStd(mean=tensor([[ 1.9803e+01, -1.1851e-01, -6.5353e-02,  ...,  5.7606e-02,
            6.0533e-02,  5.8555e-02],
          [ 5.1859e+00, -9.6052e-02, -1.2185e-01,  ..., -2.4899e-03,
           -3.4110e-03,  1.1990e-02],
          [-1.8010e+00, -5.8296e-02, -3.1889e-01,  ...,  3.3537e-03,
            2.3817e-03, -1.5474e-02],
          ...,
          [ 0.0000e+00,  3.1651e-01,  2.3858e-01,  ...,  1.4515e-03,
            2.5028e-03, -4.6628e-09],
          [ 0.0000e+00,  2.7054e-01,  2.6525e-01,  ..., -3.3427e-04,
           -9.2521e-04, -1.5035e-08],
          [ 0.0000e+00, -1.1111e+00, -2.6718e-01,  ..., -1.8196e-03,
           -1.3766e-02,  1.4751e-07]], device='cuda:0'),std=tensor([[7.5293e+03, 1.5227e+03, 7.7860e+02,  ..., 2.9198e+01, 2.9174e+01,
           2.9401e+01],
          [3.3492e+03, 5.9845e+02, 3.0433e+02,  ..., 1.3521e+01

---

## Speed check without STFT

In [34]:
crop_length = 200 * 10
multiple = 4
batch_size = 128

### `EDF`

In [35]:
%%time
transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=120000, multiple=4, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([6.2074]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([-0.0359,  0.0304, -0.1009,  0.0230, -0.0015, -0.0228,  0.1051, -0.1024,
          -0.0823,  0.1393,  0.0105,  0.1269,  0.2591,  0.0407, -0.0728, -0.0767,
          -0.2653,  0.0130, -0.0072,  0.0019]),std=tensor([44.5781, 19.7245, 11.6635, 12.1446, 14.6613, 47.3535, 19.4451, 10.3273,
          11.5484, 15.2525, 19.9712, 13.9477, 13.4642, 21.2484, 17.3686, 14.2186,
          19.4747, 10.9813, 11.1437, 94.1633]),eps=1e-08)
)
CPU times: total: 24min 2s
Wall time: 2min 1s


### `memmap`

In [36]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegRandomCrop(crop_length=2000, length_limit=120000, multiple=4, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([6.2074]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([-0.0359,  0.0304, -0.1009,  0.0230, -0.0015, -0.0228,  0.1051, -0.1024,
          -0.0823,  0.1393,  0.0105,  0.1269,  0.2591,  0.0407, -0.0728, -0.0767,
          -0.2653,  0.0130, -0.0072,  0.0019]),std=tensor([44.5781, 19.7245, 11.6635, 12.1446, 14.6613, 47.3535, 19.4451, 10.3273,
          11.5484, 15.2525, 19.9712, 13.9477, 13.4642, 21.2484, 17.3686, 14.2186,
          19.4747, 10.9813, 11.1437, 94.1633]),eps=1e-08)
)
CPU times: total: 15.4 s
Wall time: 1.27 s


### `memmap` (Drop â†’ Crop)

In [37]:
%%time

transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

Compose(
    EegDropChannels(drop_index=20)
    EegRandomCrop(crop_length=2000, length_limit=120000, multiple=4, latency=2000, return_timing=False)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([6.2074]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([-0.0359,  0.0304, -0.1009,  0.0230, -0.0015, -0.0228,  0.1051, -0.1024,
          -0.0823,  0.1393,  0.0105,  0.1269,  0.2591,  0.0407, -0.0728, -0.0767,
          -0.2653,  0.0130, -0.0072,  0.0019]),std=tensor([44.5781, 19.7245, 11.6635, 12.1446, 14.6613, 47.3535, 19.4451, 10.3273,
          11.5484, 15.2525, 19.9712, 13.9477, 13.4642, 21.2484, 17.3686, 14.2186,
          19.4747, 10.9813, 11.1437, 94.1633]),eps=1e-08)
)
CPU times: total: 1min 14s
Wall time: 6.22 s


---

## Speed check with STFT

In [38]:
crop_length = 300 * 10
n_fft, hop_length, seq_len_2d = calculate_stft_params(seq_length=crop_length, verbose=True)
multiple = 2
batch_size = 128

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
signal_2d_mean, signal_2d_std = calculate_signal_statistics(train_loader, preprocess_train)

Input sequence length: (3000) would become (78, 77) after the STFT with n_fft (155) and hop_length (39).

----------------------------------------------------------------------------------------------------



### `EDF`

In [39]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

Compose(
    EegRandomCrop(crop_length=3000, length_limit=120000, multiple=2, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([6.2074]),eps=1e-08)
  (2): EegSpectrogram(n_fft=155, complex_mode=as_real, stft_kwargs={'hop_length': 39})
  (3): EegNormalizeMeanStd(mean=tensor([[-1.2195e+01, -5.3448e-02, -1.5063e-01,  ..., -4.7347e-02,
           -4.8993e-02, -4.8371e-02],
          [-5.4980e+00,  4.6251e-02, -1.8641e-02,  ..., -1.0346e-03,
            1.4972e-02,  1.0688e-02],
          [ 5.0655e+00,  5.3347e-03, -6.6118e-02,  ...,  1.2595e-02,
            1.2875e-02,  1.2540e-02],
          ...,
          [ 0.0000e+00, -2.5794e-02,  1.0741e-02,  ..., -2.9688e-03,
            1.2981e-03,  4.6372e-04],
          [ 0.0000e+00, -5.5231e-01, -2.4670e-01,  ..., -2.3832e-03,
            1.2946e-04,  1.5863e-04],
          [ 0.0000e+00,  1.1617e+00, -2.201

### `memmap`

In [40]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

Compose(
    EegRandomCrop(crop_length=3000, length_limit=120000, multiple=2, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([6.2074]),eps=1e-08)
  (2): EegSpectrogram(n_fft=155, complex_mode=as_real, stft_kwargs={'hop_length': 39})
  (3): EegNormalizeMeanStd(mean=tensor([[-1.2195e+01, -5.3448e-02, -1.5063e-01,  ..., -4.7347e-02,
           -4.8993e-02, -4.8371e-02],
          [-5.4980e+00,  4.6251e-02, -1.8641e-02,  ..., -1.0346e-03,
            1.4972e-02,  1.0688e-02],
          [ 5.0655e+00,  5.3347e-03, -6.6118e-02,  ...,  1.2595e-02,
            1.2875e-02,  1.2540e-02],
          ...,
          [ 0.0000e+00, -2.5794e-02,  1.0741e-02,  ..., -2.9688e-03,
            1.2981e-03,  4.6372e-04],
          [ 0.0000e+00, -5.5231e-01, -2.4670e-01,  ..., -2.3832e-03,
            1.2946e-04,  1.5863e-04],
          [ 0.0000e+00,  1.1617e+00, -2.201

---

## Test on longer sequence

In [41]:
%%time
longer_transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10*6,     # crop: 1m
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(longer_transform)

config_data, longer_test_dataset = load_caueeg_task_split(dataset_path=data_path, 
                                                          task='dementia', 
                                                          split='test',
                                                          load_event=False,
                                                          file_format='memmap', 
                                                          transform=longer_transform)

longer_test_loader = DataLoader(longer_test_dataset,
                                batch_size=32,
                                shuffle=True,
                                drop_last=False,
                                num_workers=num_workers,
                                pin_memory=pin_memory,
                                collate_fn=eeg_collate_fn)
 
preprocess_test = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
    EegNormalizeAge(mean=age_mean, std=age_std),
])
preprocess_test = torch.nn.Sequential(*preprocess_test.transforms).to(device)
pprint.pprint(preprocess_test)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_test(sample_batched)

Compose(
    EegRandomCrop(crop_length=12000, length_limit=120000, multiple=2, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeMeanStd(mean=tensor([-0.0359,  0.0304, -0.1009,  0.0230, -0.0015, -0.0228,  0.1051, -0.1024,
          -0.0823,  0.1393,  0.0105,  0.1269,  0.2591,  0.0407, -0.0728, -0.0767,
          -0.2653,  0.0130, -0.0072,  0.0019]),std=tensor([44.5781, 19.7245, 11.6635, 12.1446, 14.6613, 47.3535, 19.4451, 10.3273,
          11.5484, 15.2525, 19.9712, 13.9477, 13.4642, 21.2484, 17.3686, 14.2186,
          19.4747, 10.9813, 11.1437, 94.1633]),eps=1e-08)
  (2): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([6.2074]),eps=1e-08)
)
CPU times: total: 1min 4s
Wall time: 12.5 s
