# Dataset

PyTorch의 EEG 데이터를 Dataset class 및 DataLoader class로 처리해보는 노트북

-----

## 환경 구성

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
# Load some packages
import os
import glob
import json

import matplotlib.pyplot as plt
import pprint

import numpy as np
import random
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

# custom package
from datasets.cau_eeg_dataset import *
from datasets.pipeline import *

In [3]:
# Other settings
%matplotlib inline
%config InlineBackend.figure_format = 'retina' # cleaner text

plt.style.use('default') 
# ['Solarize_Light2', '_classic_test_patch', 'bmh', 'classic', 'dark_background', 'fast', 
#  'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn', 'seaborn-bright', 'seaborn-colorblind', 
#  'seaborn-dark', 'seaborn-dark-palette', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 
#  'seaborn-notebook', 'seaborn-paper', 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 
#  'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'tableau-colorblind10']

plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams["font.family"] = 'NanumGothic' # for Hangul in Windows

In [4]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.10.1+cu113
cuda is available.


In [5]:
# Data file path
root_path = r'local/dataset/02_Curated_Data_220402/'

In [6]:
meta_path = os.path.join(root_path, 'metadata_debug.json')
with open(meta_path, 'r') as json_file:
    metadata = json.load(json_file)

pprint.pprint(metadata[0])

{'age': 78,
 'birth': '1940-06-02',
 'dx1': 'mci_rf',
 'edfname': '00001809_261018',
 'label': ['mci', 'mci_amnestic', 'mci_amnestic_rf'],
 'record': '2018-10-26T15:46:26',
 'serial': '00001'}


-----

## Data Filtering by Diagnosis

#### Non-Vascular Dementia, Non-Vascular MCI, Normal

In [7]:
diagnosis_filter = [
    # Normal
    {'type': 'Normal',
     'include': ['normal'], 
     'exclude': []},
    # Non-vascular MCI
    {'type': 'Non-vascular MCI',
     'include': ['mci'], 
     'exclude': ['mci_vascular']},
    # Non-vascular dementia
    {'type': 'Non-vascular dementia',
     'include': ['dementia'], 
     'exclude': ['vd']},
]

def generate_class_label(label):
    for c, f in enumerate(diagnosis_filter):
        # inc = set(f['include']) & set(label) == set(f['include'])
        inc = len(set(f['include']) & set(label)) > 0        
        exc = len(set(f['exclude']) & set(label)) == 0
        if  inc and exc:
            return (c, f['type'])
    return (-1, 'The others')

class_label_to_type = [d_f['type'] for d_f in diagnosis_filter]
print('class_label_to_type:', class_label_to_type)

class_label_to_type: ['Normal', 'Non-vascular MCI', 'Non-vascular dementia']


In [8]:
splitted_metadata = [[] for i in diagnosis_filter]

for m in metadata:
    c, n = generate_class_label(m['label'])
    if c >= 0:
        m['class_type'] = n
        m['class_label'] = c
        splitted_metadata[c].append(m)
        
for i, split in enumerate(splitted_metadata):
    if len(split) == 0:
        print(f'(Warning) Split group {i} has no data.')
    else:
        print(f'- There are {len(split):} data belonging to {split[0]["class_type"]}')

- There are 458 data belonging to Normal
- There are 350 data belonging to Non-vascular MCI
- There are 233 data belonging to Non-vascular dementia


-----

## Configure the Train, Validation, and Test Splits

#### Split the filtered dataset and shuffle them

In [9]:
# Train : Val : Test = 8 : 1 : 1
ratio1 = 0.8
ratio2 = 0.1

metadata_train = []
metadata_val = []
metadata_test = []

for split in splitted_metadata:
    random.shuffle(split)
    
    n1 = round(len(split) * ratio1)
    n2 = n1 + round(len(split) * ratio2)

    metadata_train.extend(split[:n1])
    metadata_val.extend(split[n1:n2])
    metadata_test.extend(split[n2:])

random.shuffle(metadata_train)
random.shuffle(metadata_val)
random.shuffle(metadata_test)

print('Train data size\t\t:', len(metadata_train))
print('Validation data size\t:', len(metadata_val))
print('Test data size\t\t:', len(metadata_test))

print('\n', '--- Recheck ---', '\n')
train_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_train:
    train_class_nums[m['class_label']] += 1

val_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_val:
    val_class_nums[m['class_label']] += 1

test_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_test:
    test_class_nums[m['class_label']] += 1

