# Dataset

PyTorch의 EEG 데이터를 Dataset class 및 DataLoader class로 처리해보는 노트북

-----

## 환경 구성

In [31]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [32]:
# Load some packages
import os
import glob
import json

import matplotlib.pyplot as plt
import pprint

import numpy as np
import random
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

# custom package
from datasets.cau_eeg_dataset import *
from datasets.pipeline import *

In [3]:
# Other settings
%matplotlib inline
%config InlineBackend.figure_format = 'retina' # cleaner text

plt.style.use('default') 
# ['Solarize_Light2', '_classic_test_patch', 'bmh', 'classic', 'dark_background', 'fast', 
#  'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn', 'seaborn-bright', 'seaborn-colorblind', 
#  'seaborn-dark', 'seaborn-dark-palette', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 
#  'seaborn-notebook', 'seaborn-paper', 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 
#  'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'tableau-colorblind10']

plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams["font.family"] = 'NanumGothic' # for Hangul in Windows

In [4]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.11.0
cuda is available.


In [5]:
# Data file path
root_path = r'local/dataset/02_Curated_Data_220322/' # 02_Curated_Data_210705

In [6]:
meta_path = os.path.join(root_path, 'metadata_debug.json')
with open(meta_path, 'r') as json_file:
    metadata = json.load(json_file)

pprint.pprint(metadata[0])

{'age': 78,
 'birth': '1940-06-02',
 'dx1': 'mci_rf',
 'edfname': '00001809_261018',
 'label': ['mci', 'mci_amnestic', 'mci_amnestic_rf'],
 'record': '2018-10-26T15:46:26',
 'serial': '00001'}


-----

## Data Filtering by Diagnosis

#### Non-Vascular Dementia, Non-Vascular MCI, Normal

In [7]:
diagnosis_filter = [
    # Normal
    {'type': 'Normal',
     'include': ['normal'], 
     'exclude': []},
    # Non-vascular MCI
    {'type': 'Non-vascular MCI',
     'include': ['mci'], 
     'exclude': ['mci_vascular']},
    # Non-vascular dementia
    {'type': 'Non-vascular dementia',
     'include': ['dementia'], 
     'exclude': ['vd']},
]

def generate_class_label(label):
    for c, f in enumerate(diagnosis_filter):
        # inc = set(f['include']) & set(label) == set(f['include'])
        inc = len(set(f['include']) & set(label)) > 0        
        exc = len(set(f['exclude']) & set(label)) == 0
        if  inc and exc:
            return (c, f['type'])
    return (-1, 'The others')

class_label_to_type = [d_f['type'] for d_f in diagnosis_filter]
print('class_label_to_type:', class_label_to_type)

class_label_to_type: ['Normal', 'Non-vascular MCI', 'Non-vascular dementia']


In [8]:
splitted_metadata = [[] for i in diagnosis_filter]

for m in metadata:
    c, n = generate_class_label(m['label'])
    if c >= 0:
        m['class_type'] = n
        m['class_label'] = c
        splitted_metadata[c].append(m)
        
for i, split in enumerate(splitted_metadata):
    if len(split) == 0:
        print(f'(Warning) Split group {i} has no data.')
    else:
        print(f'- There are {len(split):} data belonging to {split[0]["class_type"]}')

- There are 458 data belonging to Normal
- There are 350 data belonging to Non-vascular MCI
- There are 233 data belonging to Non-vascular dementia


-----

## Configure the Train, Validation, and Test Splits

#### Split the filtered dataset and shuffle them

In [9]:
# Train : Val : Test = 8 : 1 : 1
ratio1 = 0.8
ratio2 = 0.1

metadata_train = []
metadata_val = []
metadata_test = []

for split in splitted_metadata:
    random.shuffle(split)
    
    n1 = round(len(split) * ratio1)
    n2 = n1 + round(len(split) * ratio2)

    metadata_train.extend(split[:n1])
    metadata_val.extend(split[n1:n2])
    metadata_test.extend(split[n2:])

