# Dataset and DataLoader

This notebook loads the `CAUEEG` dataset, tests some useful preprocessing, and makes up the PyTorch DataLoader instances for the training.

-----

## Configurations

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
# Load some packages
import os
import json
import pprint

import numpy as np
import random
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

# custom package
from datasets.caueeg_dataset import *
from datasets.caueeg_script import *
from datasets.pipeline import *

In [None]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

In [None]:
# Data file path
data_path = r'local/dataset/caueeg-dataset/'

In [None]:
for task in ['annotation.json', 'abnormal.json', 'dementia.json']:
    task_path = os.path.join(data_path, task)
    with open(task_path, 'r') as json_file:
        task_dict = json.load(json_file)
        
    print('{')
    for k, v in task_dict.items():
        print(f'\t{k}:')
        if isinstance(v, list) and len(v) > 3:
            print(f'\t\t{v[0]}')
            print(f'\t\t{v[1]}')
            print(f'\t\t.')
            print(f'\t\t.')
            print(f'\t\t.')
            print(f'\t\t{v[-1]}')
        else:
            print(f'\t\t{v}')
        print()
    print('}')
    print('\n' + '-' * 100 + '\n')

In [None]:
for task in ['abnormal.json']:
    task_path = os.path.join(data_path, task)
    with open(task_path, 'r') as json_file:
        task_dict = json.load(json_file)
        
    diagnosis_dict = {}
    symptom_dict = {}
    for split in ['train_split', 'validation_split', 'test_split']:
        for data in task_dict[split]:
            diagnosis_dict[data['class_name']] = diagnosis_dict.get(data['class_name'], 0) + 1
            
            if 'parkinson_dementia'in data['symptom']:
                continue

            for symp in data['symptom']:
                symptom_dict[symp] = symptom_dict.get(symp, 0) + 1

pprint.pprint(diagnosis_dict)
pprint.pprint(symptom_dict)

-----

## Load the CAUEEG dataset

### Load the whole CAUEEG data as a PyTorch dataset instance without considering the target task (no train/val/test sets and no class label).

In [None]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='edf',
                                                         transform=None)

pprint.pprint(config_data, width=250)
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(full_eeg_dataset[1])
print('\n', '-' * 100, '\n')

### Load the CAUEEG-Abnormal benchmark using the PyTorch dataset instances.

In [None]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')

pprint.pprint(train_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(val_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(test_dataset[0])

In [None]:
config, test_dataset = load_caueeg_task_split(dataset_path='local/dataset/caueeg-dataset/', 
                                              task='abnormal', split='test', load_event=False)

pprint.pprint(config_data)
print('\n', '-' * 100, '\n')
pprint.pprint(test_dataset[0])

### Load the CAUEEG-Dementia benchmark using the PyTorch dataset instances.

In [None]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(config_data)
print('\n', '-' * 100, '\n')

pprint.pprint(train_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(val_dataset[0])
print('\n', '-' * 100, '\n')

pprint.pprint(test_dataset[0])

In [None]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)

num_train = [0, 0]
for d in train_dataset:
    num_train[d['class_label']] += 1
print('train', num_train, sum(num_train))

num_val = [0, 0]
for d in val_dataset:
    num_val[d['class_label']] += 1
print('val', num_val, sum(num_val))
        
num_test = [0, 0]
for d in test_dataset:
    num_test[d['class_label']] += 1
print('test', num_test, sum(num_test))
           
print()
print('total', [num1 + num2 + num3 for num1, num2, num3 in zip(num_train, num_val, num_test)], sum(num_train + num_val + num_test))

In [None]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)

num_train = [0, 0, 0]
for d in train_dataset:
    num_train[d['class_label']] += 1
print('train', num_train, sum(num_train))

num_val = [0, 0, 0]
for d in val_dataset:
    num_val[d['class_label']] += 1
print('val', num_val, sum(num_val))
        
num_test = [0, 0, 0]
for d in test_dataset:
    num_test[d['class_label']] += 1
print('test', num_test, sum(num_test))
           
print()
print('total', [num1 + num2 + num3 for num1, num2, num3 in zip(num_train, num_val, num_test)], sum(num_train + num_val + num_test))

### Event information

In [None]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=True, 
                                                                                  file_format='edf', 
                                                                                  transform=None)
pprint.pprint(train_dataset[0])

### Data Format: `EDF`

In [None]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

### Data Format: `NumPy Memmap`

In [None]:
%%time
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)

print(train_dataset[0])
print(train_dataset[1])

---

## PyTorch Transforms

### Random crop

In [None]:
transform = EegRandomCrop(crop_length=100)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path,
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', d['signal'].shape)
    print('\n', '-' * 100, '\n')

### Random crop with multiple cropping

In [None]:
transform = EegRandomCrop(crop_length=200, multiple=2)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

### Random crop with multiple cropping and latency

In [None]:
transform = EegRandomCrop(crop_length=300, multiple=3, latency=50000, return_timing=True)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2): 
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

### Random crop with multiple cropping, latency, and max length limit