print('Train data label distribution\t:', train_class_nums, train_class_nums.sum())
print('Val data label distribution\t:', val_class_nums, val_class_nums.sum())
print('Test data label distribution\t:', test_class_nums, test_class_nums.sum())

Train data size		: 832
Validation data size	: 104
Test data size		: 105

 --- Recheck --- 

Train data label distribution	: [366 280 186] 832
Val data label distribution	: [46 35 23] 104
Test data label distribution	: [46 35 24] 105


-----

## Test TorchVision Transform

#### Data Format: `EDF`

In [10]:
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=True, file_format='edf')
print(dataset[0])

{'signal': array([[-1.,  9., 14., ...,  0.,  0.,  0.],
       [ 0.,  3.,  4., ...,  0.,  0.,  0.],
       [ 1.,  1., -1., ...,  0.,  0.,  0.],
       ...,
       [ 1.,  1., -2., ...,  0.,  0.,  0.],
       [37.,  7.,  0., ...,  0.,  0.,  0.],
       [ 0., -1., -1., ...,  0.,  0.,  0.]]), 'age': 72, 'class_label': 0, 'metadata': {'serial': '00664', 'edfname': '00956048_010715', 'birth': '1942-09-13', 'record': '2015-07-01T13:45:25', 'age': 72, 'dx1': 'cb_normal', 'label': ['normal', 'cb_normal'], 'class_type': 'Normal', 'class_label': 0, 'event': [[0, 'Start Recording'], [0, 'New Montage - Montage 002'], [526, 'Eyes Open'], [6154, 'Eyes Closed'], [12286, 'Eyes Open'], [19972, 'Eyes Closed'], [25054, 'Eyes Open'], [30220, 'Eyes Closed'], [36268, 'Eyes Open'], [42274, 'Eyes Closed'], [48322, 'Eyes Open'], [54117, 'Eyes Closed'], [61341, 'Eyes Open'], [69658, 'Eyes Closed'], [75748, 'Eyes Open'], [82174, 'Eyes Closed'], [89818, 'Eyes Closed'], [101578, 'Eyes Open'], [107206, 'Eyes Closed']

#### Data Format: `PyArrow Feather`

In [11]:
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=True, file_format='feather')
print(dataset[0])
dataset.get_data_frame(0)

{'signal': array([[ -1,   9,  14, ...,  -6,  -2,   0],
       [  0,   3,   4, ...,   5,   8,  10],
       [  1,   1,  -1, ...,  13,  12,  10],
       ...,
       [  1,   1,  -2, ...,   1,  -1,  -1],
       [ 37,   7,   0, ..., -10, -11,  -9],
       [  0,  -1,  -1, ...,  -1,   2,   4]]), 'age': 72, 'class_label': 0, 'metadata': {'serial': '00664', 'edfname': '00956048_010715', 'birth': '1942-09-13', 'record': '2015-07-01T13:45:25', 'age': 72, 'dx1': 'cb_normal', 'label': ['normal', 'cb_normal'], 'class_type': 'Normal', 'class_label': 0, 'event': [[0, 'Start Recording'], [0, 'New Montage - Montage 002'], [526, 'Eyes Open'], [6154, 'Eyes Closed'], [12286, 'Eyes Open'], [19972, 'Eyes Closed'], [25054, 'Eyes Open'], [30220, 'Eyes Closed'], [36268, 'Eyes Open'], [42274, 'Eyes Closed'], [48322, 'Eyes Open'], [54117, 'Eyes Closed'], [61341, 'Eyes Open'], [69658, 'Eyes Closed'], [75748, 'Eyes Open'], [82174, 'Eyes Closed'], [89818, 'Eyes Closed'], [101578, 'Eyes Open'], [107206, 'Eyes Closed']

Unnamed: 0,Fp1-AVG,F3-AVG,C3-AVG,P3-AVG,O1-AVG,Fp2-AVG,F4-AVG,C4-AVG,P4-AVG,O2-AVG,...,T3-AVG,T5-AVG,F8-AVG,T4-AVG,T6-AVG,FZ-AVG,CZ-AVG,PZ-AVG,EKG,Photic
0,-1,0,1,-3,-1,-6,-8,-3,2,-1,...,-8,-4,26,0,0,0,-6,1,37,0
1,9,3,1,-7,-3,-1,-13,-2,5,2,...,-7,-6,27,2,5,2,-5,1,7,-1
2,14,4,-1,-8,-4,-3,-6,-2,6,5,...,-7,-6,29,3,7,2,-7,-2,0,-1
3,12,2,-3,-5,-2,3,-4,-6,3,6,...,-6,-4,30,4,7,1,-9,-4,8,-1
4,10,3,-2,-4,0,4,-6,-9,-1,3,...,-4,-3,31,3,5,2,-7,-3,7,-3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
120595,-8,9,17,6,2,-48,-22,-6,-3,-2,...,20,14,-45,-7,-4,-15,0,8,-31,0
120596,-10,6,15,3,2,-41,-23,-5,0,3,...,21,14,-43,-3,1,-18,-3,5,-13,0
120597,-6,5,13,1,3,-34,-26,-5,0,7,...,22,14,-40,-4,2,-16,-4,1,-10,-1
120598,-2,8,12,1,4,-35,-32,-9,-2,8,...,24,15,-40,-7,2,-13,-4,-1,-11,2


#### Data Format: `NumPy Memmap`

In [13]:
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=True, file_format='memmap')
print(dataset[0])

{'signal': memmap([[ -1,   9,  14, ...,  -6,  -2,   0],
        [  0,   3,   4, ...,   5,   8,  10],
        [  1,   1,  -1, ...,  13,  12,  10],
        ...,
        [  1,   1,  -2, ...,   1,  -1,  -1],
        [ 37,   7,   0, ..., -10, -11,  -9],
        [  0,  -1,  -1, ...,  -1,   2,   4]]), 'age': 72, 'class_label': 0, 'metadata': {'serial': '00664', 'edfname': '00956048_010715', 'birth': '1942-09-13', 'record': '2015-07-01T13:45:25', 'age': 72, 'dx1': 'cb_normal', 'label': ['normal', 'cb_normal'], 'class_type': 'Normal', 'class_label': 0, 'event': [[0, 'Start Recording'], [0, 'New Montage - Montage 002'], [526, 'Eyes Open'], [6154, 'Eyes Closed'], [12286, 'Eyes Open'], [19972, 'Eyes Closed'], [25054, 'Eyes Open'], [30220, 'Eyes Closed'], [36268, 'Eyes Open'], [42274, 'Eyes Closed'], [48322, 'Eyes Open'], [54117, 'Eyes Closed'], [61341, 'Eyes Open'], [69658, 'Eyes Closed'], [75748, 'Eyes Open'], [82174, 'Eyes Closed'], [89818, 'Eyes Closed'], [101578, 'Eyes Open'], [107206, 'Eyes C

#### Max Length

In [30]:
for i in range(2):
    dataset = CauEegDataset(root_path, metadata_train, load_event=True, file_format='feather',
                            transform=EegLimitMaxLength(50))
    print(dataset[0]['signal'].shape)

(21, 50)
(21, 50)


#### Random crop

In [31]:
for i in range(2):
    dataset = CauEegDataset(root_path, metadata_train, load_event=True, file_format='feather',
                            transform=EegRandomCrop(crop_length=3))
    print(dataset[0])
    print('\n')
    print('-' * 100)
    print('\n')

{'signal': array([[ -3,  -4,  -3],
       [  5,   5,   4],
       [ 10,  14,  16],
       [  3,   5,   6],
       [ -4,  -4,  -4],
       [-40, -33, -34],
       [ -9, -11, -12],
       [  3,   0,   0],
       [  4,   3,   3],
       [ -6,  -8,  -9],
       [ 13,  10,  10],
       [  8,   9,   8],
       [  2,   2,   2],
       [-20, -19, -19],
       [ -5,  -7,  -7],
       [  4,   3,   2],
       [ -6,  -7,  -9],
       [ -1,   0,   1],
       [  1,   5,   7],
       [ -9, -25, -22],
       [  0,   0,   0]]), 'age': 72, 'class_label': 0, 'metadata': {'serial': '00664', 'edfname': '00956048_010715', 'birth': '1942-09-13', 'record': '2015-07-01T13:45:25', 'age': 72, 'dx1': 'cb_normal', 'label': ['normal', 'cb_normal'], 'class_type': 'Normal', 'class_label': 0, 'event': [[0, 'Start Recording'], [0, 'New Montage - Montage 002'], [526, 'Eyes Open'], [6154, 'Eyes Closed'], [12286, 'Eyes Open'], [19972, 'Eyes Closed'], [25054, 'Eyes Open'], [30220, 'Eyes Closed'], [36268, 'Eyes Open'], [422

In [32]:
for i in range(2):
    dataset = CauEegDataset(root_path, metadata_train, load_event=False, file_format='feather',
                            transform=EegRandomCrop(crop_length=3, multiple=2))
    print(dataset[0]['signal'])
    print('\n')
    print('-' * 100)
    print('\n')

[array([[-44, -46, -46],
       [  2,   2,   4],
       [ -5,  -1,   2],
       [  4,   4,   3],
       [ -1,  -3,  -5],
       [-21, -24, -21],
       [-11, -10,  -8],
       [ -2,  -2,  -1],
       [  0,  -2,  -2],
       [  6,   5,   2],
       [ 18,  18,  19],
       [ -7,  -7,  -7],
       [ -2,  -3,  -5],
       [-27, -27, -26],
       [ -5,  -6,  -6],
       [  2,   0,  -2],
       [ 17,  16,  17],
       [  5,   7,   8],
       [  5,   6,   6],
       [ 53,  63,  52],
       [ -1,   2,   3]]), array([[24, 22, 19],
       [15, 10,  6],
       [-2, -2, -4],
       [-4, -3, -2],
       [-8, -6, -6],
       [54, 53, 56],
       [10,  8,  8],
       [ 1, -1, -2],
       [-4, -4, -4],
       [-6, -4, -2],
       [ 7,  6,  4],
       [-1, -1, -1],
       [-6, -4, -2],
       [-6, -6, -4],
       [-6, -3, -3],
       [-8, -5, -2],
       [19, 16, 14],
       [ 2,  1,  1],
       [-4, -3, -2],
       [17, 15, 26],
       [ 0, -1, -2]])]


------------------------------------------------

In [33]:
for i in range(2):
    dataset = CauEegDataset(root_path, metadata_train, load_event=True, file_format='feather',
                            transform=EegRandomCropDebug(crop_length=3, multiple=2, latency=50000))
    print(dataset[0])
    print('\n')
    print('-' * 100)
    print('\n')

{'signal': [array([[-49, -46, -43],
       [ -5,  -3,  -2],
       [ -6,  -4,  -5],
       [-10, -12, -12],
       [ -3,  -5,  -4],
       [-50, -47, -47],
       [-12,  -9,  -5],
       [ -2,  -2,  -3],
       [  8,   5,   2],
       [ 12,   9,   5],
       [  2,   4,   7],
       [  9,   6,   6],
       [  7,   6,   7],
       [  7,  10,  14],
       [ 10,  10,  12],
       [ 12,  11,  11],
       [-19, -18, -16],
       [ -7,  -6,  -7],
       [ -3,  -4,  -8],
       [211, 166, 160],
       [ -2,  -3,   0]]), array([[-17, -15, -17],
       [ -2,  -3,  -3],
       [ -9, -14, -18],
       [ -7, -11, -13],
       [  8,   6,   5],
       [  8,  11,  11],
       [ -4,  -2,  -3],
       [ -8,  -4,  -1],
       [  4,   6,   8],
       [ 17,  15,  14],
       [ -4,  -1,  -1],
       [ -3,  -1,   0],
       [  2,   2,   2],
       [  4,   6,   8],
       [  8,   7,   7],
       [ 15,  15,  14],
       [  0,   0,  -1],
       [ -9,  -9,  -7],
       [-11, -12, -11],
       [ 72,  63,  37],
  

#### Drop EKG channel

In [34]:
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=None)
print('before:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

print()
print('-' * 100)
print()

dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=EegDropEKGChannel())
print('after:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

before: (21, 121800)
[[-1.  9. 14. ...  0.  0.  0.]
 [ 0.  3.  4. ...  0.  0.  0.]
 [ 1.  1. -1. ...  0.  0.  0.]
 ...
 [ 1.  1. -2. ...  0.  0.  0.]
 [37.  7.  0. ...  0.  0.  0.]
 [ 0. -1. -1. ...  0.  0.  0.]]

----------------------------------------------------------------------------------------------------

after: (20, 121800)
[[-1.  9. 14. ...  0.  0.  0.]
 [ 0.  3.  4. ...  0.  0.  0.]
 [ 1.  1. -1. ...  0.  0.  0.]
 ...
 [-6. -5. -7. ...  0.  0.  0.]
 [ 1.  1. -2. ...  0.  0.  0.]
 [ 0. -1. -1. ...  0.  0.  0.]]


#### Drop photic stimulation channel

In [35]:
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=None)
print('before:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

print()
print('-' * 100)
print()

dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=EegDropPhoticChannel())
print('after:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

before: (21, 121800)
[[-1.  9. 14. ...  0.  0.  0.]
 [ 0.  3.  4. ...  0.  0.  0.]
 [ 1.  1. -1. ...  0.  0.  0.]
 ...
 [ 1.  1. -2. ...  0.  0.  0.]
 [37.  7.  0. ...  0.  0.  0.]
 [ 0. -1. -1. ...  0.  0.  0.]]

----------------------------------------------------------------------------------------------------

after: (20, 121800)
[[-1.  9. 14. ...  0.  0.  0.]
 [ 0.  3.  4. ...  0.  0.  0.]
 [ 1.  1. -1. ...  0.  0.  0.]
 ...
 [-6. -5. -7. ...  0.  0.  0.]
 [ 1.  1. -2. ...  0.  0.  0.]
 [37.  7.  0. ...  0.  0.  0.]]


#### To Tensor

In [36]:
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=None)
print('before:')
print(dataset[0])

print()
print('-' * 100)
print()

dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=EegToTensor())
print('after:')
print(dataset[0])

before:
{'signal': array([[-1.,  9., 14., ...,  0.,  0.,  0.],
       [ 0.,  3.,  4., ...,  0.,  0.,  0.],
       [ 1.,  1., -1., ...,  0.,  0.,  0.],
       ...,
       [ 1.,  1., -2., ...,  0.,  0.,  0.],
       [37.,  7.,  0., ...,  0.,  0.,  0.],
       [ 0., -1., -1., ...,  0.,  0.,  0.]]), 'age': 72, 'class_label': 0, 'metadata': {'serial': '00664', 'edfname': '00956048_010715', 'birth': '1942-09-13', 'record': '2015-07-01T13:45:25', 'age': 72, 'dx1': 'cb_normal', 'label': ['normal', 'cb_normal'], 'class_type': 'Normal', 'class_label': 0}}

----------------------------------------------------------------------------------------------------

after:
{'signal': tensor([[-1.,  9., 14.,  ...,  0.,  0.,  0.],
        [ 0.,  3.,  4.,  ...,  0.,  0.,  0.],
        [ 1.,  1., -1.,  ...,  0.,  0.,  0.],
        ...,
        [ 1.,  1., -2.,  ...,  0.,  0.,  0.],
        [37.,  7.,  0.,  ...,  0.,  0.,  0.],
        [ 0., -1., -1.,  ...,  0.,  0.,  0.]]), 'age': tensor(72.), 'class_label': t

#### Normalization per signal

In [37]:
composed = transforms.Compose([EegToTensor(), EegNormalizePerSignal()])
dataset = CauEegDataset(root_path, metadata_train, load_event=False, transform=composed)

print(dataset[0])

print()
print('-' * 100)
print()

print('Mean:', torch.mean(dataset[0]['signal'], axis=1))
print('Std:', torch.std(dataset[0]['signal'], axis=1))

{'signal': tensor([[-2.1657e-02,  1.7518e-01,  2.7360e-01,  ..., -1.9734e-03,
         -1.9734e-03, -1.9734e-03],
        [-8.7267e-04,  2.0686e-01,  2.7611e-01,  ..., -8.7267e-04,
         -8.7267e-04, -8.7267e-04],
        [ 1.2075e-01,  1.2075e-01, -1.1930e-01,  ...,  7.2823e-04,
          7.2823e-04,  7.2823e-04],
        ...,
        [ 1.4415e-01,  1.4415e-01, -2.8889e-01,  ..., -1.9673e-04,
         -1.9673e-04, -1.9673e-04],
        [ 3.6194e-01,  6.8453e-02, -2.8353e-05,  ..., -2.8353e-05,
         -2.8353e-05, -2.8353e-05],
        [ 1.0690e-03, -7.1040e-01, -7.1040e-01,  ...,  1.0690e-03,
          1.0690e-03,  1.0690e-03]]), 'age': tensor(72.), 'class_label': tensor(0), 'metadata': {'serial': '00664', 'edfname': '00956048_010715', 'birth': '1942-09-13', 'record': '2015-07-01T13:45:25', 'age': 72, 'dx1': 'cb_normal', 'label': ['normal', 'cb_normal'], 'class_type': 'Normal', 'class_label': 0}}

-----------------------------------------------------------------------------------

#### Age normalization

In [38]:
ages = []
for m in metadata_train:
    ages.append(m['age'])

ages = np.array(ages)
age_mean = np.mean(ages)
age_std = np.std(ages)

print('Age mean and standard deviation:\t', age_mean, ',\t', age_std)

print()
print('-' * 100)
print()

print('before:')
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=None)
for i in range(5):
    print(dataset[i]['age'])

print()
print('-' * 100)
print()

print('after:')
composed = transforms.Compose([EegToTensor(), EegNormalizeAge(mean=age_mean, std=age_std)])
dataset = CauEegDataset(root_path, metadata_train, load_event=False, transform=composed)
for i in range(5):
    print(dataset[i]['age'])

Age mean and standard deviation:	 70.32692307692308 ,	 9.844895502524205

----------------------------------------------------------------------------------------------------

before:
72
72
71
59
77

----------------------------------------------------------------------------------------------------

after:
tensor(0.1699)
tensor(0.1699)
tensor(0.0684)
tensor(-1.1505)
tensor(0.6778)


#### Short time Fourier transform (STFT or spectrogram)

In [39]:
composed = transforms.Compose([EegToTensor(), EegSpectrogram(n_fft=200, complex_mode='as_real')])
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

print()
print('-' * 100)
print()

composed = transforms.Compose([EegToTensor(), EegSpectrogram(n_fft=200, complex_mode='power')])
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

print()
print('-' * 100)
print()

composed = transforms.Compose([EegToTensor(), EegSpectrogram(n_fft=200, complex_mode='remove')])
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

torch.Size([42, 101, 2437]) torch.float32 <class 'torch.Tensor'>
tensor([[-4.0560e+03,  1.6568e+03,  5.5016e+02,  ...,  5.2105e+01,
          5.1527e+01,  4.4000e+01],
        [-1.3200e+03,  5.6830e+02, -4.2169e+01,  ...,  2.0512e+01,
          1.9415e+01,  2.6000e+01],
        [ 4.0600e+02,  2.3340e+02,  8.0225e+01,  ...,  3.4422e+00,
          3.5751e+00,  0.0000e+00],
        ...,
        [ 0.0000e+00, -1.5536e+02, -1.5174e+02,  ..., -2.9735e+00,
         -3.4531e-01,  0.0000e+00],
        [ 0.0000e+00, -1.8225e+03,  1.1826e+03,  ...,  3.4226e+00,
         -2.8571e+00,  0.0000e+00],
        [ 0.0000e+00, -2.3688e+01,  4.7708e+01,  ..., -1.8115e-01,
         -5.0074e-01,  0.0000e+00]])

----------------------------------------------------------------------------------------------------

torch.Size([21, 101, 2437]) torch.float32 <class 'torch.Tensor'>
tensor([[4.0560e+03, 1.9314e+03, 6.9652e+02,  ..., 5.2154e+01, 5.1528e+01,
         4.4000e+01],
        [1.3200e+03, 7.7577e+02, 2.337

#### Transform Composition

In [43]:
composed = transforms.Compose([EegDropPhoticChannel(), 
                               EegRandomCrop(crop_length=200*60, latency=200*10), # 1 minute
                               EegToTensor(), 
                               EegNormalizeAge(mean=age_mean, std=age_std), 
                               EegNormalizePerSignal()])
train_dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                              file_format='memmap', transform=composed)
print(train_dataset[0]['signal'].shape)
print(train_dataset[0])

print()
print('-' * 100)
print()

composed_np = transforms.Compose([EegDropPhoticChannel(), 
                                  EegRandomCrop(crop_length=200*60, latency=200*10), # 1 minute
                                  EegToTensor()])
composed_pt = transforms.Compose([EegNormalizeAge(mean=age_mean, std=age_std), 
                                  EegNormalizePerSignal()])
composed_pt = torch.nn.Sequential(*composed_pt.transforms)

train_dataset = CauEegDataset(root_path, metadata_train, 
                              load_event=False, transform=composed)

sample = composed_pt(train_dataset[0])
print(sample['signal'].shape)
print(sample)

torch.Size([20, 12000])
{'signal': tensor([[ 0.3132,  0.3654,  0.4349,  ...,  1.7211,  1.6342,  1.5647],
        [ 1.1582,  1.2969,  1.6435,  ...,  1.0889,  0.8116,  0.4649],
        [-0.5548, -0.9008, -0.6701,  ..., -0.3241, -0.5548, -0.9008],
        ...,
        [-0.1085,  0.1097,  0.5461,  ...,  0.1097, -0.2176, -0.7631],
        [-1.0020, -1.0020, -0.8546,  ...,  0.9144,  0.9144,  0.4722],
        [-0.2637, -0.2931, -0.3028,  ..., -0.0683, -0.0585, -0.1171]]), 'age': tensor(0.1699), 'class_label': tensor(0), 'metadata': {'serial': '00664', 'edfname': '00956048_010715', 'birth': '1942-09-13', 'record': '2015-07-01T13:45:25', 'age': 72, 'dx1': 'cb_normal', 'label': ['normal', 'cb_normal'], 'class_type': 'Normal', 'class_label': 0}}

----------------------------------------------------------------------------------------------------

torch.Size([20, 12000])
{'signal': tensor([[ 0.1516,  0.1881,  0.1698,  ...,  0.0236, -0.0495, -0.0678],
        [ 0.0981, -0.4306, -0.1001,  ...,  0.23

#### Train, Validation, Test Datasets

In [44]:
composed = transforms.Compose([EegDropPhoticChannel(),
                               EegRandomCrop(crop_length=200*60, latency=200*10), # 1 minute
                               EegToTensor()])

train_dataset = CauEegDataset(root_path, metadata_train, 
                              load_event=False, transform=composed)
val_dataset = CauEegDataset(root_path, metadata_val, 
                            load_event=False, transform=composed)
test_dataset = CauEegDataset(root_path, metadata_test, 
                             load_event=False, transform=composed)

print(train_dataset[0]['signal'][0].shape)
print(torch.mean(train_dataset[0]['signal'][0], axis=-1))
print(torch.std(train_dataset[0]['signal'][0], axis=-1))

print()
print('-' * 100)
print()

print(train_dataset[0]['signal'][1].shape)
print(torch.mean(train_dataset[0]['signal'][1], axis=-1))
print(torch.std(train_dataset[0]['signal'][1], axis=-1))

print()
print('-' * 100)
print()

print(train_dataset[0]['signal'][1].shape)
print(train_dataset[0])

print()
print('-' * 100)
print()

print(val_dataset[0])

print()
print('-' * 100)
print()

print(test_dataset[0])

torch.Size([12000])
tensor(-0.4210)
tensor(48.9648)

----------------------------------------------------------------------------------------------------

torch.Size([12000])
tensor(-0.1077)
tensor(13.4308)

----------------------------------------------------------------------------------------------------

torch.Size([12000])
{'signal': tensor([[ 61.,  58.,  52.,  ..., -14., -14., -19.],
        [ 25.,  22.,  17.,  ..., -13.,  -6.,   1.],
        [  5.,   4.,   1.,  ...,  -8.,  -2.,   4.],
        ...,
        [ 12.,  11.,   8.,  ...,   7.,  13.,  22.],
        [  4.,   2.,  -1.,  ...,  -6.,  -3.,   2.],
        [-21., -54., -91.,  ...,  62.,  45., -34.]]), 'age': tensor(72.), 'class_label': tensor(0), 'metadata': {'serial': '00664', 'edfname': '00956048_010715', 'birth': '1942-09-13', 'record': '2015-07-01T13:45:25', 'age': 72, 'dx1': 'cb_normal', 'label': ['normal', 'cb_normal'], 'class_type': 'Normal', 'class_label': 0}}

-----------------------------------------------------------

In [46]:
composed = transforms.Compose([EegDropPhoticChannel(),
                               EegRandomCrop(crop_length=200*60, latency=200*10), # 1 minutes
                               EegToTensor(),
                               EegNormalizePerSignal(),
                               EegNormalizeAge(mean=age_mean, std=age_std),
                               EegSpectrogram(n_fft=200, complex_mode='power', hop_length=200 // 2)])
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

torch.Size([20, 101, 121]) torch.float32 <class 'torch.Tensor'>
tensor([[ 6.3965,  2.5732,  3.6047,  ...,  0.0834,  0.1731,  0.0705],
        [ 1.5608,  7.2968, 15.7876,  ...,  0.2967,  0.1937,  0.0838],
        [22.6381, 15.1222, 10.8481,  ...,  0.1824,  0.1819,  0.2878],
        ...,
        [19.2723, 18.6043, 20.1806,  ...,  0.4046,  0.6141,  0.2121],
        [16.9573, 15.8265, 17.8164,  ...,  0.1465,  0.3879,  0.3033],
        [ 5.7490, 25.2681, 16.1617,  ...,  0.1256,  0.0360,  0.0680]])


#### Data loader test

In [54]:
%%time

print('Current PyTorch device:', device)
if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed = transforms.Compose([
    EegDropPhoticChannel(),
    EegRandomCrop(crop_length=200*60, latency=200*10), # 20 seconds
    EegToTensor(),
    EegNormalizeAge(mean=age_mean, std=age_std),
    EegNormalizePerSignal(),
])

train_dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                              file_format='edf', transform=composed)

train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        sample_batched['signal'].to(device)
        sample_batched['age'].to(device)
        sample_batched['class_label'].to(device)
        
# pprint.pprint(sample_batched)

Current PyTorch device: cuda
CPU times: total: 1h 41min 46s
Wall time: 8min 29s


In [55]:
%%time

print('Current PyTorch device:', device)
if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed = transforms.Compose([
    EegDropPhoticChannel(),
    EegRandomCrop(crop_length=200*60, latency=200*10), # 20 seconds
    EegToTensor(),
    EegNormalizeAge(mean=age_mean, std=age_std),
    EegNormalizePerSignal(),
])

train_dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                              file_format='feather', transform=composed)

train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        sample_batched['signal'].to(device)
        sample_batched['age'].to(device)
        sample_batched['class_label'].to(device)
        
# pprint.pprint(sample_batched)

Current PyTorch device: cuda
CPU times: total: 5min 29s
Wall time: 20.6 s


In [56]:
%%time

print('Current PyTorch device:', device)
if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed = transforms.Compose([
    EegDropPhoticChannel(),
    EegRandomCrop(crop_length=200*60, latency=200*10), # 60 seconds
    EegToTensor(),
    EegNormalizeAge(mean=age_mean, std=age_std),
    EegNormalizePerSignal(),
])

train_dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                              file_format='memmap', transform=composed)

train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        sample_batched['signal'].to(device)
        sample_batched['age'].to(device)
        sample_batched['class_label'].to(device)
        
# pprint.pprint(sample_batched)

Current PyTorch device: cuda
CPU times: total: 1min 23s
Wall time: 6.98 s


In [57]:
%%time

print('Current PyTorch device:', device)
if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed_np = transforms.Compose([EegDropPhoticChannel(), 
                                  EegRandomCrop(crop_length=200*60, latency=200*10), # 20 seconds
                                  EegToTensor()])
preprocess = transforms.Compose([EegNormalizePerSignal(), 
                                 EegNormalizeAge(mean=age_mean, std=age_std)])
preprocess = torch.nn.Sequential(*preprocess.transforms)

train_dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                              file_format='memmap', transform=composed_np)

train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        # pull up the batch data
        x = sample_batched['signal'] = sample_batched['signal'].to(device)
        age = sample_batched['age'] = sample_batched['age'].to(device)
        target = sample_batched['class_label'] = sample_batched['class_label'].to(device)
        
        preprocess(sample_batched)
        
# pprint.pprint(sample_batched)

Current PyTorch device: cuda
CPU times: total: 1min 1s
Wall time: 5.16 s


In [58]:
%%time

print('Current PyTorch device:', device)
if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed_np = transforms.Compose([EegDropPhoticChannel(), 
                                  EegRandomCrop(crop_length=200*60, latency=200*10), # 20 seconds
                                  EegToTensor()])
preprocess = transforms.Compose([EegNormalizePerSignal(), 
                                 EegNormalizeAge(mean=age_mean, std=age_std), 
                                 EegToDevice(device=device)])
preprocess = torch.nn.Sequential(*preprocess.transforms)

train_dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                              file_format='memmap', transform=composed_np)

train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        preprocess(sample_batched)
        
print(sample_batched)

Current PyTorch device: cuda
{'signal': tensor([[[ 0.9220,  0.7702,  0.6943,  ..., -0.0267, -0.2164, -0.4062],
         [-1.2627, -1.1418, -1.5647,  ..., -0.6586, -0.4169, -0.4774],
         [ 0.0697, -0.0857, -0.0857,  ..., -0.3964, -0.3964, -0.3964],
         ...,
         [ 1.0843,  1.2414,  1.0843,  ...,  1.0843,  0.9272,  0.6130],
         [ 1.7394,  1.7394,  2.0624,  ...,  1.5778,  1.7394,  1.7394],
         [ 1.4858,  1.5502,  1.9773,  ..., -1.6000, -2.5346, -3.9687]],

        [[-1.7006, -1.5556, -1.2654,  ..., -1.8699, -1.3621, -1.1687],
         [ 0.7225,  0.8934,  0.8934,  ..., -0.8587, -0.9014, -1.4142],
         [ 0.0817,  0.0285,  0.1348,  ...,  0.0817, -0.1309, -0.0246],
         ...,
         [ 2.0449,  1.8987,  1.8256,  ...,  0.0712, -0.2212,  0.0712],
         [ 0.4372,  0.3493,  0.2614,  ...,  1.3162,  1.4041,  1.6679],
         [ 0.2128,  0.1728,  0.1728,  ..., -0.1998, -0.1599, -0.1865]],

        [[ 0.2190,  0.2651,  0.4494,  ...,  0.9101,  1.0944,  1.3247],
     