# Dataset and DataLoader

This notebook loads the `CAUEEG` dataset, tests some useful preprocessing, and makes up the PyTorch DataLoader instances for the training.

-----

## Configurations

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
import glob
import json
import pprint

import numpy as np
import random
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

# custom package
from datasets.caueeg_dataset import *
from datasets.caueeg_script import *
from datasets.pipeline import *

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.12.1+cu113
cuda is available.


In [4]:
# Data file path
data_path = r'local/dataset/caueeg-dataset/'

In [5]:
for task in ['annotation.json', 'abnormal.json', 'dementia.json']:
    task_path = os.path.join(data_path, task)
    with open(task_path, 'r') as json_file:
        task_dict = json.load(json_file)
        
    print('{')
    for k, v in task_dict.items():
        print(f'\t{k}:')
        if isinstance(v, list) and len(v) > 3:
            print(f'\t\t{v[0]}')
            print(f'\t\t{v[1]}')
            print(f'\t\t{v[2]}')
            print(f'\t\t.')
            print(f'\t\t.')
            print(f'\t\t.')
            print(f'\t\t{v[-1]}')
        else:
            print(f'\t\t{v}')
        print()
    print('}')

{
	dataset_name:
		CAUEEG dataset

	signal_header:
		Fp1-AVG
		F3-AVG
		C3-AVG
		.
		.
		.
		Photic

	data:
		{'serial': '00001', 'age': 78, 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}
		{'serial': '00002', 'age': 56, 'symptom': ['normal', 'smi']}
		{'serial': '00003', 'age': 93, 'symptom': ['mci', 'mci_vascular']}
		.
		.
		.
		{'serial': '01388', 'age': 73, 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_ef']}

}
{
	task_name:
		CAUEEG-Abnormal benchmark

	task_description:
		Classification of [Normal] and [Abnormal] symptoms

	class_label_to_name:
		['Normal', 'Abnormal']

	class_name_to_label:
		{'Normal': 0, 'Abnormal': 1}

	train_split:
		{'serial': '01258', 'age': 77, 'symptom': ['dementia', 'vd', 'sivd'], 'class_name': 'Abnormal', 'class_label': 1}
		{'serial': '00836', 'age': 80, 'symptom': ['normal', 'smi'], 'class_name': 'Normal', 'class_label': 0}
		{'serial': '00761', 'age': 75, 'symptom': ['dementia', 'ad', 'load'], 'class_name': 'Abnormal', 'class_label': 1}


-----

## Load the CAUEEG dataset

### Load the whole CAUEEG data as a PyTorch dataset instance without considering the target task (no train/val/test sets and no class label).

In [6]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='edf',
                                                         transform=None)

pprint.pprint(config_data, width=250)
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[1])
print('\n', '-' * 100, '\n')

