# Dataset

PyTorch의 EEG 데이터를 Dataset class 및 DataLoader class로 처리해보는 노트북

-----

## 환경 구성

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
# Load some packages
import os
import glob
import json

import matplotlib.pyplot as plt
import pprint

import numpy as np
import random
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

# custom package
from datasets.cau_eeg_dataset import *
from datasets.pipeline import *

In [3]:
# Other settings
%matplotlib inline
%config InlineBackend.figure_format = 'retina' # cleaner text

plt.style.use('default') 
# ['Solarize_Light2', '_classic_test_patch', 'bmh', 'classic', 'dark_background', 'fast', 
#  'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn', 'seaborn-bright', 'seaborn-colorblind', 
#  'seaborn-dark', 'seaborn-dark-palette', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 
#  'seaborn-notebook', 'seaborn-paper', 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 
#  'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'tableau-colorblind10']

plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams["font.family"] = 'NanumGothic' # for Hangul in Windows

In [4]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.11.0
cuda is available.


In [5]:
# Data file path
root_path = r'local/dataset/02_Curated_Data_220317/' # 02_Curated_Data_210705

In [6]:
meta_path = os.path.join(root_path, 'metadata_debug.json')
with open(meta_path, 'r') as json_file:
    metadata = json.load(json_file)

pprint.pprint(metadata[0])

{'age': 78,
 'birth': '1940-06-02',
 'dx1': 'mci_rf',
 'edfname': '00001809_261018',
 'label': ['mci', 'mci_amnestic', 'mci_amnestic_rf'],
 'record': '2018-10-26T15:46:26',
 'serial': '00001'}


-----

## Data Filtering by Diagnosis

#### Non-Vascular Dementia, Non-Vascular MCI, Normal

In [7]:
diagnosis_filter = [
    # Normal
    {'type': 'Normal',
     'include': ['normal'], 
     'exclude': []},
    # Non-vascular MCI
    {'type': 'Non-vascular MCI',
     'include': ['mci'], 
     'exclude': ['mci_vascular']},
    # Non-vascular dementia
    {'type': 'Non-vascular dementia',
     'include': ['dementia'], 
     'exclude': ['vd']},
]

def generate_class_label(label):
    for c, f in enumerate(diagnosis_filter):
        # inc = set(f['include']) & set(label) == set(f['include'])
        inc = len(set(f['include']) & set(label)) > 0        
        exc = len(set(f['exclude']) & set(label)) == 0
        if  inc and exc:
            return (c, f['type'])
    return (-1, 'The others')

class_label_to_type = [d_f['type'] for d_f in diagnosis_filter]
print('class_label_to_type:', class_label_to_type)

class_label_to_type: ['Normal', 'Non-vascular MCI', 'Non-vascular dementia']


In [8]:
splitted_metadata = [[] for i in diagnosis_filter]

for m in metadata:
    c, n = generate_class_label(m['label'])
    if c >= 0:
        m['class_type'] = n
        m['class_label'] = c
        splitted_metadata[c].append(m)
        
for i, split in enumerate(splitted_metadata):
    if len(split) == 0:
        print(f'(Warning) Split group {i} has no data.')
    else:
        print(f'- There are {len(split):} data belonging to {split[0]["class_type"]}')

- There are 458 data belonging to Normal
- There are 350 data belonging to Non-vascular MCI
- There are 233 data belonging to Non-vascular dementia


-----

## Configure the Train, Validation, and Test Splits

#### Split the filtered dataset and shuffle them

In [9]:
# Train : Val : Test = 8 : 1 : 1
ratio1 = 0.8
ratio2 = 0.1

metadata_train = []
metadata_val = []
metadata_test = []

for split in splitted_metadata:
    random.shuffle(split)
    
    n1 = round(len(split) * ratio1)
    n2 = n1 + round(len(split) * ratio2)

    metadata_train.extend(split[:n1])
    metadata_val.extend(split[n1:n2])
    metadata_test.extend(split[n2:])

random.shuffle(metadata_train)
random.shuffle(metadata_val)
random.shuffle(metadata_test)

print('Train data size\t\t:', len(metadata_train))
print('Validation data size\t:', len(metadata_val))
print('Test data size\t\t:', len(metadata_test))

print('\n', '--- Recheck ---', '\n')
train_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_train:
    train_class_nums[m['class_label']] += 1