In [None]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200, 
                  length_limit=50300,
                  multiple=3, 
                  latency=50000, 
                  return_timing=True)
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[0]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

### Random crop with event rejection

In [None]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=8000, 
                  multiple=8, 
                  latency=10000, 
                  return_timing=True, 
                  reject_events=True)
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=True, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
for i in range(2):
    d = train_dataset[i]
    pprint.pprint(d)
    print()
    print('>>> signal shape:', [signal.shape for signal in d['signal']])
    print('\n', '-' * 100, '\n')

### Drop channel(s)

In [None]:
anno_path = os.path.join(data_path, 'annotation.json')
with open(anno_path, 'r') as json_file:
    annotation = json.load(json_file)
signal_headers = annotation['signal_header']
del annotation
print(signal_headers)

channel_ekg = signal_headers.index('EKG')
print('channel_ekg: ', channel_ekg)

channel_photic = signal_headers.index('Photic')
print('channel_photic: ', channel_photic)

In [None]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap', 
                                                                                  transform=EegDropChannels(channel_ekg))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

In [None]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=EegDropChannels(channel_photic))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

In [None]:
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=None)
print('before:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

print()
print('-' * 100)
print()
config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=EegDropChannels([channel_ekg, channel_photic]))
print('after:', train_dataset[0]['signal'].shape)
print(train_dataset[0]['signal'])

### To Tensor

In [None]:
config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='memmap',
                                                         transform=None)
print('Before:')
pprint.pprint(full_eeg_dataset[0])

print()
print('-' * 100)
print()

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='memmap',
                                                         transform=EegToTensor())
print('After:')
pprint.pprint(full_eeg_dataset[0])

### Compose the above all in one

In [None]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=4, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

pprint.pprint(train_dataset[0])

---

## PyTorch DataLoader

In [None]:
if device.type == 'cuda':
    num_workers = 0  # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

In [None]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, full_eeg_dataset = load_caueeg_full_dataset(dataset_path=data_path, 
                                                         load_event=False, 
                                                         file_format='memmap',
                                                         transform=transform)

full_loader = DataLoader(full_eeg_dataset,
                         batch_size=4,
                         shuffle=True,
                         drop_last=True,
                         num_workers=num_workers,
                         pin_memory=pin_memory,
                         collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(full_loader):
    pprint.pprint(sample_batched)
    break

In [None]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='abnormal',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=8,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    pprint.pprint(sample_batched, width=250)
    break

---

## Preprocessing steps run by the PyTorch Modules

In [None]:
transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10,       # crop: 10s
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=2,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

### To GPU device if it is possible

In [None]:
print('device:', device)
print()

preprocess_train = transforms.Compose([EegToDevice(device=device)])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched)
    break

### Normalization per signal

In [None]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizePerSignal()
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

### Signal normalization using the specified mean and std values

In [None]:
signal_mean, signal_std = calculate_signal_statistics(train_loader, repeats=1, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

### Age normalization

In [None]:
age_mean, age_std = calculate_age_statistics(train_loader, verbose=True)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['age'])

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['age'])
    break

### DropOut channels

In [None]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegChannelDropOut(p=0.2)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    print(sample_batched['signal'])

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    print(sample_batched['signal'])
    break

### Short time Fourier transform (STFT or spectrogram)

In [None]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    pprint.pprint(sample_batched['signal'].shape)

    print()
    print('-' * 100)
    print()
    
    preprocess_train(sample_batched)
    
    print('- After -')
    pprint.pprint(sample_batched['signal'].shape)
    break

### Signal normalization after STFT

In [None]:
preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=200, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)

signal_2d_mean, signal_2d_std = calculate_signal_statistics(train_loader, preprocess_train)

preprocess_train2 = transforms.Compose([
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std)
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)

pprint.pprint(preprocess_train)
pprint.pprint(preprocess_train2)

for i_batch, sample_batched in enumerate(train_loader):
    print('- Before -')
    preprocess_train(sample_batched)   
    
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    
    print()
    print('-' * 100)
    print()
    
    print('- After -')
    preprocess_train2(sample_batched)
    
    print('Mean:', torch.mean(sample_batched['signal'], axis=-1))
    print()
    print('Std:', torch.std(sample_batched['signal'], axis=-1))
    break

---

## Speed check without STFT

In [None]:
crop_length = 200 * 10
multiple = 4
batch_size = 128

### `EDF`

In [None]:
%%time
transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

### `memmap`

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

### `memmap` (Drop → Crop)

In [None]:
%%time

transform = transforms.Compose([
    EegDropChannels(channel_photic), 
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std)
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)

---

## Speed check with STFT

In [None]:
crop_length = 300 * 10
n_fft, hop_length, seq_len_2d = calculate_stft_params(seq_length=crop_length, verbose=True)
multiple = 2
batch_size = 128

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real')
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
signal_2d_mean, signal_2d_std = calculate_signal_statistics(train_loader, preprocess_train)

### `EDF`

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='edf',
                                                                                  transform=transform)
 
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

