# Dataset and DataLoader

This notebook loads the `CAUEEG` dataset, tests some useful preprocessing, and makes up the PyTorch DataLoader instances for the training.

-----

## Configurations

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
import json
import pprint

import numpy as np
import random
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

# custom package
from datasets.caueeg_dataset import *
from datasets.caueeg_script import *
from datasets.pipeline import *

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.12.1+cu113
cuda is available.


In [4]:
# Data file path
data_path = r'local/dataset/caueeg-dataset/'

In [5]:
for task in ['annotation.json', 'abnormal.json', 'dementia.json']:
    task_path = os.path.join(data_path, task)
    with open(task_path, 'r') as json_file:
        task_dict = json.load(json_file)
        
    print('{')
    for k, v in task_dict.items():
        print(f'\t{k}:')
        if isinstance(v, list) and len(v) > 3:
            print(f'\t\t{v[0]}')
            print(f'\t\t{v[1]}')
            print(f'\t\t.')
            print(f'\t\t.')
            print(f'\t\t.')
            print(f'\t\t{v[-1]}')
        else:
            print(f'\t\t{v}')
        print()
    print('}')
    print('\n' + '-' * 100 + '\n')

{
	dataset_name:
		CAUEEG dataset

	signal_header:
		Fp1-AVG
		F3-AVG
		.
		.
		.
		Photic

	data:
		{'serial': '00001', 'age': 78, 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}
		{'serial': '00002', 'age': 56, 'symptom': ['normal', 'smi']}
		.
		.
		.
		{'serial': '01388', 'age': 73, 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_ef']}

}

----------------------------------------------------------------------------------------------------

{
	task_name:
		CAUEEG-Abnormal benchmark

	task_description:
		Classification of [Normal] and [Abnormal] symptoms

	class_label_to_name:
		['Normal', 'Abnormal']

	class_name_to_label:
		{'Normal': 0, 'Abnormal': 1}

	train_split:
		{'serial': '01258', 'age': 77, 'symptom': ['dementia', 'vd', 'sivd'], 'class_name': 'Abnormal', 'class_label': 1}
		{'serial': '00836', 'age': 80, 'symptom': ['normal', 'smi'], 'class_name': 'Normal', 'class_label': 0}
		.
		.
		.
		{'serial': '00105', 'age': 71, 'symptom': ['normal', 'smi'], 'class_name': 'N

In [6]:
for task in ['abnormal.json']:
    task_path = os.path.join(data_path, task)
    with open(task_path, 'r') as json_file:
        task_dict = json.load(json_file)
        
    diagnosis_dict = {}
    symptom_dict = {}
    for split in ['train_split', 'validation_split', 'test_split']:
        for data in task_dict[split]:
            diagnosis_dict[data['class_name']] = diagnosis_dict.get(data['class_name'], 0) + 1
            
            if 'parkinson_dementia'in data['symptom']:
                continue

            for symp in data['symptom']:
                symptom_dict[symp] = symptom_dict.get(symp, 0) + 1

pprint.pprint(diagnosis_dict)
pprint.pprint(symptom_dict)

{'Abnormal': 920, 'Normal': 459}
{'ad': 230,
 'ad_vd_mixed': 1,
 'bvftd': 10,
 'cb_normal': 260,
 'dementia': 310,
 'eoad': 41,
 'ftd': 14,
 'hc_normal': 11,
 'load': 189,
 'mci': 417,
 'mci_ad': 17,
 'mci_amnestic': 334,
 'mci_amnestic_ef': 156,
 'mci_amnestic_rf': 114,
 'mci_multi_domain': 10,
 'mci_non_amnestic': 18,
 'mci_vascular': 70,
 'non_fluent_aphasia': 2,
 'normal': 459,
 'nph': 13,
 'parkinson_disease': 63,
 'parkinson_synd': 93,
 'semantic_aphasia': 2,
 'sivd': 79,
 'smi': 186,
 'tga': 72,
 'vd': 79}


-----

## Load the CAUEEG dataset

### Load the whole CAUEEG data as a PyTorch dataset instance without considering the target task (no train/val/test sets and no class label).

In [7]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='edf',
                                                         transform=None)

pprint.pprint(config_data, width=250)
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[1])
print('\n', '-' * 100, '\n')

{'dataset_name': 'CAUEEG dataset',
 'signal_header': ['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 78,
 'serial': '00001',
 'signal': array([[  0., -11., -13., ...,   0.,   0.,   0.],
       [ 29.,  33.,  34., ...,   0.,   0.,   0.],
       [ -3.,  -6.,  -3., ...,   0.,   0.,   0.],
       ...,
       [ -4.,  -2.,   1., ...,   0.,   0.,   0.],
       [112.,  67.,  76., ...,   0.,   0.,   0.],
       [ -1.,  -1.,  -1., ...,   0.,   0.,   0.]]),
 'symptom': ['mci', 'mci_amnestic', 'mci_amnestic_rf']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 56,
 'serial': '00002',
 'signal': array([[  39.,   58.,   72., ...,    0.,    0.,    0.],
       [   4.,

### Load the CAUEEG-Abnormal benchmark using the PyTorch dataset instances.

In [8]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')

pprint.pprint(train_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(val_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'Abnormal'],
 'class_name_to_label': {'Abnormal': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal] and [Abnormal] symptoms',
 'task_name': 'CAUEEG-Abnormal benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 77,
 'class_label': 1,
 'class_name': 'Abnormal',
 'serial': '01258',
 'signal': array([[ 3., -1., -5., ...,  0.,  0.,  0.],
       [ 7., 15.,  9., ...,  0.,  0.,  0.],
       [-6., -5., -3., ...,  0.,  0.,  0.],
       ...,
       [ 4.,  6.,  5., ...,  0.,  0.,  0.],
       [62., 54., 53., ...,  0.,  0.,  0.],
       [ 0.,  0.,  1., ...,  0.,  0.,  0.]]),
 'symptom': ['dementia', 'vd', 'sivd']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 81,
 'class_label': 1,
 'class_name': 'Abnormal',
 'serial': '00152',
 'signal': array([[ 14.,  10.,   2., ...,   0.,   0.,   0.],
       [  8.,   6.,   0.

In [9]:
config, test_dataset = load_caueeg_task_split(dataset_path='local/dataset/caueeg-dataset/', 
                                              task='abnormal', split='test', load_event=False)

pprint.pprint(config_data)
print('\n', '-' * 100, '\n')
pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'Abnormal'],
 'class_name_to_label': {'Abnormal': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal] and [Abnormal] symptoms',
 'task_name': 'CAUEEG-Abnormal benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 57,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00560',
 'signal': array([[  7.,  41.,  49., ...,   0.,   0.,   0.],
       [ 23.,  26.,  27., ...,   0.,   0.,   0.],
       [-16., -20., -18., ...,   0.,   0.,   0.],
       ...,
       [ 32.,  31.,  31., ...,   0.,   0.,   0.],
       [177., 228., 142., ...,   0.,   0.,   0.],
       [  0.,  -1.,   0., ...,   0.,   0.,   0.]]),
 'symptom': ['normal', 'cb_normal']}


### Load the CAUEEG-Dementia benchmark using the PyTorch dataset instances.

In [10]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')

pprint.pprint(train_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(val_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(test_dataset[0])

{'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal], [MCI], and [Dementia] '
                     'symptoms.',
 'task_name': 'CAUEEG-Dementia benchmark'}

 ---------------------------------------------------------------------------------------------------- 

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': array([[30., 15., 18., ...,  0.,  0.,  0.],
       [-3.,  4.,  5., ...,  0.,  0.,  0.],
       [-2.,  7.,  8., ...,  0.,  0.,  0.],
       ...,
       [ 0.,  5.,  7., ...,  0.,  0.,  0.],
       [27., 27., 34., ...,  0.,  0.,  0.],
       [ 0.,  0., -1., ...,  0.,  0.,  0.]]),
 'symptom': ['normal', 'cb_normal']}

 ---------------------------------------------------------------------------------------------------- 

{'age': 80,
 'class_label': 2,
 'class_name': 'Dementia',
 'serial': '00341',
 'signal': array([[ -9.,  -2.,  -3., ...

In [11]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='feather', 
                                                                                  transform=None)

num_train = [0, 0]
for d in train_dataset:
    num_train[d['class_label']] += 1
print('train', num_train, sum(num_train))

num_val = [0, 0]
for d in val_dataset:
    num_val[d['class_label']] += 1
print('val', num_val, sum(num_val))
        
num_test = [0, 0]
for d in test_dataset:
    num_test[d['class_label']] += 1
print('test', num_test, sum(num_test))
           
print()
print('total', [num1 + num2 + num3 for num1, num2, num3 in zip(num_train, num_val, num_test)], sum(num_train + num_val + num_test))

train [367, 740] 1107
val [46, 90] 136
test [46, 90] 136

total [459, 920] 1379


In [12]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)

num_train = [0, 0, 0]
for d in train_dataset:
    num_train[d['class_label']] += 1
print('train', num_train, sum(num_train))

num_val = [0, 0, 0]
for d in val_dataset:
    num_val[d['class_label']] += 1
print('val', num_val, sum(num_val))
        
num_test = [0, 0, 0]
for d in test_dataset:
    num_test[d['class_label']] += 1
print('test', num_test, sum(num_test))
           
print()
print('total', [num1 + num2 + num3 for num1, num2, num3 in zip(num_train, num_val, num_test)], sum(num_train + num_val + num_test))

train [367, 334, 249] 950
val [46, 42, 31] 119
test [46, 41, 31] 118

total [459, 417, 311] 1187


### Event information

In [13]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=True, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(train_dataset[0])

{'age': 77,
 'class_label': 1,
 'class_name': 'Abnormal',
 'event': [[0, 'Start Recording'],
           [0, 'New Montage - Montage 002'],
           [1358, 'Eyes Open'],
           [6146, 'Eyes Closed'],
           [11984, 'Eyes Open'],
           [22232, 'Eyes Closed'],
           [27650, 'Eyes Open'],
           [30170, 'Eyes Closed'],
           [36218, 'Eyes Open'],
           [42098, 'Eyes Closed'],
           [48692, 'Eyes Open'],
           [54236, 'Eyes Closed'],
           [65324, 'Eyes Open'],
           [66668, 'Eyes Closed'],
           [72548, 'Eyes Open'],
           [77408, 'Move'],
           [78848, 'Eyes Closed'],
           [84056, 'Eyes Open'],
           [90272, 'Eyes Closed'],
           [96782, 'Eyes Open'],
           [102326, 'Eyes Closed'],
           [108583, 'Eyes Open'],
           [114170, 'Eyes Closed'],
           [121438, 'Photic On - 3.0 Hz'],
           [123454, 'Photic Off'],
           [125511, 'Photic On - 6.0 Hz'],
           [127528, 'Photic Off'

### Data Format: `EDF`

In [14]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'serial': '00587', 'age': 53, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0, 'signal': array([[30., 15., 18., ...,  0.,  0.,  0.],
       [-3.,  4.,  5., ...,  0.,  0.,  0.],
       [-2.,  7.,  8., ...,  0.,  0.,  0.],
       ...,
       [ 0.,  5.,  7., ...,  0.,  0.,  0.],
       [27., 27., 34., ...,  0.,  0.,  0.],
       [ 0.,  0., -1., ...,  0.,  0.,  0.]])}
{'serial': '01301', 'age': 88, 'symptom': ['dementia', 'ad', 'load'], 'class_name': 'Dementia', 'class_label': 2, 'signal': array([[ 18.,  15.,  13., ...,   0.,   0.,   0.],
       [ -2.,   1.,   1., ...,   0.,   0.,   0.],
       [  5.,   0.,  -1., ...,   0.,   0.,   0.],
       ...,
       [  2.,  -3.,  -3., ...,   0.,   0.,   0.],
       [ 51., 115., 103., ...,   0.,   0.,   0.],
       [  0.,  -1.,  -1., ...,   0.,   0.,   0.]])}
CPU times: total: 141 ms
Wall time: 147 ms


### Data Format: `NumPy Memmap`

In [15]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

{'serial': '00587', 'age': 53, 'symptom': ['normal', 'cb_normal'], 'class_name': 'Normal', 'class_label': 0, 'signal': memmap([[ 30,  15,  18, ..., -42, -40, -38],
        [ -3,   4,   5, ..., -16, -14, -13],
        [ -2,   7,   8, ...,  -3,  -7,  -9],
        ...,
        [  0,   5,   7, ..., -10, -12, -12],
        [ 27,  27,  34, ..., -25, -27, -27],
        [  0,   0,  -1, ...,   1,   1,  -1]])}
{'serial': '01301', 'age': 88, 'symptom': ['dementia', 'ad', 'load'], 'class_name': 'Dementia', 'class_label': 2, 'signal': memmap([[ 18,  15,  13, ...,  36,  37,  39],
        [ -2,   1,   1, ...,  20,  19,  13],
        [  5,   0,  -1, ...,  -5,  -3,   1],
        ...,
        [  2,  -3,  -3, ...,   9,   9,  11],
        [ 51, 115, 103, ...,  -6, -14,  -8],
        [  0,  -1,  -1, ...,  -1,   0,  -1]])}
CPU times: total: 0 ns
Wall time: 5 ms


---

## PyTorch Transforms

### Random crop

In [16]:
transform = EegRandomCrop(crop_length=100)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path,
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', d['signal'].shape)
    print('\n', '-' * 100, '\n')

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': memmap([[ 56,  53,  53, ...,  45,  39,  40],
        [ 14,  13,  12, ...,   4,   3,   3],
        [  7,   5,   5, ...,   1,   1,   1],
        ...,
        [  1,   2,   2, ...,   1,   0,   0],
        [ 55,  49,  47, ..., -25, -25, -24],
        [ -1,   1,   1, ...,   0,  -1,  -1]]),
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: (21, 100)

 ---------------------------------------------------------------------------------------------------- 

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': memmap([[-49, -46, -42, ..., -51, -56, -57],
        [-12,  -8,  -4, ...,   1,   3,   3],
        [ -4,  -5,  -6, ...,   0,   2,   3],
        ...,
        [  9,   3,  -2, ...,   0,   1,   2],
        [-24, -26, -27, ..., 187, 121,  87],
        [  0,   0,   0, ...,   0,   0,   0]]),
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: (21, 100)

 -----------------

### Random crop with multiple cropping

In [17]:
transform = EegRandomCrop(crop_length=200, multiple=2)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': [memmap([[ 73,  78,  80, ...,  76,  82,  86],
        [-30, -24, -18, ...,  -3,  -4,  -3],
        [-11, -16, -19, ...,   2,   0,  -2],
        ...,
        [ -3,  -9, -12, ...,   5,   4,   2],
        [-33, -39, -41, ..., -17, -17, -17],
        [  0,  -1,  -1, ...,  -1,   0,  -1]]),
            memmap([[-71, -71, -65, ...,  27,  29,  35],
        [-16, -16, -14, ...,   7,   5,   4],
        [-10, -11, -10, ...,   7,   5,   5],
        ...,
        [  7,   4,   3, ...,   5,   7,   7],
        [ -1,  -5,  -5, ..., -27, -28, -29],
        [ -1,  -1,   0, ...,  -1,  -1,  -1]])],
 'symptom': ['normal', 'cb_normal']}

>>> signal shape: [(21, 200), (21, 200)]

 ---------------------------------------------------------------------------------------------------- 

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'serial': '00587',
 'signal': [memmap([[  30,   35,   39, ...,   36,   32,   28],
      

### Random crop with multiple cropping and latency

In [18]:
transform = EegRandomCrop(crop_length=300, multiple=3, latency=50000, return_timing=True)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2): 
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 53,
 'class_label': 0,
 'class_name': 'Normal',
 'crop_timing': [54875, 101594, 75938],
 'serial': '00587',
 'signal': [memmap([[  7,   7,   8, ...,  50,  52,  51],
        [  6,   5,   6, ...,   3,   5,   8],
        [  7,   9,   9, ...,   5,   6,   6],
        ...,
        [  6,   7,   7, ...,   2,   2,   1],
        [-18, -17, -16, ...,   5,   2,  -1],
        [  2,  -1,   1, ...,   1,  -1,   1]]),
            memmap([[-14, -11,  -6, ...,  54,  64,  75],
        [-16, -13, -11, ...,  -1,   0,   1],
        [  0,   1,   1, ...,   1,  -1,  -2],
        ...,
        [ -1,   0,   0, ...,   1,   0,   0],
        [ -5,  -6,  -5, ...,   2,   0,  -2],
        [  1,   0,   0, ...,  -1,   0,   0]]),
            memmap([[132, 138, 141, ...,  68,  60,  47],
        [ 30,  30,  29, ...,  13,  11,  11],
        [ -3,  -2,  -2, ...,   0,  -1,  -3],
        ...,
        [-20, -20, -18, ...,  -3,  -3,  -3],
        [100, 103,  93, ...,   2,  -1,  -3],
        [ -1,  -1,   2, ...,  -1,  -1,  

### Random crop with multiple cropping, latency, and max length limit

In [19]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200, 
                  length_limit=50300,
                  multiple=3, 
                  latency=50000, 
                  return_timing=True)
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 77,
 'class_label': 1,
 'class_name': 'Abnormal',
 'crop_timing': [50032, 50010, 50008],
 'serial': '01258',
 'signal': [memmap([[  35,   42,   52, ...,   41,   33,   24],
        [  54,   49,   43, ...,   33,   14,   -3],
        [ -23,  -24,  -23, ...,  -46,  -47,  -46],
        ...,
        [  -2,   -2,   -2, ...,   -2,   -2,   -5],
        [-384, -259,  -58, ...,   45,   50,   67],
        [   0,   -1,   -1, ...,   -2,   -3,   -1]]),
            memmap([[  11,    8,   14, ...,   23,   22,   18],
        [  13,   32,   41, ...,   23,   21,   19],
        [ -17,  -18,  -20, ...,  -42,  -41,  -42],
        ...,
        [   1,   -2,   -4, ...,   -3,   -5,   -6],
        [  53,   64,   48, ..., -424, -264,  -42],
        [   2,    3,   -1, ...,   -1,    2,    2]]),
            memmap([[   9,   10,   11, ...,    2,   12,   23],
        [  60,   38,   13, ...,   18,   32,   23],
        [ -21,  -20,  -17, ...,  -40,  -43,  -42],
        ...,
        [  -5,   -1,    1, ...,    3,  

### Random crop with event rejection

In [104]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=2000, 
                  multiple=256, 
                  latency=10000, 
                  return_timing=True, 
                  reject_events=True)
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=True, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[i]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

{'age': 77,
 'class_label': 1,
 'class_name': 'Abnormal',
 'crop_timing': [69599,
                 57216,
                 22531,
                 73052,
                 180464,
                 22536,
                 178740,
                 86541,
                 85152,
                 68598,
                 117913,
                 15770,
                 79671,
                 115948,
                 116904,
                 110014,
                 118749,
                 72567,
                 99756,
                 36381,
                 25634,
                 81190,
                 67240,
                 116908,
                 114917,
                 30430,
                 110539,
                 161118,
                 61649,
                 177592,
                 51520,
                 58402,
                 68691,
                 24901,
                 54264,
                 57419,
                 22458,
                 59173,
                 2

In [109]:
np.arange(20)[-5 + 1:].shape

(4,)

### Drop channel(s)

In [None]:
anno_path = os.path.join(data_path, 'annotation.json')
with open(anno_path, 'r') as json_file:
    annotation = json.load(json_file)
signal_headers = annotation['signal_header']
del annotation
print(signal_headers)

channel_ekg = signal_headers.index('EKG')
print('channel_ekg: ', channel_ekg)

channel_photic = signal_headers.index('Photic')
print('channel_photic: ', channel_photic)

In [None]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=EegDropChannels(channel_ekg))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

In [None]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=EegDropChannels(channel_photic))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

In [None]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=EegDropChannels([channel_ekg, channel_photic]))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

### To Tensor

In [None]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='memmap',
                                                         transform=None)
print('Before:')
pprint.pprint(full_eeg_dataset[0])

print()
print('-' * 100)
print()

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='memmap',
                                                         transform=EegToTensor())
print('After:')
pprint.pprint(full_eeg_dataset[0])

### Compose the above all in one

In [None]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=4, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

pprint.pprint(train_dataset[0])

---

## PyTorch DataLoader

In [None]:
if device.type == 'cuda':
    num_workers = 0  # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

In [None]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='memmap',
                                                         transform=transform)

full_loader = DataLoader(full_eeg_dataset,
                         batch_size=4,
                         shuffle=True,
                         drop_last=True,
                         num_workers=num_workers,
                         pin_memory=pin_memory,
                         collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(full_loader):
    pprint.pprint(sample_batched)
    break

In [None]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=8,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    pprint.pprint(sample_batched, width=250)
    break

---

## Preprocessing steps run by the PyTorch Modules

In [None]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=2,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

### To GPU device if it is possible

In [None]:
print('device:', device)
print()

preprocess_train = transforms.Compose([EegToDevice(device=device)])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched)
    break

### Normalization per signal

In [None]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizePerSignal()
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

### Signal normalization using the specified mean and std values

In [None]:
signal_mean, signal_std = calculate_signal_statistics(train_loader, repeats=1, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

### Age normalization

In [None]:
age_mean, age_std = calculate_age_statistics(train_loader, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['age'])

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['age'])
    break

### Short time Fourier transform (STFT or spectrogram)

In [None]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['signal'].shape)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['signal'].shape)
    break

### Signal normalization after STFT

In [None]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)

signal_2d_mean, signal_2d_std = calculate_signal_statistics(train_loader, preprocess_train)

preprocess_train2 = transforms.Compose([
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std)
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)

pprint.pprint(preprocess_train)
pprint.pprint(preprocess_train2)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    preprocess_train(sample_batched)   
    
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    
    print()
    print('-' * 100)
    print()
    
    print('- After -')
    preprocess_train2(sample_batched)
    
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

---

## Speed check without STFT

In [None]:
crop_length = 200 * 10
multiple = 4
batch_size = 128

### `EDF`

In [None]:
%%time
transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

### `memmap`

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

### `memmap` (Drop → Crop)

In [None]:
%%time

transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

---

## Speed check with STFT

In [None]:
crop_length = 300 * 10
n_fft, hop_length, seq_len_2d = calculate_stft_params(seq_length=crop_length, verbose=True)
multiple = 2
batch_size = 128

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
signal_2d_mean, signal_2d_std = calculate_signal_statistics(train_loader, preprocess_train)

### `EDF`

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

### `memmap`

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

---

## Test on longer sequence

In [None]:
%%time
longer_transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10*6,     # crop: 1m
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(longer_transform)

config_data, longer_test_dataset = load_caueeg_task_split(dataset_path=data_path, 
                                                          task='dementia', 
                                                          split='test',
                                                          load_event=False,
                                                          file_format='memmap', 
                                                          transform=longer_transform)

longer_test_loader = DataLoader(longer_test_dataset,
                                batch_size=32,
                                shuffle=True,
                                drop_last=False,
                                num_workers=num_workers,
                                pin_memory=pin_memory,
                                collate_fn=eeg_collate_fn)
 
preprocess_test = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
    EegNormalizeAge(mean=age_mean, std=age_std),
])
preprocess_test = torch.nn.Sequential(*preprocess_test.transforms).to(device)
pprint.pprint(preprocess_test)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_test(sample_batched)

In [None]:
---

## Resampling

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=200*200,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train1 = transforms.Compose([
    EegToDevice(device=device), 
    EegResample(orig_freq=200, new_freq=250, resampling_method='kaiser_best'),
    EegResample(orig_freq=250, new_freq=200, resampling_method='kaiser_best'),
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train1 = torch.nn.Sequential(*preprocess_train1.transforms).to(device)
pprint.pprint(preprocess_train1)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train2 = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)
pprint.pprint(preprocess_train2)

diff = 0.0
for e in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        from copy import deepcopy
        sb1 = deepcopy(sample_batched)
        sb2 = deepcopy(sample_batched)

        preprocess_train1(sb1)
        preprocess_train2(sb2)
        
        diff += (torch.norm(sb1['signal'] - sb2['signal']) / torch.sqrt(torch.norm(sb1['signal'])) / torch.sqrt(torch.norm(sb1['signal']))).item()
        
print(diff)

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=200*200,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train1 = transforms.Compose([
    EegToDevice(device=device), 
    EegResample(orig_freq=200, new_freq=250, resampling_method='kaiser_fast'),
    EegResample(orig_freq=250, new_freq=200, resampling_method='kaiser_fast'),
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train1 = torch.nn.Sequential(*preprocess_train1.transforms).to(device)
pprint.pprint(preprocess_train1)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train2 = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)
pprint.pprint(preprocess_train2)

diff = 0.0
for e in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        from copy import deepcopy
        sb1 = deepcopy(sample_batched)
        sb2 = deepcopy(sample_batched)

        preprocess_train1(sb1)
        preprocess_train2(sb2)
        
        diff += (torch.norm(sb1['signal'] - sb2['signal']) / torch.sqrt(torch.norm(sb1['signal'])) / torch.sqrt(torch.norm(sb1['signal']))).item()
        
print(diff)

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=200*200,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train1 = transforms.Compose([
    EegToDevice(device=device), 
    EegResample(orig_freq=200, new_freq=250),
    EegResample(orig_freq=250, new_freq=200),
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train1 = torch.nn.Sequential(*preprocess_train1.transforms).to(device)
pprint.pprint(preprocess_train1)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train2 = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)
pprint.pprint(preprocess_train2)

diff = 0.0
for e in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        from copy import deepcopy
        sb1 = deepcopy(sample_batched)
        sb2 = deepcopy(sample_batched)

        preprocess_train1(sb1)
        preprocess_train2(sb2)
        
        diff += (torch.norm(sb1['signal'] - sb2['signal']) / torch.sqrt(torch.norm(sb1['signal'])) / torch.sqrt(torch.norm(sb1['signal']))).item()
        
print(diff)