random.shuffle(metadata_train)
random.shuffle(metadata_val)
random.shuffle(metadata_test)

print('Train data size\t\t:', len(metadata_train))
print('Validation data size\t:', len(metadata_val))
print('Test data size\t\t:', len(metadata_test))

print('\n', '--- Recheck ---', '\n')
train_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_train:
    train_class_nums[m['class_label']] += 1

val_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_val:
    val_class_nums[m['class_label']] += 1

test_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_test:
    test_class_nums[m['class_label']] += 1

print('Train data label distribution\t:', train_class_nums, train_class_nums.sum())
print('Val data label distribution\t:', val_class_nums, val_class_nums.sum())
print('Test data label distribution\t:', test_class_nums, test_class_nums.sum())

Train data size		: 832
Validation data size	: 104
Test data size		: 105

 --- Recheck --- 

Train data label distribution	: [366 280 186] 832
Val data label distribution	: [46 35 23] 104
Test data label distribution	: [46 35 24] 105


-----

## Test TorchVision Transform

#### Data Format: `PyArrow Feather`

In [10]:
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=True, file_format='feather')
dataset.get_data_frame(0)

Unnamed: 0,Fp1-AVG,F3-AVG,C3-AVG,P3-AVG,O1-AVG,Fp2-AVG,F4-AVG,C4-AVG,P4-AVG,O2-AVG,...,T3-AVG,T5-AVG,F8-AVG,T4-AVG,T6-AVG,FZ-AVG,CZ-AVG,PZ-AVG,EKG,Photic
0,-1,15,-1,-4,-9,7,0,-8,-12,-7,...,-6,-3,2,19,-6,5,-8,-8,-42,0
1,-2,15,0,-1,-8,5,-5,-6,-7,-3,...,-5,-3,-5,20,1,4,-7,-5,-39,0
2,-4,13,-1,-1,-9,2,-9,-7,-5,-2,...,-9,-5,-8,22,2,2,-7,-5,-19,0
3,-7,14,-1,-2,-11,0,-5,-3,-4,-2,...,-12,-9,-8,25,2,3,-6,-5,-11,0
4,-4,15,-2,-3,-12,4,0,0,-2,-2,...,-12,-11,-2,23,4,4,-7,-5,-16,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
179795,-12,6,-4,-6,-16,-4,19,5,3,-10,...,-12,-13,-12,1,-3,9,11,-2,59,-1
179796,-11,4,-3,-3,-11,-3,19,6,6,-5,...,-13,-12,-9,5,1,9,12,2,26,0
179797,-13,-1,-6,-2,-8,-6,16,5,7,-3,...,-17,-12,-11,4,2,3,9,3,15,0
179798,-16,-6,-8,0,-5,-8,12,3,8,0,...,-18,-11,-11,3,5,-4,5,4,15,0


#### Data Format: `NumPy memmap`

In [11]:
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=True, file_format='memmap')
print(dataset[0])