{'dataset_name': 'CAUEEG dataset',
 'signal_header': ['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 78,
 'serial': '00001',
 'signal': array([[  0., -11., -13., ...,   0.,   0.,   0.],
       [ 29.,  33.,  34., ...,   0.,   0.,   0.],
       [ -3.,  -6.,  -3., ...,   0.,   0.,   0.],
       ...,
       [ -4.,  -2.,   1., ...,   0.,   0.,   0.],
       [112.,  67.,  76., ...,   0.,   0.,   0.],
       [ -1.,  -1.,  -1., ...,   0.,   0.,   0.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 56,
 'serial': '00002',
 'signal': array([[  39.,   58.,   72., ...,    0.,    0.,    0.],
       [   4.,

### Load the CAUEEG task1 datasets as the PyTorch dataset instances.

In [7]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')

pprint.pprint(train_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(val_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'Abnormal'],
 'class_name_to_label': {'Abnormal': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal] and [Abnormal] symptoms',
 'task_name': 'CAUEEG-Abnormal benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 77,
 'class_label': 1,
 'class_name': 'Abnormal',
 'serial': '01258',
 'signal': array([[ 3., -1., -5., ...,  0.,  0.,  0.],
       [ 7., 15.,  9., ...,  0.,  0.,  0.],
       [-6., -5., -3., ...,  0.,  0.,  0.],
       ...,
       [ 4.,  6.,  5., ...,  0.,  0.,  0.],
       [62., 54., 53., ...,  0.,  0.,  0.],
       [ 0.,  0.,  1., ...,  0.,  0.,  0.]]),
 'symptom': ['dementia', 'vd', 'sivd']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 81,
 'class_label': 1,
 'class_name': 'Abnormal',
 'serial': '00152',
 'signal': array([[ 14.,  10.,   2., ...,   0.,   0.,   0.],
       [  8.,   6.,   0.

### Load the CAUEEG task2 datasets as the PyTorch dataset instances.

In [8]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')

pprint.pprint(train_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(val_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal], [MCI], and [Dementia] '
                     'symptoms.',
 'task_name': 'CAUEEG-Dementia benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': array([[30., 15., 18., ...,  0.,  0.,  0.],
       [-3.,  4.,  5., ...,  0.,  0.,  0.],
       [-2.,  7.,  8., ...,  0.,  0.,  0.],
       ...,
       [ 0.,  5.,  7., ...,  0.,  0.,  0.],
       [27., 27., 34., ...,  0.,  0.,  0.],
       [ 0.,  0., -1., ...,  0.,  0.,  0.]]),
 'symptom': ['normal', 'cb_normal']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 80,
 'class_label': 2,
 'class_name': 'Dementia',
 'serial': '00341',
 'signal': array([[ -9.,  -2.,  -3., ...

In [9]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)

num_train = [0, 0]
for d in train_dataset:
    num_train[d['class_label']] += 1
print('train', num_train, sum(num_train))

num_val = [0, 0]
for d in val_dataset:
    num_val[d['class_label']] += 1
print('val', num_val, sum(num_val))
        
num_test = [0, 0]
for d in test_dataset:
    num_test[d['class_label']] += 1
print('test', num_test, sum(num_test))
           
print()
print('total', [num1 + num2 + num3 for num1, num2, num3 in zip(num_train, num_val, num_test)], sum(num_train + num_val + num_test))

train [367, 740] 1107
val [46, 90] 136
test [46, 90] 136

total [459, 920] 1379


In [10]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)

num_train = [0, 0, 0]
for d in train_dataset:
    num_train[d['class_label']] += 1
print('train', num_train, sum(num_train))

num_val = [0, 0, 0]
for d in val_dataset:
    num_val[d['class_label']] += 1
print('val', num_val, sum(num_val))
        
num_test = [0, 0, 0]
for d in test_dataset:
    num_test[d['class_label']] += 1
print('test', num_test, sum(num_test))
           
print()
print('total', [num1 + num2 + num3 for num1, num2, num3 in zip(num_train, num_val, num_test)], sum(num_train + num_val + num_test))

train [367, 334, 249] 950
val [46, 42, 31] 119
test [46, 41, 31] 118

total [459, 417, 311] 1187


### Event information

In [11]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=True, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(train_dataset[0])

{'age': 77,
 'class_label': 1,
 'class_name': 'Abnormal',
 'event': [[0, 'Start Recording'],
           [0, 'New Montage - Montage 002'],
           [1358, 'Eyes Open'],
           [6146, 'Eyes Closed'],
           [11984, 'Eyes Open'],
           [22232, 'Eyes Closed'],
           [27650, 'Eyes Open'],
           [30170, 'Eyes Closed'],
           [36218, 'Eyes Open'],
           [42098, 'Eyes Closed'],
           [48692, 'Eyes Open'],
           [54236, 'Eyes Closed'],
           [65324, 'Eyes Open'],
           [66668, 'Eyes Closed'],
           [72548, 'Eyes Open'],
           [77408, 'Move'],
           [78848, 'Eyes Closed'],
           [84056, 'Eyes Open'],
           [90272, 'Eyes Closed'],
           [96782, 'Eyes Open'],
           [102326, 'Eyes Closed'],
           [108583, 'Eyes Open'],
           [114170, 'Eyes Closed'],
           [121438, 'Photic On - 3.0 Hz'],
           [123454, 'Photic Off'],
           [125511, 'Photic On - 6.0 Hz'],
           [127528, 'Photic Off'

### Data Format: `EDF`

In [12]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'serial': '00587', 'age': 53, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0, 'signal': array([[30., 15., 18., ...,  0.,  0.,  0.],
       [-3.,  4.,  5., ...,  0.,  0.,  0.],
       [-2.,  7.,  8., ...,  0.,  0.,  0.],
       ...,
       [ 0.,  5.,  7., ...,  0.,  0.,  0.],
       [27., 27., 34., ...,  0.,  0.,  0.],
       [ 0.,  0., -1., ...,  0.,  0.,  0.]])}
{'serial': '01301', 'age': 88, 'symptom': ['dementia', 'ad', 'load'], 'class_name': 'Dementia', 'class_label': 2, 'signal': array([[ 18.,  15.,  13., ...,   0.,   0.,   0.],
       [ -2.,   1.,   1., ...,   0.,   0.,   0.],
       [  5.,   0.,  -1., ...,   0.,   0.,   0.],
       ...,
       [  2.,  -3.,  -3., ...,   0.,   0.,   0.],
       [ 51., 115., 103., ...,   0.,   0.,   0.],
       [  0.,  -1.,  -1., ...,   0.,   0.,   0.]])}
CPU times: total: 141 ms
Wall time: 152 ms


### Data Format: `NumPy Memmap`

In [13]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'serial': '00587', 'age': 53, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0, 'signal': memmap([[ 30,  15,  18, ..., -42, -40, -38],
        [ -3,   4,   5, ..., -16, -14, -13],
        [ -2,   7,   8, ...,  -3,  -7,  -9],
        ...,
        [  0,   5,   7, ..., -10, -12, -12],
        [ 27,  27,  34, ..., -25, -27, -27],
        [  0,   0,  -1, ...,   1,   1,  -1]])}
{'serial': '01301', 'age': 88, 'symptom': ['dementia', 'ad', 'load'], 'class_name': 'Dementia', 'class_label': 2, 'signal': memmap([[ 18,  15,  13, ...,  36,  37,  39],
        [ -2,   1,   1, ...,  20,  19,  13],
        [  5,   0,  -1, ...,  -5,  -3,   1],
        ...,
        [  2,  -3,  -3, ...,   9,   9,  11],
        [ 51, 115, 103, ...,  -6, -14,  -8],
        [  0,  -1,  -1, ...,  -1,   0,  -1]])}
CPU times: total: 0 ns
Wall time: 4 ms


---

## PyTorch Transforms

### Random crop

In [14]:
transform = EegRandomCrop(crop_length=100)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path,
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', d['signal'].shape)
    print('\n', '-' * 100, '\n')

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': memmap([[143, 139, 134, ...,  23,  24,  20],
        [ 31,  26,  22, ...,  11,  16,  16],
        [  6,   2,  -2, ...,   5,   7,   9],
        ...,
        [ -4,  -7,  -8, ...,   1,   2,   4],
        [ -5,  -4,  -3, ..., -43, -49, -58],
        [  0,   2,   2, ...,   0,  -1,  -1]]),
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: (21, 100)

 ---------------------------------------------------------------------------------------------------- 

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': memmap([[  65,   69,   66, ...,  103,   95,   94],
        [ -12,  -13,  -13, ...,    6,    2,    0],
        [  -6,   -6,   -8, ...,    0,    0,    0],
        ...,
        [   0,    2,    6, ...,    2,    5,    7],
        [ -17,  -18,  -17, ..., -105, -114, -121],
        [  -1,    0,    2, ...,   -1,    0,   -1]]),
 'symptom': ['normal', 'cb_normal']}

>>> signal 

### Random crop with multiple cropping

In [15]:
transform = EegRandomCrop(crop_length=200, multiple=2)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': [memmap([[ 79,  84,  87, ...,  92,  92,  86],
        [  7,   9,  11, ...,   6,   5,   7],
        [  1,   1,   4, ...,  -3,  -3,  -2],
        ...,
        [  4,   3,   2, ...,  -5,  -6,  -5],
        [114,  99,  95, ..., -16, -22, -31],
        [ -1,   0,   2, ...,   0,  -1,  -1]]),
            memmap([[  24,   24,   21, ...,  -13,  -12,  -10],
        [  -1,   -3,   -3, ...,   13,   13,   17],
        [  -2,   -2,   -2, ...,   -1,    1,    2],
        ...,
        [ -11,   -9,   -5, ...,   -6,   -5,   -8],
        [ 108,   97,   98, ..., -475, -263,    6],
        [   1,    1,   -3, ...,   -1,    0,    0]])],
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: [(21, 200), (21, 200)]

 ---------------------------------------------------------------------------------------------------- 

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': [memmap([[-109, -109, 

### Random crop with multiple cropping and latency

In [16]:
transform = EegRandomCrop(crop_length=300, multiple=3, latency=50000, return_timing=True)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2): 
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'crop_timing': [113726, 71569, 86599],
 'serial': '00587',
 'signal': [memmap([[-20, -22, -22, ..., 115, 110, 108],
        [  1,   4,   3, ...,  17,  14,  14],
        [  3,   3,   3, ..., -11, -10,  -9],
        ...,
        [  0,   2,   3, ..., -15, -11,  -7],
        [ 71,  68,  64, ..., -19, -18, -19],
        [ -1,   1,   2, ...,   0,   1,   1]]),
            memmap([[-41, -42, -45, ..., -63, -71, -74],
        [  2,   0,   1, ...,   8,   8,   8],
        [  8,   7,   4, ...,   5,   7,   8],
        ...,
        [ 10,   9,   9, ...,   2,   3,   2],
        [ 60,  59,  58, ...,  62,  59,  54],
        [  3,   0,  -1, ...,   3,  -1,  -1]]),
            memmap([[ 262,  261,  262, ...,  -15,  -12,  -14],
        [  35,   36,   37, ...,   -4,   -4,   -5],
        [  -2,   -2,   -1, ...,   -3,   -4,   -5],
        ...,
        [ -10,  -11,  -12, ...,   -6,   -6,   -4],
        [-562, -536, -352, ...,   93,   88,   84],
        [ 

### Random crop with multiple cropping, latency, and max length limit

In [17]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200, 
                  length_limit=50300,
                  multiple=3, 
                  latency=50000, 
                  return_timing=True)
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 77,
 'class_label': 1,
 'class_name': 'Abnormal',
 'crop_timing': [50051, 50001, 50048],
 'serial': '01258',
 'signal': [memmap([[ 89,  76,  68, ...,   1,  -1,  -2],
        [ 75,  66,  55, ...,   9,  14,  21],
        [-29, -28, -27, ..., -40, -39, -41],
        ...,
        [-13,  -8,  -4, ...,  -4,  -5,  -4],
        [  4,   0,  10, ..., -37, -39, -26],
        [  0,   0,   0, ...,  -1,   0,  -1]]),
            memmap([[  6,   3,   9, ...,   0,   2,   5],
        [ 44,  38,  40, ...,  21,  21,  34],
        [-18, -19, -21, ..., -43, -41, -44],
        ...,
        [ -5,  -5,  -5, ...,  -4,  -7,  -8],
        [-15, -13, -22, ...,   4,   5,  21],
        [  3,   2,  -1, ...,  -1,   2,   2]]),
            memmap([[111, 103,  96, ...,  -9, -13,  -1],
        [ 47,  58,  68, ...,  12,   4,  16],
        [-20, -25, -30, ..., -41, -42, -43],
        ...,
        [-11, -13, -16, ...,  -3,  -3,  -3],
        [ 25,  25,  15, ..., -28, -22, -27],
        [  0,   0,   0, ...,   0,  -1, 

### Drop channel(s)

In [18]:
anno_path = os.path.join(data_path, 'annotation.json')
with open(anno_path, 'r') as json_file:
    annotation = json.load(json_file)
signal_headers = annotation['signal_header']
del annotation
print(signal_headers)

channel_ekg = signal_headers.index('EKG')
print('channel_ekg: ', channel_ekg)

channel_photic = signal_headers.index('Photic')
print('channel_photic: ', channel_photic)

['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']
channel_ekg:  19
channel_photic:  20


In [19]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=EegDropChannels(channel_ekg))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 183200)
[[   3   -1   -5 ...   -5   -2   -9]
 [   7   15    9 ...    8  -10  -15]
 [  -6   -5   -3 ...   -2    0    3]
 ...
 [   4    6    5 ...    2    6    8]
 [  62   54   53 ...  -22 -137 -282]
 [   0    0    1 ...    0   -1    0]]

----------------------------------------------------------------------------------------------------

after: (20, 183200)
[[  3  -1  -5 ...  -5  -2  -9]
 [  7  15   9 ...   8 -10 -15]
 [ -6  -5  -3 ...  -2   0   3]
 ...
 [  0   6   3 ...  10   7   5]
 [  4   6   5 ...   2   6   8]
 [  0   0   1 ...   0  -1   0]]


In [20]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=EegDropChannels(channel_photic))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 183200)
[[   3   -1   -5 ...   -5   -2   -9]
 [   7   15    9 ...    8  -10  -15]
 [  -6   -5   -3 ...   -2    0    3]
 ...
 [   4    6    5 ...    2    6    8]
 [  62   54   53 ...  -22 -137 -282]
 [   0    0    1 ...    0   -1    0]]

----------------------------------------------------------------------------------------------------

after: (20, 183200)
[[   3   -1   -5 ...   -5   -2   -9]
 [   7   15    9 ...    8  -10  -15]
 [  -6   -5   -3 ...   -2    0    3]
 ...
 [   0    6    3 ...   10    7    5]
 [   4    6    5 ...    2    6    8]
 [  62   54   53 ...  -22 -137 -282]]


In [21]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=EegDropChannels([channel_ekg, channel_photic]))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

before: (21, 183200)
[[   3   -1   -5 ...   -5   -2   -9]
 [   7   15    9 ...    8  -10  -15]
 [  -6   -5   -3 ...   -2    0    3]
 ...
 [   4    6    5 ...    2    6    8]
 [  62   54   53 ...  -22 -137 -282]
 [   0    0    1 ...    0   -1    0]]

----------------------------------------------------------------------------------------------------

after: (19, 183200)
[[  3  -1  -5 ...  -5  -2  -9]
 [  7  15   9 ...   8 -10 -15]
 [ -6  -5  -3 ...  -2   0   3]
 ...
 [-30 -27 -27 ...  22  18  14]
 [  0   6   3 ...  10   7   5]
 [  4   6   5 ...   2   6   8]]


### To Tensor

In [22]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='memmap',
                                                         transform=None)
print('Before:')
pprint.pprint(full_eeg_dataset[0])

print()
print('-' * 100)
print()

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='memmap',
                                                         transform=EegToTensor())
print('After:')
pprint.pprint(full_eeg_dataset[0])

Before:
{'age': 78,
 'serial': '00001',
 'signal': memmap([[  0, -11, -13, ...,  18,  21,  22],
        [ 29,  33,  34, ...,  -7,  -4,  -4],
        [ -3,  -6,  -3, ...,  -1,   1,   2],
        ...,
        [ -4,  -2,   1, ...,   0,  -1,  -1],
        [112,  67,  76, ..., -13, -15, -11],
        [ -1,  -1,  -1, ...,  -1,  -1,  -1]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

----------------------------------------------------------------------------------------------------

After:
{'age': tensor(78.),
 'serial': '00001',
 'signal': tensor([[  0., -11., -13.,  ...,  18.,  21.,  22.],
        [ 29.,  33.,  34.,  ...,  -7.,  -4.,  -4.],
        [ -3.,  -6.,  -3.,  ...,  -1.,   1.,   2.],
        ...,
        [ -4.,  -2.,   1.,  ...,   0.,  -1.,  -1.],
        [112.,  67.,  76.,  ..., -13., -15., -11.],
        [ -1.,  -1.,  -1.,  ...,  -1.,  -1.,  -1.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}


### Compose the above all in one

In [23]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=4, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

pprint.pprint(train_dataset[0])

{'age': tensor(77.),
 'class_label': tensor(1),
 'class_name': 'Abnormal',
 'serial': '01258',
 'signal': [tensor([[ 66.,  77.,  79.,  ..., -16., -13.,  -9.],
        [ 43.,  42.,  35.,  ...,   5., -11., -27.],
        [-35., -31., -26.,  ..., -36., -32., -28.],
        ...,
        [ -9.,  -9.,  -6.,  ..., -11., -10.,  -9.],
        [-10.,  -8.,  -6.,  ...,   4.,   7.,   9.],
        [-28., -20.,   2.,  ...,  -7.,   4., -20.]]),
            tensor([[ -9., -19., -26.,  ...,  29.,  24.,  10.],
        [ -8., -11., -14.,  ...,  18.,  24.,  20.],
        [ 15.,  18.,  20.,  ..., -27., -27., -27.],
        ...,
        [  3.,   4.,   4.,  ...,  -4.,  -2.,   2.],
        [  8.,   9.,   7.,  ...,  -1.,   0.,   4.],
        [-22., -30., -11.,  ...,  31.,  27.,   9.]]),
            tensor([[ 23.,  16.,  10.,  ..., -53., -51., -53.],
        [-22.,  -7.,  -1.,  ..., -28., -21., -20.],
        [ 29.,  18.,  13.,  ..., -32., -35., -32.],
        ...,
        [  5.,   1.,  -1.,  ..., -12., -12., -

---

## PyTorch DataLoader

In [24]:
if device.type == 'cuda':
    num_workers = 0  # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

In [25]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='memmap',
                                                         transform=transform)

full_loader = DataLoader(full_eeg_dataset,
                         batch_size=4,
                         shuffle=True,
                         drop_last=True,
                         num_workers=num_workers,
                         pin_memory=pin_memory,
                         collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(full_loader):
    pprint.pprint(sample_batched)
    break

{'age': tensor([75., 75., 74., 74., 82., 82., 80., 80.]),
 'serial': ['01275',
            '01275',
            '00855',
            '00855',
            '00326',
            '00326',
            '00007',
            '00007'],
 'signal': tensor([[[   0.,    0.,   -2.,  ...,    4.,    5.,    8.],
         [   0.,   -2.,   -3.,  ...,    5.,    7.,    8.],
         [  -2.,   -3.,   -4.,  ...,    5.,    7.,    6.],
         ...,
         [  -2.,    0.,    1.,  ...,   10.,    9.,    7.],
         [   0.,    0.,    0.,  ...,    1.,    0.,    0.],
         [  17.,   13.,   10.,  ...,  142.,  133.,  133.]],

        [[   0.,    1.,    0.,  ...,  -21.,  -15.,  -14.],
         [   1.,    2.,    1.,  ...,    0.,    2.,    4.],
         [  -4.,   -4.,   -4.,  ...,   18.,   20.,   21.],
         ...,
         [  -3.,   -1.,    0.,  ...,    9.,    7.,    6.],
         [ -10.,   -9.,   -8.,  ...,    5.,    5.,    5.],
         [  11.,   27.,   37.,  ...,   62.,   33.,    8.]],

        [[ -12.,   -9.

In [26]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=8,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    pprint.pprint(sample_batched, width=250)
    break

{'age': tensor([71., 71., 77., 77., 86., 86., 57., 57., 76., 76., 80., 80., 78., 78.,
        73., 73.]),
 'class_label': tensor([1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]),
 'class_name': ['Abnormal', 'Abnormal', 'Normal', 'Normal', 'Abnormal', 'Abnormal', 'Abnormal', 'Abnormal', 'Abnormal', 'Abnormal', 'Abnormal', 'Abnormal', 'Abnormal', 'Abnormal', 'Normal', 'Normal'],
 'serial': ['00671', '00671', '00672', '00672', '00405', '00405', '01218', '01218', '00763', '00763', '00258', '00258', '00288', '00288', '01152', '01152'],
 'signal': tensor([[[-2.4000e+01, -2.2000e+01, -2.4000e+01,  ...,  1.1000e+01,
           8.0000e+00,  6.0000e+00],
         [-2.0000e+00, -1.0000e+00, -3.0000e+00,  ...,  4.0000e+00,
           0.0000e+00, -1.0000e+00],
         [-5.0000e+00, -4.0000e+00, -3.0000e+00,  ...,  6.0000e+00,
           4.0000e+00,  1.0000e+00],
         ...,
         [ 2.0000e+00,  3.0000e+00,  1.0000e+00,  ...,  2.0000e+00,
           0.0000e+00, -1.0000e+00],
         [-3.0000

---

## Preprocessing steps run by the PyTorch Modules

In [27]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=2,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

### To GPU device if it is possible

In [28]:
print('device:', device)
print()

preprocess_train = transforms.Compose([EegToDevice(device=device)])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched)
    break

device: cuda

Sequential(
  (0): EegToDevice(device=cuda)
)
- Before -
{'age': tensor([58., 58., 67., 67.]),
 'class_label': tensor([1, 1, 0, 0]),
 'class_name': ['MCI', 'MCI', 'Normal', 'Normal'],
 'serial': ['01176', '01176', '01320', '01320'],
 'signal': tensor([[[-46., -48., -46.,  ..., -52., -47., -50.],
         [ -6.,  -6.,  -5.,  ..., -14., -15., -15.],
         [ -3.,  -3.,  -4.,  ...,  -2.,  -2.,  -1.],
         ...,
         [ 20.,  19.,  18.,  ...,  -9.,  -9.,  -9.],
         [ -2.,  -1.,  -2.,  ..., -13., -13., -14.],
         [-79., -70., -54.,  ..., -82., -80., -80.]],

        [[290., 287., 295.,  ...,   7.,   0.,  -6.],
         [ 42.,  35.,  29.,  ...,   5.,   3.,   5.],
         [ -5.,  -6.,  -9.,  ...,   6.,   8.,   8.],
         ...,
         [-11., -12., -14.,  ...,  -4.,  -4.,  -5.],
         [-10.,  -8.,  -7.,  ...,   0.,   1.,   3.],
         [-35., -44., -43.,  ...,  81.,  -9., -73.]],

        [[  0.,   0.,   0.,  ...,  14.,  13.,  13.],
         [-15., -16.,

### Normalization per signal

In [29]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizePerSignal()
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizePerSignal(eps=1e-08)
)
- Before -
Mean: tensor([[ 2.5840e+00,  7.3250e-01, -8.6700e-01, -1.1030e+00, -2.5000e-01,
          3.5350e-01, -4.0000e-03, -1.2645e+00, -3.5600e-01, -7.4000e-02,
          4.6300e-01, -8.4550e-01, -9.5000e-03, -7.7550e-01,  3.9500e-02,
          6.6800e-01,  2.3065e+00, -5.1000e-02, -2.6750e-01, -1.1885e+00],
        [ 1.0295e+00,  3.5005e+00,  2.0060e+00,  1.4000e+00,  8.8900e-01,
         -4.7830e+00, -1.2810e+00, -1.0960e+00,  1.3050e-01,  2.9650e-01,
          5.4350e-01,  1.8190e+00,  1.2685e+00, -4.4585e+00, -1.1405e+00,
          2.1350e-01, -2.6330e+00, -1.4085e+00,  8.6300e-01,  9.8350e-01],
        [-2.7827e+01, -7.0015e+00, -2.3280e+00, -5.0900e-01,  4.4250e-01,
         -1.3829e+01,  2.1150e-01,  8.5050e-01,  4.3100e+00,  3.0140e+00,
         -3.2100e+00, -5.8150e-01, -4.9750e-01,  1.4759e+01,  6.0050e+00,
          7.9075e+00,  2.1950e-01, -6.1400e-01,  3.3075e+00,  2.0575e+00],
      

### Signal normalization using the specified mean and std values

In [30]:
signal_mean, signal_std = calculate_signal_statistics(train_loader, repeats=1, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Mean and standard deviation for signal:
tensor([[[ 0.3848],
         [ 0.0929],
         [ 0.0679],
         [ 0.1235],
         [ 0.2628],
         [-0.1592],
         [-0.0261],
         [-0.0366],
         [-0.1227],
         [ 0.1542],
         [-0.1571],
         [-0.0203],
         [ 0.0513],
         [-0.3129],
         [-0.1915],
         [-0.1468],
         [ 0.1375],
         [-0.0133],
         [ 0.0477],
         [ 0.0199]]])
-
tensor([[[46.6161],
         [21.1672],
         [11.7478],
         [11.2979],
         [14.8326],
         [49.3067],
         [19.5424],
         [10.4230],
         [11.6122],
         [15.3615],
         [21.0841],
         [14.5436],
         [13.6893],
         [21.1969],
         [17.4971],
         [13.7927],
         [19.8298],
         [11.1799],
         [11.2053],
         [94.0091]]])

----------------------------------------------------------------------------------------------------

Sequential(
  (0): EegToDevice(device=cuda)
  (1): 

### Age normalization

In [31]:
age_mean, age_std = calculate_age_statistics(train_loader, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['age'])

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['age'])
    break

Age mean and standard deviation:
tensor([71.2453]) tensor([6.3654])

----------------------------------------------------------------------------------------------------

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([6.3654]),eps=1e-08)
)
- Before -
tensor([74., 74., 83., 83.])

----------------------------------------------------------------------------------------------------

- After -
tensor([0.4328, 0.4328, 1.8466, 1.8466], device='cuda:0')


### Short time Fourier transform (STFT or spectrogram)

In [32]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['signal'].shape)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['signal'].shape)
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
)
- Before -
torch.Size([4, 20, 2000])

----------------------------------------------------------------------------------------------------

- After -
torch.Size([4, 40, 101, 41])


### Signal normalization after STFT

In [33]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)

signal_2d_mean, signal_2d_std = calculate_signal_statistics(train_loader, preprocess_train)

preprocess_train2 = transforms.Compose([
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std)
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)

pprint.pprint(preprocess_train)
pprint.pprint(preprocess_train2)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    preprocess_train(sample_batched)   
    
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    
    print()
    print('-' * 100)
    print()
    
    print('- After -')
    preprocess_train2(sample_batched)
    
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegSpectrogram(n_fft=200, complex_mode=as_real, stft_kwargs={})
)
Sequential(
  (0): EegNormalizeMeanStd(mean=tensor([[ 1.5429e+01,  6.4722e-01,  3.8627e-02,  ...,  2.4955e-02,
            2.2957e-02,  2.9620e-02],
          [ 7.0377e+00,  3.1176e-01, -5.6557e-02,  ...,  1.6799e-02,
            1.8693e-02,  1.0362e-02],
          [-8.2014e+00,  9.6021e-02, -6.2677e-02,  ..., -1.7618e-02,
           -1.6574e-02, -2.3030e-03],
          ...,
          [ 0.0000e+00,  1.1431e+00,  6.7176e-01,  ...,  1.7511e-03,
            2.2550e-03, -1.6554e-08],
          [ 0.0000e+00,  1.1292e+00,  7.2300e-01,  ...,  1.3086e-03,
            7.3269e-04, -6.1795e-09],
          [ 0.0000e+00,  2.4342e+00,  1.1617e+00,  ...,  6.9995e-04,
           -7.2238e-04, -4.8062e-08]], device='cuda:0'),std=tensor([[7.3918e+03, 1.5109e+03, 7.8873e+02,  ..., 2.8842e+01, 2.8815e+01,
           2.9036e+01],
          [3.2748e+03, 5.9179e+02, 3.0332e+02,  ..., 1.3358e+01

---

## Speed check without STFT

In [34]:
crop_length = 200 * 10
multiple = 4
batch_size = 128

### `EDF`

In [None]:
%%time
transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

### `memmap`

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

### `memmap` (Drop → Crop)

In [None]:
%%time

transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

---

## Speed check with STFT

In [36]:
crop_length = 300 * 10
n_fft, hop_length, seq_len_2d = calculate_stft_params(seq_length=crop_length, verbose=True)
multiple = 2
batch_size = 128

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
signal_2d_mean, signal_2d_std = calculate_signal_statistics(train_loader, preprocess_train)

Input sequence length: (3000) would become (78, 77) after the STFT with n_fft (155) and hop_length (39).

----------------------------------------------------------------------------------------------------



### `EDF`

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

### `memmap`

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

---

## Test on longer sequence

In [None]:
%%time
longer_transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10*6,     # crop: 1m
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(longer_transform)

config_data, longer_test_dataset = load_caueeg_task_split(dataset_path=data_path, 
                                                          task='dementia', 
                                                          split='test',
                                                          load_event=False,
                                                          file_format='memmap', 
                                                          transform=longer_transform)

longer_test_loader = DataLoader(longer_test_dataset,
                                batch_size=32,
                                shuffle=True,
                                drop_last=False,
                                num_workers=num_workers,
                                pin_memory=pin_memory,
                                collate_fn=eeg_collate_fn)
 
preprocess_test = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
    EegNormalizeAge(mean=age_mean, std=age_std),
])
preprocess_test = torch.nn.Sequential(*preprocess_test.transforms).to(device)
pprint.pprint(preprocess_test)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_test(sample_batched)

---

## Resampling

In [70]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=200*200,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train1 = transforms.Compose([
    EegToDevice(device=device), 
    EegResample(orig_freq=200, new_freq=250, resampling_method='kaiser_best'),
    EegResample(orig_freq=250, new_freq=200, resampling_method='kaiser_best'),
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train1 = torch.nn.Sequential(*preprocess_train1.transforms).to(device)
pprint.pprint(preprocess_train1)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train2 = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)
pprint.pprint(preprocess_train2)

diff = 0.0
for e in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        from copy import deepcopy
        sb1 = deepcopy(sample_batched)
        sb2 = deepcopy(sample_batched)

        preprocess_train1(sb1)
        preprocess_train2(sb2)
        
        diff += (torch.norm(sb1['signal'] - sb2['signal']) / torch.sqrt(torch.norm(sb1['signal'])) / torch.sqrt(torch.norm(sb1['signal']))).item()
        
print(diff)

Compose(
    EegRandomCrop(crop_length=40000, length_limit=120000, multiple=2, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegResample(resampling_method='kaiser_best', resampler=Resample())
  (2): EegResample(resampling_method='kaiser_best', resampler=Resample())
  (3): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([6.3654]),eps=1e-08)
  (4): EegNormalizeMeanStd(mean=tensor([ 0.3848,  0.0929,  0.0679,  0.1235,  0.2628, -0.1592, -0.0261, -0.0366,
          -0.1227,  0.1542, -0.1571, -0.0203,  0.0513, -0.3129, -0.1915, -0.1468,
           0.1375, -0.0133,  0.0477,  0.0199]),std=tensor([46.6161, 21.1672, 11.7478, 11.2979, 14.8326, 49.3067, 19.5424, 10.4230,
          11.6122, 15.3615, 21.0841, 14.5436, 13.6893, 21.1969, 17.4971, 13.7927,
          19.8298, 11.1799, 11.2053, 94.0091]),eps=1e-08)
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=

In [71]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=200*200,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train1 = transforms.Compose([
    EegToDevice(device=device), 
    EegResample(orig_freq=200, new_freq=250, resampling_method='kaiser_fast'),
    EegResample(orig_freq=250, new_freq=200, resampling_method='kaiser_fast'),
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train1 = torch.nn.Sequential(*preprocess_train1.transforms).to(device)
pprint.pprint(preprocess_train1)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train2 = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)
pprint.pprint(preprocess_train2)

diff = 0.0
for e in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        from copy import deepcopy
        sb1 = deepcopy(sample_batched)
        sb2 = deepcopy(sample_batched)

        preprocess_train1(sb1)
        preprocess_train2(sb2)
        
        diff += (torch.norm(sb1['signal'] - sb2['signal']) / torch.sqrt(torch.norm(sb1['signal'])) / torch.sqrt(torch.norm(sb1['signal']))).item()
        
print(diff)

Compose(
    EegRandomCrop(crop_length=40000, length_limit=120000, multiple=2, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegResample(resampling_method='kaiser_fast', resampler=Resample())
  (2): EegResample(resampling_method='kaiser_fast', resampler=Resample())
  (3): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([6.3654]),eps=1e-08)
  (4): EegNormalizeMeanStd(mean=tensor([ 0.3848,  0.0929,  0.0679,  0.1235,  0.2628, -0.1592, -0.0261, -0.0366,
          -0.1227,  0.1542, -0.1571, -0.0203,  0.0513, -0.3129, -0.1915, -0.1468,
           0.1375, -0.0133,  0.0477,  0.0199]),std=tensor([46.6161, 21.1672, 11.7478, 11.2979, 14.8326, 49.3067, 19.5424, 10.4230,
          11.6122, 15.3615, 21.0841, 14.5436, 13.6893, 21.1969, 17.4971, 13.7927,
          19.8298, 11.1799, 11.2053, 94.0091]),eps=1e-08)
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2453]),std=

In [72]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=200*200,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train1 = transforms.Compose([
    EegToDevice(device=device), 
    EegResample(orig_freq=200, new_freq=250),
    EegResample(orig_freq=250, new_freq=200),
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train1 = torch.nn.Sequential(*preprocess_train1.transforms).to(device)
pprint.pprint(preprocess_train1)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train2 = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)
pprint.pprint(preprocess_train2)

diff = 0.0
for e in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        from copy import deepcopy
        sb1 = deepcopy(sample_batched)
        sb2 = deepcopy(sample_batched)

        preprocess_train1(sb1)
        preprocess_train2(sb2)
        
        diff += (torch.norm(sb1['signal'] - sb2['signal']) / torch.sqrt(torch.norm(sb1['signal'])) / torch.sqrt(torch.norm(sb1['signal']))).item()
        
print(diff)

Compose(
    EegRandomCrop(crop_length=40000, length_limit=120000, multiple=2, latency=2000, return_timing=False)
    EegDropChannels(drop_index=20)
    EegToTensor()
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegResample(resampling_method='sinc_interpolation', resampler=Resample())
  (2): EegResample(resampling_method='sinc_interpolation', resampler=Resample())
  (3): EegNormalizeAge(mean=tensor([71.2453]),std=tensor([6.3654]),eps=1e-08)
  (4): EegNormalizeMeanStd(mean=tensor([ 0.3848,  0.0929,  0.0679,  0.1235,  0.2628, -0.1592, -0.0261, -0.0366,
          -0.1227,  0.1542, -0.1571, -0.0203,  0.0513, -0.3129, -0.1915, -0.1468,
           0.1375, -0.0133,  0.0477,  0.0199]),std=tensor([46.6161, 21.1672, 11.7478, 11.2979, 14.8326, 49.3067, 19.5424, 10.4230,
          11.6122, 15.3615, 21.0841, 14.5436, 13.6893, 21.1969, 17.4971, 13.7927,
          19.8298, 11.1799, 11.2053, 94.0091]),eps=1e-08)
)
Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([