### `memmap`

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=crop_length,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegSpectrogram(n_fft=n_fft, hop_length=hop_length, complex_mode='as_real'),
    EegNormalizeMeanStd(mean=signal_2d_mean, std=signal_2d_std),
])
preprocess_train = torch.nn.Sequential(*preprocess_train.transforms).to(device)
pprint.pprint(preprocess_train)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_train(sample_batched)
    size = sample_batched['signal'].size()
    
print(size)

---

## Test on longer sequence

In [None]:
%%time
longer_transform = transforms.Compose([
    EegRandomCrop(crop_length=200*10*6,     # crop: 1m
                  length_limit=200*60*10,   # length: 10m
                  multiple=2, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(longer_transform)

config_data, longer_test_dataset = load_caueeg_task_split(dataset_path=data_path, 
                                                          task='dementia', 
                                                          split='test',
                                                          load_event=False,
                                                          file_format='memmap', 
                                                          transform=longer_transform)

longer_test_loader = DataLoader(longer_test_dataset,
                                batch_size=32,
                                shuffle=True,
                                drop_last=False,
                                num_workers=num_workers,
                                pin_memory=pin_memory,
                                collate_fn=eeg_collate_fn)
 
preprocess_test = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
    EegNormalizeAge(mean=age_mean, std=age_std),
])
preprocess_test = torch.nn.Sequential(*preprocess_test.transforms).to(device)
pprint.pprint(preprocess_test)

for i_batch, sample_batched in enumerate(train_loader):
    preprocess_test(sample_batched)

In [None]:
---

## Resampling

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=200*200,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train1 = transforms.Compose([
    EegToDevice(device=device), 
    EegResample(orig_freq=200, new_freq=250, resampling_method='kaiser_best'),
    EegResample(orig_freq=250, new_freq=200, resampling_method='kaiser_best'),
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train1 = torch.nn.Sequential(*preprocess_train1.transforms).to(device)
pprint.pprint(preprocess_train1)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train2 = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)
pprint.pprint(preprocess_train2)

diff = 0.0
for e in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        from copy import deepcopy
        sb1 = deepcopy(sample_batched)
        sb2 = deepcopy(sample_batched)

        preprocess_train1(sb1)
        preprocess_train2(sb2)
        
        diff += (torch.norm(sb1['signal'] - sb2['signal']) / torch.sqrt(torch.norm(sb1['signal'])) / torch.sqrt(torch.norm(sb1['signal']))).item()
        
print(diff)

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=200*200,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train1 = transforms.Compose([
    EegToDevice(device=device), 
    EegResample(orig_freq=200, new_freq=250, resampling_method='kaiser_fast'),
    EegResample(orig_freq=250, new_freq=200, resampling_method='kaiser_fast'),
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train1 = torch.nn.Sequential(*preprocess_train1.transforms).to(device)
pprint.pprint(preprocess_train1)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train2 = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)
pprint.pprint(preprocess_train2)

diff = 0.0
for e in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        from copy import deepcopy
        sb1 = deepcopy(sample_batched)
        sb2 = deepcopy(sample_batched)

        preprocess_train1(sb1)
        preprocess_train2(sb2)
        
        diff += (torch.norm(sb1['signal'] - sb2['signal']) / torch.sqrt(torch.norm(sb1['signal'])) / torch.sqrt(torch.norm(sb1['signal']))).item()
        
print(diff)

In [None]:
%%time

transform = transforms.Compose([
    EegRandomCrop(crop_length=200*200,
                  length_limit=200*60*10,   # length: 10m
                  multiple=multiple, 
                  latency=200*10),          # latency: 10s
    EegDropChannels(channel_photic), 
    EegToTensor()
])
pprint.pprint(transform)

config_data, train_dataset, val_dataset, test_dataset = load_caueeg_task_datasets(dataset_path=data_path, 
                                                                                  task='dementia',
                                                                                  load_event=False, 
                                                                                  file_format='memmap',
                                                                                  transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train1 = transforms.Compose([
    EegToDevice(device=device), 
    EegResample(orig_freq=200, new_freq=250),
    EegResample(orig_freq=250, new_freq=200),
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train1 = torch.nn.Sequential(*preprocess_train1.transforms).to(device)
pprint.pprint(preprocess_train1)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

preprocess_train2 = transforms.Compose([
    EegToDevice(device=device), 
    EegNormalizeAge(mean=age_mean, std=age_std), 
    EegNormalizeMeanStd(mean=signal_mean, std=signal_std),
])
preprocess_train2 = torch.nn.Sequential(*preprocess_train2.transforms).to(device)
pprint.pprint(preprocess_train2)

diff = 0.0
for e in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        from copy import deepcopy
        sb1 = deepcopy(sample_batched)
        sb2 = deepcopy(sample_batched)

        preprocess_train1(sb1)
        preprocess_train2(sb2)
        
        diff += (torch.norm(sb1['signal'] - sb2['signal']) / torch.sqrt(torch.norm(sb1['signal'])) / torch.sqrt(torch.norm(sb1['signal']))).item()
        
print(diff)