{'signal': memmap([[ -1,  -2,  -4, ..., -13, -16, -15],
        [ 15,  15,  13, ...,  -1,  -6,  -8],
        [ -1,   0,  -1, ...,  -6,  -8,  -6],
        ...,
        [ -8,  -5,  -5, ...,   3,   4,   7],
        [-42, -39, -19, ...,  15,  15,   1],
        [  0,   0,   0, ...,   0,   0,   0]]), 'age': 76, 'class_label': 1, 'metadata': {'serial': '00465', 'edfname': '00779511_050319', 'birth': '1942-09-08', 'record': '2019-03-05T08:58:32', 'age': 76, 'dx1': 'mci ef', 'label': ['mci', 'mci_amnestic', 'mci_amnestic_ef'], 'class_type': 'Non-vascular MCI', 'class_label': 1, 'event': [[0, 'Start Recording'], [0, 'New Montage - Montage 002'], [38730, 'Eyes Open'], [69992, 'Move'], [71302, 'Move'], [72792, 'Eyes Closed'], [73326, 'Move'], [74598, 'Eyes Open'], [76992, 'Eyes Closed'], [78716, 'Photic On - 3.0 Hz'], [79050, 'Eyes Open'], [79554, 'Eyes Closed'], [80046, 'Move'], [80732, 'Photic Off'], [82750, 'Photic On - 6.0 Hz'], [84766, 'Photic Off'], [86824, 'Photic On - 9.0 Hz'], [88844, 'Ph

#### Random crop

In [12]:
for i in range(2):
    dataset = CauEegDataset(root_path, metadata_train, load_event=True, 
                            transform=EegRandomCrop(crop_length=3))
    print(dataset[0])
    print('\n')
    print('-' * 100)
    print('\n')

{'signal': array([[ 30,  29,  31],
       [  4,   6,   7],
       [ -4,  -3,  -5],
       [ -2,   1,   0],
       [ -1,   4,   6],
       [ 17,  18,  16],
       [ -3,  -5,  -6],
       [ -9,  -9,  -9],
       [ -1,   4,   3],
       [ -6,   1,   3],
       [ 17,  17,  19],
       [  2,   2,   1],
       [ -4,  -3,  -1],
       [ -7,  -8,  -9],
       [-12, -11, -13],
       [ -4,   0,   0],
       [  1,   2,   3],
       [ -6,  -4,  -4],
       [  3,   5,   3],
       [-27, -27, -29],
       [157, 174, 145]]), 'age': 76, 'class_label': 1, 'metadata': {'serial': '00465', 'edfname': '00779511_050319', 'birth': '1942-09-08', 'record': '2019-03-05T08:58:32', 'age': 76, 'dx1': 'mci ef', 'label': ['mci', 'mci_amnestic', 'mci_amnestic_ef'], 'class_type': 'Non-vascular MCI', 'class_label': 1, 'event': [[0, 'Start Recording'], [0, 'New Montage - Montage 002'], [38730, 'Eyes Open'], [69992, 'Move'], [71302, 'Move'], [72792, 'Eyes Closed'], [73326, 'Move'], [74598, 'Eyes Open'], [76992, 'Eyes Cl

In [13]:
for i in range(2):
    dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                            transform=EegRandomCrop(crop_length=3, multiple=2))
    print(dataset[0]['signal'])
    print('\n')
    print('-' * 100)
    print('\n')

[array([[-23, -23, -20],
       [-13, -10,  -7],
       [  4,   8,   8],
       [  4,   6,   5],
       [  7,   9,   7],
       [-21, -20, -18],
       [-12,  -8,  -4],
       [  6,   9,  10],
       [  7,  10,   8],
       [ 11,  11,   7],
       [  9,   7,   8],
       [  4,   6,   6],
       [ -6,  -4,  -1],
       [-21, -20, -21],
       [-19, -21, -25],
       [ 10,   9,   5],
       [-19, -15, -11],
       [ -1,   3,   5],
       [  8,  11,   8],
       [-52, -53, -49],
       [ -1,   0,  -1]]), array([[-142, -139, -148],
       [ -10,   -7,  -13],
       [  14,   12,    9],
       [  13,   13,   11],
       [  15,   15,   14],
       [ -96,  -93,  -92],
       [  -9,   -9,  -11],
       [   4,    5,    5],
       [   6,    9,    9],
       [   8,   12,   12],
       [ -31,  -20,  -20],
       [  18,   16,   11],
       [  16,   17,   16],
       [   9,   11,    9],
       [  30,   32,   30],
       [  13,   16,   18],
       [ -26,  -23,  -26],
       [  -3,    0,   -2],
       

#### Drop EKG channel

In [14]:
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=None)
print('before:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

print()
print('-' * 100)
print()

dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=EegDropEKGChannel())
print('after:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

before: (21, 179800)
[[ -1  -2  -4 ... -13 -16 -15]
 [ 15  15  13 ...  -1  -6  -8]
 [ -1   0  -1 ...  -6  -8  -6]
 ...
 [ -8  -5  -5 ...   3   4   7]
 [-42 -39 -19 ...  15  15   1]
 [  0   0   0 ...   0   0   0]]

----------------------------------------------------------------------------------------------------

after: (20, 179800)
[[ -1  -2  -4 ... -13 -16 -15]
 [ 15  15  13 ...  -1  -6  -8]
 [ -1   0  -1 ...  -6  -8  -6]
 ...
 [ -8  -7  -7 ...   9   5   3]
 [ -8  -5  -5 ...   3   4   7]
 [  0   0   0 ...   0   0   0]]


#### Drop photic stimulation channel

In [15]:
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=None)
print('before:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

print()
print('-' * 100)
print()

dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=EegDropPhoticChannel())
print('after:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

before: (21, 179800)
[[ -1  -2  -4 ... -13 -16 -15]
 [ 15  15  13 ...  -1  -6  -8]
 [ -1   0  -1 ...  -6  -8  -6]
 ...
 [ -8  -5  -5 ...   3   4   7]
 [-42 -39 -19 ...  15  15   1]
 [  0   0   0 ...   0   0   0]]

----------------------------------------------------------------------------------------------------

after: (20, 179800)
[[ -1  -2  -4 ... -13 -16 -15]
 [ 15  15  13 ...  -1  -6  -8]
 [ -1   0  -1 ...  -6  -8  -6]
 ...
 [ -8  -7  -7 ...   9   5   3]
 [ -8  -5  -5 ...   3   4   7]
 [-42 -39 -19 ...  15  15   1]]


#### To Tensor

In [16]:
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=None)
print('before:')
print(dataset[0])

print()
print('-' * 100)
print()

dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=EegToTensor())
print('after:')
print(dataset[0])

before:
{'signal': array([[ -1,  -2,  -4, ..., -13, -16, -15],
       [ 15,  15,  13, ...,  -1,  -6,  -8],
       [ -1,   0,  -1, ...,  -6,  -8,  -6],
       ...,
       [ -8,  -5,  -5, ...,   3,   4,   7],
       [-42, -39, -19, ...,  15,  15,   1],
       [  0,   0,   0, ...,   0,   0,   0]]), 'age': 76, 'class_label': 1, 'metadata': {'serial': '00465', 'edfname': '00779511_050319', 'birth': '1942-09-08', 'record': '2019-03-05T08:58:32', 'age': 76, 'dx1': 'mci ef', 'label': ['mci', 'mci_amnestic', 'mci_amnestic_ef'], 'class_type': 'Non-vascular MCI', 'class_label': 1}}

----------------------------------------------------------------------------------------------------

after:
{'signal': tensor([[ -1.,  -2.,  -4.,  ..., -13., -16., -15.],
        [ 15.,  15.,  13.,  ...,  -1.,  -6.,  -8.],
        [ -1.,   0.,  -1.,  ...,  -6.,  -8.,  -6.],
        ...,
        [ -8.,  -5.,  -5.,  ...,   3.,   4.,   7.],
        [-42., -39., -19.,  ...,  15.,  15.,   1.],
        [  0.,   0.,   0.,  

#### Normalization per signal

In [17]:
composed = transforms.Compose([EegToTensor(), EegNormalizePerSignal()])
dataset = CauEegDataset(root_path, metadata_train, load_event=False, transform=composed)

print(dataset[0])

print()
print('-' * 100)
print()

print('Mean:', torch.mean(dataset[0]['signal'], axis=1))
print('Std:', torch.std(dataset[0]['signal'], axis=1))

{'signal': tensor([[-2.9969e-02, -6.0315e-02, -1.2101e-01,  ..., -3.9413e-01,
         -4.8517e-01, -4.5482e-01],
        [ 9.0846e-01,  9.0846e-01,  7.8735e-01,  ..., -6.0414e-02,
         -3.6319e-01, -4.8430e-01],
        [-1.3070e-01,  2.2208e-04, -1.3070e-01,  ..., -7.8528e-01,
         -1.0471e+00, -7.8528e-01],
        ...,
        [-1.2635e+00, -7.9031e-01, -7.9031e-01,  ...,  4.7165e-01,
          6.2940e-01,  1.1026e+00],
        [-3.2224e-01, -2.9923e-01, -1.4587e-01,  ...,  1.1485e-01,
          1.1485e-01,  7.4959e-03],
        [ 7.0046e-05,  7.0046e-05,  7.0046e-05,  ...,  7.0046e-05,
          7.0046e-05,  7.0046e-05]]), 'age': tensor(76.), 'class_label': tensor(1), 'metadata': {'serial': '00465', 'edfname': '00779511_050319', 'birth': '1942-09-08', 'record': '2019-03-05T08:58:32', 'age': 76, 'dx1': 'mci ef', 'label': ['mci', 'mci_amnestic', 'mci_amnestic_ef'], 'class_type': 'Non-vascular MCI', 'class_label': 1}}

---------------------------------------------------------

#### Age normalization

In [18]:
ages = []
for m in metadata_train:
    ages.append(m['age'])

ages = np.array(ages)
age_mean = np.mean(ages)
age_std = np.std(ages)

print('Age mean and standard deviation:\t', age_mean, ',\t', age_std)

print()
print('-' * 100)
print()

print('before:')
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=None)
for i in range(5):
    print(dataset[i]['age'])

print()
print('-' * 100)
print()

print('after:')
composed = transforms.Compose([EegToTensor(), EegNormalizeAge(mean=age_mean, std=age_std)])
dataset = CauEegDataset(root_path, metadata_train, load_event=False, transform=composed)
for i in range(5):
    print(dataset[i]['age'])

Age mean and standard deviation:	 70.16586538461539 ,	 9.861365378877553

----------------------------------------------------------------------------------------------------

before:
76
69
77
65
85

----------------------------------------------------------------------------------------------------

after:
tensor(0.5916)
tensor(-0.1182)
tensor(0.6930)
tensor(-0.5238)
tensor(1.5043)


#### Short time Fourier transform (STFT or spectrogram)

In [19]:
composed = transforms.Compose([EegToTensor(), EegSpectrogram(n_fft=200, complex_mode='as_real')])
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

print()
print('-' * 100)
print()

composed = transforms.Compose([EegToTensor(), EegSpectrogram(n_fft=200, complex_mode='power')])
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

print()
print('-' * 100)
print()

composed = transforms.Compose([EegToTensor(), EegSpectrogram(n_fft=200, complex_mode='remove')])
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

torch.Size([42, 101, 3597]) torch.float32 <class 'torch.Tensor'>
tensor([[ 1.8630e+03, -3.6824e+01, -4.5683e+01,  ...,  4.0228e-01,
         -1.9465e-01,  3.0000e+00],
        [ 5.3000e+01, -3.8482e+01,  7.6352e+01,  ...,  1.6589e+00,
          1.8816e+00, -3.0000e+00],
        [ 3.0700e+02, -6.2419e+00, -8.3077e+01,  ...,  6.0720e+00,
         -1.5512e+00, -1.0000e+00],
        ...,
        [ 0.0000e+00, -9.4149e+01, -1.6262e+02,  ..., -1.1109e+00,
          1.9103e+00,  0.0000e+00],
        [ 0.0000e+00, -3.4713e+02,  2.0332e+03,  ...,  4.0216e+00,
          1.5967e-01,  0.0000e+00],
        [ 0.0000e+00,  5.0961e+00, -2.7539e+01,  ...,  8.1982e-01,
          2.3206e+00,  0.0000e+00]])

----------------------------------------------------------------------------------------------------

torch.Size([21, 101, 3597]) torch.float32 <class 'torch.Tensor'>
tensor([[1.8630e+03, 3.9918e+01, 1.1786e+02,  ..., 4.5592e-01, 2.6543e+00,
         3.0000e+00],
        [5.3000e+01, 4.6436e+02, 2.178

#### Transform Composition

In [20]:
composed = transforms.Compose([EegDropPhoticChannel(), 
                               EegRandomCrop(crop_length=200*60), # 1 minute
                               EegToTensor(), 
                               EegNormalizeAge(mean=age_mean, std=age_std), 
                               EegNormalizePerSignal()])
train_dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                              file_format='memmap', transform=composed)
print(train_dataset[0]['signal'].shape)
print(train_dataset[0])

print()
print('-' * 100)
print()

composed_np = transforms.Compose([EegDropPhoticChannel(), 
                                  EegRandomCrop(crop_length=200*60), # 1 minute
                                  EegToTensor()])
composed_pt = transforms.Compose([EegNormalizeAge(mean=age_mean, std=age_std), 
                                  EegNormalizePerSignal()])
composed_pt = torch.nn.Sequential(*composed_pt.transforms)

train_dataset = CauEegDataset(root_path, metadata_train, 
                              load_event=False, transform=composed)

sample = composed_pt(train_dataset[0])
print(sample['signal'].shape)
print(sample)

torch.Size([20, 12000])
{'signal': tensor([[-1.1198, -1.0176, -0.5067,  ..., -0.9155, -1.2220, -1.0176],
        [-0.3820, -0.0156,  0.2287,  ..., -1.7256, -1.9699, -1.9699],
        [ 0.1712,  0.8142,  0.9749,  ..., -0.6324, -0.7932, -0.6324],
        ...,
        [ 0.1425,  0.2934,  0.2934,  ...,  0.8967,  0.8967,  0.8967],
        [-0.1751,  0.0038,  0.0038,  ...,  0.0038,  0.1828,  0.5407],
        [-0.3319, -0.4575, -0.5920,  ..., -0.4306, -0.5203, -0.6369]]), 'age': tensor(0.5916), 'class_label': tensor(1), 'metadata': {'serial': '00465', 'edfname': '00779511_050319', 'birth': '1942-09-08', 'record': '2019-03-05T08:58:32', 'age': 76, 'dx1': 'mci ef', 'label': ['mci', 'mci_amnestic', 'mci_amnestic_ef'], 'class_type': 'Non-vascular MCI', 'class_label': 1}}

----------------------------------------------------------------------------------------------------

torch.Size([20, 12000])
{'signal': tensor([[-2.0040, -2.0040, -1.5269,  ...,  0.3815,  0.7632,  0.9540],
        [-0.2464, -0.

#### Train, Validation, Test Datasets

In [21]:
composed = transforms.Compose([EegDropPhoticChannel(),
                               EegRandomCrop(crop_length=200*60), # 1 minute
                               EegToTensor()])

train_dataset = CauEegDataset(root_path, metadata_train, 
                              load_event=False, transform=composed)
val_dataset = CauEegDataset(root_path, metadata_val, 
                            load_event=False, transform=composed)
test_dataset = CauEegDataset(root_path, metadata_test, 
                             load_event=False, transform=composed)

print(train_dataset[0]['signal'][0].shape)
print(torch.mean(train_dataset[0]['signal'][0], axis=-1))
print(torch.std(train_dataset[0]['signal'][0], axis=-1))

print()
print('-' * 100)
print()

print(train_dataset[0]['signal'][1].shape)
print(torch.mean(train_dataset[0]['signal'][1], axis=-1))
print(torch.std(train_dataset[0]['signal'][1], axis=-1))

print()
print('-' * 100)
print()

print(train_dataset[0]['signal'][1].shape)
print(train_dataset[0])

print()
print('-' * 100)
print()

print(val_dataset[0])

print()
print('-' * 100)
print()

print(test_dataset[0])

torch.Size([12000])
tensor(-0.8292)
tensor(50.1309)

----------------------------------------------------------------------------------------------------

torch.Size([12000])
tensor(-0.1476)
tensor(7.8639)

----------------------------------------------------------------------------------------------------

torch.Size([12000])
{'signal': tensor([[-23., -24., -20.,  ..., -28., -17., -38.],
        [-16., -11.,  -4.,  ...,   0., -12.,  -5.],
        [  4.,   7.,  11.,  ...,   5.,   1.,   9.],
        ...,
        [-13., -13., -10.,  ...,   3.,   8.,  13.],
        [  9.,   8.,   8.,  ...,   1.,   6.,  12.],
        [570., 456., 179.,  ..., -85., -75., -82.]]), 'age': tensor(76.), 'class_label': tensor(1), 'metadata': {'serial': '00465', 'edfname': '00779511_050319', 'birth': '1942-09-08', 'record': '2019-03-05T08:58:32', 'age': 76, 'dx1': 'mci ef', 'label': ['mci', 'mci_amnestic', 'mci_amnestic_ef'], 'class_type': 'Non-vascular MCI', 'class_label': 1}}

----------------------------------

In [22]:
composed = transforms.Compose([EegNormalizeAge(mean=age_mean, std=age_std),
                               EegDropPhoticChannel(),
                               EegRandomCrop(crop_length=200*60), # 1 minutes
                               EegToTensor(),
                               EegNormalizePerSignal(),
                               EegSpectrogram(n_fft=200, complex_mode='power', hop_length=200 // 2)])
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

torch.Size([20, 101, 121]) torch.float32 <class 'torch.Tensor'>
tensor([[1.0207e+01, 6.6656e+01, 1.4424e+01,  ..., 8.1594e-01, 7.3818e-01,
         9.4948e-01],
        [4.6719e+01, 6.2834e+01, 5.0437e+01,  ..., 1.0002e+00, 7.8660e-01,
         2.6535e-01],
        [5.1688e+01, 5.4097e+01, 1.7246e+01,  ..., 3.1529e-01, 1.2848e-01,
         1.9073e-06],
        ...,
        [6.1089e+01, 3.9688e+01, 2.0640e+01,  ..., 3.5023e-01, 6.5073e-01,
         9.0234e-01],
        [2.4466e+01, 5.5986e+01, 7.2037e+01,  ..., 1.8135e-01, 8.2947e-01,
         7.2446e-01],
        [6.2215e+00, 2.0280e+01, 1.2054e+01,  ..., 1.2320e-01, 1.7671e-01,
         2.1417e-01]])


#### Data loader test

In [23]:
%%time

print('Current PyTorch device:', device)
if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed = transforms.Compose([
    EegDropPhoticChannel(),
    EegRandomCrop(crop_length=200*60), # 20 seconds
    EegToTensor(),
    EegNormalizeAge(mean=age_mean, std=age_std),
    EegNormalizePerSignal(),
])

train_dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                              file_format='feather', transform=composed)

train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        sample_batched['signal'].to(device)
        sample_batched['age'].to(device)
        sample_batched['class_label'].to(device)
        
# pprint.pprint(sample_batched)

Current PyTorch device: cuda
CPU times: total: 5min 34s
Wall time: 48.8 s


In [24]:
%%time

print('Current PyTorch device:', device)
if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed = transforms.Compose([
    EegDropPhoticChannel(),
    EegRandomCrop(crop_length=200*60), # 20 seconds
    EegToTensor(),
    EegNormalizeAge(mean=age_mean, std=age_std),
    EegNormalizePerSignal(),
])

train_dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                              file_format='memmap', transform=composed)

train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        sample_batched['signal'].to(device)
        sample_batched['age'].to(device)
        sample_batched['class_label'].to(device)
        
# pprint.pprint(sample_batched)

Current PyTorch device: cuda
CPU times: total: 1min 22s
Wall time: 22.8 s


In [25]:
%%time

print('Current PyTorch device:', device)
if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed_np = transforms.Compose([EegDropPhoticChannel(), 
                                  EegRandomCrop(crop_length=200*60), # 20 seconds
                                  EegToTensor()])
preprocess = transforms.Compose([EegNormalizePerSignal(), 
                                 EegNormalizeAge(mean=age_mean, std=age_std)])
preprocess = torch.nn.Sequential(*preprocess.transforms)

train_dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                              file_format='memmap', transform=composed_np)

train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        # pull up the batch data
        x = sample_batched['signal'] = sample_batched['signal'].to(device)
        age = sample_batched['age'] = sample_batched['age'].to(device)
        target = sample_batched['class_label'] = sample_batched['class_label'].to(device)
        
        preprocess(sample_batched)
        
# pprint.pprint(sample_batched)

Current PyTorch device: cuda
CPU times: total: 57.1 s
Wall time: 18.4 s


In [35]:
%%time

print('Current PyTorch device:', device)
if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed_np = transforms.Compose([EegDropPhoticChannel(), 
                                  EegRandomCrop(crop_length=200*60), # 20 seconds
                                  EegToTensor()])
preprocess = transforms.Compose([EegNormalizePerSignal(), 
                                 EegNormalizeAge(mean=age_mean, std=age_std), 
                                 EegToDevice(device=device)])
preprocess = torch.nn.Sequential(*preprocess.transforms)

train_dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                              file_format='memmap', transform=composed_np)

train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        preprocess(sample_batched)
        
print(sample_batched)

Current PyTorch device: cuda
{'signal': tensor([[[-1.5289e-01, -1.5289e-01, -1.6889e-01,  ...,  1.2236e+00,
           1.4156e+00,  1.5597e+00],
         [ 1.1189e+00,  1.0120e+00,  9.5855e-01,  ...,  7.4469e-01,
           7.9816e-01,  8.5162e-01],
         [ 1.6361e+00,  1.9089e+00,  1.9089e+00,  ..., -8.1911e-01,
          -9.5551e-01, -9.5551e-01],
         ...,
         [-3.4103e-02, -3.4103e-02, -2.1034e-01,  ...,  1.4214e-01,
          -3.4103e-02, -2.1034e-01],
         [ 1.3943e+00,  1.5690e+00,  1.0451e+00,  ..., -5.2646e-01,
          -1.2249e+00, -1.3996e+00],
         [ 1.6148e+00,  1.5080e+00,  1.4813e+00,  ..., -3.4402e-02,
          -6.1109e-02, -4.7755e-02]],

        [[ 7.9576e-01,  1.1035e+00,  1.5650e+00,  ..., -4.3508e-01,
          -3.0320e-01, -3.9112e-01],
         [-1.4294e+00, -1.1375e+00, -4.8064e-01,  ...,  2.4917e-01,
           3.2215e-01,  1.7619e-01],
         [-2.5128e+00, -2.5128e+00, -2.3635e+00,  ..., -2.7296e-01,
           2.5683e-02,  2.5683e-02],

In [59]:
age_means = []
age_stds = []

for sample in train_loader:
    age = sample['age']
    std, mean = torch.std_mean(age, dim=-1, keepdims=True)
    age_means.append(mean)
    age_stds.append(std)

age_mean = torch.mean(torch.cat(age_means, dim=0), dim=0, keepdims=True)
age_std = torch.mean(torch.cat(age_stds, dim=0), dim=0, keepdims=True)

print(age_mean, age_std)

tensor([70.1659]) tensor([9.7979])


In [60]:
age_means = []
age_stds = []

for sample in train_loader:
    age = sample['age']
    age = (age - age_mean) / (age_std + 1e-8)
    std, mean = torch.std_mean(age, dim=-1, keepdims=True)
    age_means.append(mean)
    age_stds.append(std)

age_mean = torch.mean(torch.cat(age_means, dim=0), dim=0, keepdims=True)
age_std = torch.mean(torch.cat(age_stds, dim=0), dim=0, keepdims=True)

print(age_mean, age_std)

tensor([2.4071e-07]) tensor([0.9983])