val_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_val:
    val_class_nums[m['class_label']] += 1

test_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_test:
    test_class_nums[m['class_label']] += 1

print('Train data label distribution\t:', train_class_nums, train_class_nums.sum())
print('Val data label distribution\t:', val_class_nums, val_class_nums.sum())
print('Test data label distribution\t:', test_class_nums, test_class_nums.sum())

Train data size		: 832
Validation data size	: 104
Test data size		: 105

 --- Recheck --- 

Train data label distribution	: [366 280 186] 832
Val data label distribution	: [46 35 23] 104
Test data label distribution	: [46 35 24] 105


-----

## Test TorchVision Transform

#### Random crop

In [10]:
for i in range(2):
    dataset = CauEegDataset(root_path, metadata_train, load_event=True, 
                            transform=EegRandomCrop(crop_length=3))
    print(dataset[0])
    print('\n')
    print('-' * 100)
    print('\n')

{'signal': array([[-27, -28, -26],
       [  5,   4,   6],
       [ -8,  -8,  -8],
       [ 13,  14,  15],
       [  6,   8,  10],
       [-40, -42, -44],
       [-15, -10, -13],
       [ 18,  15,   6],
       [  6,   5,   2],
       [ 13,  13,  13],
       [ -3,  -2,   0],
       [  4,   6,  13],
       [  5,  11,  18],
       [-20, -22, -23],
       [ 17,   9,  11],
       [ 12,  12,  12],
       [-49, -51, -52],
       [  0,   0,  -2],
       [ -4,  -4,  -6],
       [ 12,  30,  17],
       [ -1,   0,   0]]), 'age': 76, 'class_label': 2, 'metadata': {'serial': '01203', 'edfname': '01312293_120417', 'birth': '1941-02-26', 'record': '2017-04-12T13:48:31', 'age': 76, 'dx1': 'load', 'label': ['dementia', 'ad', 'load'], 'class_type': 'Non-vascular dementia', 'class_label': 2, 'channel': ['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Phot

In [11]:
for i in range(2):
    dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                            transform=EegRandomCrop(crop_length=3, multiple=2))
    print(dataset[0]['signal'])
    print('\n')
    print('-' * 100)
    print('\n')

[array([[  3,   7,  10],
       [ 11,  13,  13],
       [  7,   9,   8],
       [ -6,  -6,  -6],
       [-15, -16, -16],
       [ 16,  16,  17],
       [  2,   2,   4],
       [ 10,  11,  16],
       [ -7,  -6,  -5],
       [ -5,  -4,  -4],
       [ 78,  80,  79],
       [-13, -16, -17],
       [-11, -14, -15],
       [-13, -14, -15],
       [  0,  -8, -12],
       [  1,   0,  -2],
       [-19, -17, -17],
       [-13, -10,  -8],
       [ -6,  -4,  -3],
       [ 17,  42,  65],
       [  1,   1,   0]]), array([[  3,   4,   3],
       [  7,   7,   7],
       [-16, -12, -10],
       [-21, -19, -17],
       [ -5,  -6,  -4],
       [  3,   3,   2],
       [  6,   7,   6],
       [ -5,  -7,  -7],
       [ -1,  -1,  -1],
       [  5,   6,   5],
       [  9,   8,   7],
       [  4,   0,   1],
       [ -8, -11, -11],
       [ 18,  17,  15],
       [ 10,  12,   9],
       [  1,  -2,  -4],
       [-13, -11, -10],
       [ 10,  12,  14],
       [ -2,  -1,   1],
       [108,  96, 101],
       [  0, 

#### Normalization per signal

In [12]:
dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                        transform=EegNormalizePerSignal())

print(dataset[0])

print()
print('-' * 100)
print()

print('Mean:', np.mean(dataset[0]['signal'], axis=1))
print('Std:', np.std(dataset[0]['signal'], axis=1))

{'signal': array([[ 1.02104732e+00,  9.58473534e-01,  9.58473534e-01, ...,
        -4.27070916e-02, -2.18491620e-02, -4.27070916e-02],
       [ 1.78706960e+00,  1.66245676e+00,  2.20244571e+00, ...,
         7.07091698e-01,  8.73242144e-01,  7.90166921e-01],
       [ 3.34353951e+00,  3.46288107e+00,  3.22419795e+00, ...,
        -1.01242755e+00, -1.01242755e+00, -9.52756771e-01],
       ...,
       [ 2.67948738e+00,  2.84720273e+00,  3.01491807e+00, ...,
        -6.74819496e-01, -5.07104152e-01, -5.07104152e-01],
       [ 1.03733280e+00,  7.56602549e-01,  3.41610004e-01, ...,
        -8.55882048e-02,  2.42627630e-02,  9.74967416e-02],
       [ 1.04657419e-02,  1.04657419e-02,  7.36901822e-05, ...,
         7.36901822e-05,  1.56617677e-02,  1.56617677e-02]]), 'age': 76, 'class_label': 2, 'metadata': {'serial': '01203', 'edfname': '01312293_120417', 'birth': '1941-02-26', 'record': '2017-04-12T13:48:31', 'age': 76, 'dx1': 'load', 'label': ['dementia', 'ad', 'load'], 'class_type': 'Non-va

#### Age normalization

In [13]:
ages = []
for m in metadata_train:
    ages.append(m['age'])

ages = np.array(ages)
age_mean = np.mean(ages)
age_std = np.std(ages)

print('Age mean and standard deviation:\t', age_mean, ',\t', age_std)

print()
print('-' * 100)
print()

print('before:')
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=None)
for i in range(5):
    print(dataset[i]['age'])

print()
print('-' * 100)
print()

print('after:')
dataset = CauEegDataset(root_path, metadata_train, load_event=False, 
                        transform=EegNormalizeAge(mean=age_mean, std=age_std))
for i in range(5):
    print(dataset[i]['age'])

Age mean and standard deviation:	 70.12860576923077 ,	 9.986242891827843

----------------------------------------------------------------------------------------------------

before:
76
83
84
87
74

----------------------------------------------------------------------------------------------------

after:
0.5879482692829902
1.288912592784349
1.389050353284543
1.6894636347851253
0.387672748282602


#### Drop EKG channel

In [14]:
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=None)
print('before:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

print()
print('-' * 100)
print()

dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=EegDropEKGChannel())
print('after:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

before: (21, 173600)
[[ 49  46  46 ...  -2  -1  -2]
 [ 43  40  53 ...  17  21  19]
 [ 56  58  54 ... -17 -17 -16]
 ...
 [ 16  17  18 ...  -4  -3  -3]
 [ 85  62  28 ...  -7   2   8]
 [  2   2   0 ...   0   3   3]]

----------------------------------------------------------------------------------------------------

after: (20, 173600)
[[ 49  46  46 ...  -2  -1  -2]
 [ 43  40  53 ...  17  21  19]
 [ 56  58  54 ... -17 -17 -16]
 ...
 [ 24  25  27 ... -13 -12 -10]
 [ 16  17  18 ...  -4  -3  -3]
 [  2   2   0 ...   0   3   3]]


#### Drop photic stimulation channel

In [15]:
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=None)
print('before:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

print()
print('-' * 100)
print()

dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=EegDropPhoticChannel())
print('after:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

before: (21, 173600)
[[ 49  46  46 ...  -2  -1  -2]
 [ 43  40  53 ...  17  21  19]
 [ 56  58  54 ... -17 -17 -16]
 ...
 [ 16  17  18 ...  -4  -3  -3]
 [ 85  62  28 ...  -7   2   8]
 [  2   2   0 ...   0   3   3]]

----------------------------------------------------------------------------------------------------

after: (20, 173600)
[[ 49  46  46 ...  -2  -1  -2]
 [ 43  40  53 ...  17  21  19]
 [ 56  58  54 ... -17 -17 -16]
 ...
 [ 24  25  27 ... -13 -12 -10]
 [ 16  17  18 ...  -4  -3  -3]
 [ 85  62  28 ...  -7   2   8]]


#### To Tensor

In [16]:
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=None)
print('before:')
print(dataset[0])

print()
print('-' * 100)
print()

dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=EegToTensor())
print('after:')
print(dataset[0])

before:
{'signal': array([[ 49,  46,  46, ...,  -2,  -1,  -2],
       [ 43,  40,  53, ...,  17,  21,  19],
       [ 56,  58,  54, ..., -17, -17, -16],
       ...,
       [ 16,  17,  18, ...,  -4,  -3,  -3],
       [ 85,  62,  28, ...,  -7,   2,   8],
       [  2,   2,   0, ...,   0,   3,   3]]), 'age': 76, 'class_label': 2, 'metadata': {'serial': '01203', 'edfname': '01312293_120417', 'birth': '1941-02-26', 'record': '2017-04-12T13:48:31', 'age': 76, 'dx1': 'load', 'label': ['dementia', 'ad', 'load'], 'class_type': 'Non-vascular dementia', 'class_label': 2, 'channel': ['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']}}

----------------------------------------------------------------------------------------------------

after:
{'signal': tensor([[ 49.,  46.,  46.,  ...,  -2.,  -1.,  -2.],
        [ 43.,  40.,  53.,  ...,  17.,  

#### Short time Fourier transform (STFT or spectrogram)

In [17]:
composed = transforms.Compose([EegToTensor(), EegSpectrogram(n_fft=200, complex_mode='as_real')])
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

print()
print('-' * 100)
print()

composed = transforms.Compose([EegToTensor(), EegSpectrogram(n_fft=200, complex_mode='power')])
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

print()
print('-' * 100)
print()

composed = transforms.Compose([EegToTensor(), EegSpectrogram(n_fft=200, complex_mode='remove')])
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

torch.Size([42, 101, 3473]) torch.float32 <class 'torch.Tensor'>
tensor([[-8.8490e+03,  7.7273e+02,  3.9726e+01,  ..., -1.8506e+01,
         -1.6672e+01, -2.3000e+01],
        [-7.4850e+03,  2.8430e+02, -6.1568e+01,  ..., -5.0885e+00,
         -1.1210e+00,  5.0000e+00],
        [ 9.1100e+02, -4.7487e+01, -1.7347e+02,  ...,  5.7645e+00,
         -2.9938e+00,  3.0000e+00],
        ...,
        [ 0.0000e+00, -3.5455e+02, -2.7713e+02,  ...,  9.7990e-01,
         -6.4330e+00,  0.0000e+00],
        [ 0.0000e+00,  1.3124e+03, -8.3939e+02,  ..., -4.4814e+00,
         -1.3693e+00,  0.0000e+00],
        [ 0.0000e+00,  9.8977e+00, -1.1967e+01,  ...,  2.9848e+00,
          5.0672e-01,  0.0000e+00]])

----------------------------------------------------------------------------------------------------

torch.Size([21, 101, 3473]) torch.float32 <class 'torch.Tensor'>
tensor([[8.8490e+03, 1.3565e+03, 4.9238e+02,  ..., 1.8506e+01, 1.6763e+01,
         2.3000e+01],
        [7.4850e+03, 3.5771e+02, 1.783

#### Compose some at once

In [18]:
composed = transforms.Compose([EegNormalizeAge(mean=age_mean, std=age_std),
                               EegDropPhoticChannel(),
                               EegRandomCrop(crop_length=200*60), # 1 minute
                               EegNormalizePerSignal(),
                               EegToTensor()])

train_dataset = CauEegDataset(root_path, metadata_train, 
                              load_event=False, transform=composed)
val_dataset = CauEegDataset(root_path, metadata_val, 
                            load_event=False, transform=composed)
test_dataset = CauEegDataset(root_path, metadata_test, 
                             load_event=False, transform=composed)

print(train_dataset[0]['signal'].shape)
print(train_dataset[0])

print()
print('-' * 100)
print()

print(val_dataset[0]['signal'].shape)
print(val_dataset[0])

print()
print('-' * 100)
print()

print(test_dataset[0]['signal'].shape)
print(test_dataset[0])

torch.Size([20, 12000])
{'signal': tensor([[-1.1255, -1.1464, -1.1883,  ..., -0.4138, -0.4766, -0.4347],
        [-2.0817, -1.8845, -2.4270,  ..., -0.3556, -0.6022, -0.6515],
        [ 0.5920,  0.5278,  0.5920,  ...,  0.3351,  0.2708,  0.2066],
        ...,
        [ 0.4626,  0.4626,  0.3357,  ...,  1.3505,  1.3505,  1.3505],
        [ 2.3842,  2.5376,  2.3842,  ...,  0.8498,  0.8498,  0.6964],
        [ 0.4344,  0.5190,  0.6519,  ..., -0.6286, -0.7252, -0.5924]]), 'age': tensor(0.5879), 'class_label': tensor(2), 'metadata': {'serial': '01203', 'edfname': '01312293_120417', 'birth': '1941-02-26', 'record': '2017-04-12T13:48:31', 'age': 76, 'dx1': 'load', 'label': ['dementia', 'ad', 'load'], 'class_type': 'Non-vascular dementia', 'class_label': 2, 'channel': ['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']}}

-------------------

In [19]:
composed = transforms.Compose([EegNormalizeAge(mean=age_mean, std=age_std),
                               EegDropPhoticChannel(),
                               EegRandomCrop(crop_length=200*20, multiple=2), # 20 seconds
                               EegNormalizePerSignal(),
                               EegToTensor()
                              ])

train_dataset = CauEegDataset(root_path, metadata_train, 
                              load_event=False, transform=composed)
val_dataset = CauEegDataset(root_path, metadata_val, 
                            load_event=False, transform=composed)
test_dataset = CauEegDataset(root_path, metadata_test, 
                             load_event=False, transform=composed)

print(train_dataset[0]['signal'][0].shape)
print(torch.mean(train_dataset[0]['signal'][0], axis=-1))
print(torch.std(train_dataset[0]['signal'][0], axis=-1))

print()
print('-' * 100)
print()

print(train_dataset[0]['signal'][1].shape)
print(torch.mean(train_dataset[0]['signal'][1], axis=-1))
print(torch.std(train_dataset[0]['signal'][1], axis=-1))

print()
print('-' * 100)
print()

print(train_dataset[0]['signal'][1].shape)
print(train_dataset[0])

print()
print('-' * 100)
print()

print(val_dataset[0])

print()
print('-' * 100)
print()

print(test_dataset[0])

torch.Size([20, 4000])
tensor([-5.4836e-09,  2.4915e-08,  2.3842e-09,  1.5259e-08, -7.1526e-09,
         4.6492e-09, -1.0729e-08, -7.9274e-09, -5.2452e-09,  4.5300e-09,
         5.3048e-09,  2.3842e-09, -2.3842e-10, -1.9073e-09, -1.4305e-09,
        -3.3379e-09, -4.7684e-09,  1.0371e-08, -9.0599e-09,  7.1526e-10])
tensor([1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001,
        1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001,
        1.0001, 1.0001])

----------------------------------------------------------------------------------------------------

torch.Size([20, 4000])
tensor([-1.6689e-09,  1.4305e-09, -9.0599e-09, -1.2040e-08, -1.1086e-08,
        -3.5167e-09,  2.6226e-09, -9.0599e-09, -4.7684e-09,  9.5367e-10,
         4.0531e-09,  4.5300e-09,  1.1921e-08, -8.8215e-09,  7.8678e-09,
        -4.7684e-09,  8.5831e-09, -3.2187e-09, -1.1921e-09, -1.1921e-09])
tensor([1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001,


In [20]:
composed = transforms.Compose([EegNormalizeAge(mean=age_mean, std=age_std),
                               EegDropPhoticChannel(),
                               EegRandomCrop(crop_length=200*60), # 1 minutes
                               EegNormalizePerSignal(),
                               EegToTensor(),
                               EegSpectrogram(n_fft=200, complex_mode='power', hop_length=200 // 2)])
dataset = CauEegDataset(root_path, metadata_train, 
                        load_event=False, transform=composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

torch.Size([20, 101, 121]) torch.float32 <class 'torch.Tensor'>
tensor([[6.4232e+01, 2.0913e+01, 7.2749e+00,  ..., 3.5906e-01, 3.8525e-01,
         4.0100e-01],
        [3.7788e+02, 6.1324e+01, 1.9016e+01,  ..., 9.4126e-01, 7.6489e-01,
         7.4011e-01],
        [1.4823e+02, 2.5660e+01, 1.2974e+01,  ..., 8.1064e-01, 5.4665e-01,
         6.3792e-01],
        ...,
        [2.0942e+02, 1.9779e+01, 6.6603e+00,  ..., 5.0642e-01, 6.9401e-01,
         5.6524e-01],
        [2.1950e+01, 2.6849e+01, 1.7988e+01,  ..., 8.0672e-01, 9.7667e-01,
         2.0130e+00],
        [1.1486e+01, 2.4720e+01, 7.5637e+00,  ..., 3.3016e-01, 3.7622e-01,
         4.2370e-01]])


#### Data loader test

In [41]:
%%time

print('Current PyTorch device:', device)
if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed = transforms.Compose([
    EegRandomCrop(crop_length=200*20, multiple=2), # 20 seconds
    EegNormalizeAge(mean=age_mean, std=age_std),
    EegDropPhoticChannel(),
    EegNormalizePerSignal(),
    EegToTensor()
])

train_dataset = CauEegDataset(root_path, metadata_train, 
                              load_event=False, transform=composed)
val_dataset = CauEegDataset(root_path, metadata_val, 
                            load_event=False, transform=composed)
test_dataset = CauEegDataset(root_path, metadata_test, 
                             load_event=False, transform=composed)

train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        sample_batched['signal'].to(device)
        sample_batched['age'].to(device)
        sample_batched['class_label'].to(device)
        
# pprint.pprint(sample_batched)

Current PyTorch device: cuda
CPU times: total: 4min 45s
Wall time: 38.7 s


In [40]:
%%time

print('Current PyTorch device:', device)
if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed = transforms.Compose([
    EegRandomCrop(crop_length=200*20, multiple=2), # 20 seconds
    EegNormalizeAge(mean=age_mean, std=age_std),
    EegDropPhoticChannel(),
    EegNormalizePerSignal(),
    EegToTensor()
])

train_dataset = CauEegDataset(root_path, metadata_train, 
                              load_event=True, transform=composed)
val_dataset = CauEegDataset(root_path, metadata_val, 
                            load_event=True, transform=composed)
test_dataset = CauEegDataset(root_path, metadata_test, 
                             load_event=True, transform=composed)

train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(5):
    for i_batch, sample_batched in enumerate(train_loader):
        sample_batched['signal'].to(device)
        sample_batched['age'].to(device)
        sample_batched['class_label'].to(device)
        
# pprint.pprint(sample_batched)

Current PyTorch device: cuda
CPU times: total: 5min 2s
Wall time: 41.5 s


#### Train, validation, test dataloaders

In [23]:
train_loader = DataLoader(train_dataset, 
                          batch_size=32, 
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

val_loader = DataLoader(val_dataset, 
                        batch_size=32, 
                        shuffle=False, 
                        drop_last=False,
                        num_workers=num_workers, 
                        pin_memory=pin_memory,
                        collate_fn=eeg_collate_fn)

test_loader = DataLoader(test_dataset, 
                         batch_size=32, 
                         shuffle=False, 
                         drop_last=False,
                         num_workers=num_workers, 
                         pin_memory=pin_memory,
                         collate_fn=eeg_collate_fn)

In [24]:
for batch_i, sample_batched in enumerate(train_loader):
    # pull up the batch data
    x = sample_batched['signal'].to(device)
    age = sample_batched['age'].to(device)
    target = sample_batched['class_label'].to(device)
    
    print(x)
    print(age)
    print(target)
    
    break

tensor([[[ 2.5979e+00,  2.7526e+00,  2.8651e+00,  ...,  1.9264e-01,
           1.9264e-01,  1.6451e-01],
         [ 1.8266e+00,  1.9315e+00,  2.1762e+00,  ...,  2.1836e-01,
           1.8340e-01,  1.1348e-01],
         [-3.1936e-01, -1.3874e-01,  4.0313e-01,  ..., -1.0419e+00,
          -1.0419e+00, -9.5154e-01],
         ...,
         [-1.7913e+00, -1.7913e+00, -1.6425e+00,  ..., -1.0472e+00,
          -8.9843e-01, -7.4962e-01],
         [-2.1856e+00, -2.2934e+00, -2.2934e+00,  ..., -1.5386e+00,
          -1.5386e+00, -1.4307e+00],
         [-3.3934e-01, -3.3934e-01, -6.7292e-01,  ...,  1.3286e+00,
          -3.3934e-01, -1.3401e+00]],

        [[-1.2820e-01, -1.4334e-01, -1.7362e-01,  ..., -1.3546e+00,
          -1.2941e+00, -1.2638e+00],
         [ 2.9552e-01,  2.2524e-01,  1.9010e-01,  ..., -3.3703e-01,
          -4.7759e-01, -4.7759e-01],
         [ 5.0425e-01,  3.4115e-01,  3.4115e-01,  ...,  5.8580e-01,
           5.8580e-01,  6.6734e-01],
         ...,
         [ 1.6797e+00,  1