# Dataset

PyTorch의 EEG 데이터를 Dataset class 및 DataLoader class로 처리해보는 노트북

-----

## 환경 구성

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
# Load some packages
import os
import glob
import json

import matplotlib.pyplot as plt
import pprint

import numpy as np
import random
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

# custom package
from utils.eeg_dataset import *

In [3]:
# Other settings
%matplotlib inline
%config InlineBackend.figure_format = 'retina' # cleaner text

plt.style.use('default') 
# ['Solarize_Light2', '_classic_test_patch', 'bmh', 'classic', 'dark_background', 'fast', 
#  'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn', 'seaborn-bright', 'seaborn-colorblind', 
#  'seaborn-dark', 'seaborn-dark-palette', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 
#  'seaborn-notebook', 'seaborn-paper', 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 
#  'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'tableau-colorblind10']

plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams["font.family"] = 'NanumGothic' # for Hangul in Windows

In [4]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.10.1+cu113
cuda is available.


In [5]:
# Data file path
root_path = r'dataset/02_Curated_Data/'

In [6]:
meta_path = os.path.join(root_path, 'metadata_debug.json')
with open(meta_path, 'r') as json_file:
    metadata = json.load(json_file)

pprint.pprint(metadata[0])

{'age': 78,
 'birth': '1940-06-02',
 'dx1': 'mci_rf',
 'edfname': '00001809_261018',
 'events': [[0, 'Start Recording'],
            [0, 'New Montage - Montage 002'],
            [36396, 'Eyes Open'],
            [72518, 'Eyes Closed'],
            [73862, 'Eyes Open'],
            [75248, 'Eyes Closed'],
            [76728, 'swallowing'],
            [77978, 'Eyes Open'],
            [79406, 'Eyes Closed'],
            [79996, 'Photic On - 3.0 Hz'],
            [80288, 'Eyes Open'],
            [81296, 'Eyes Closed'],
            [82054, 'Photic Off'],
            [84070, 'Photic On - 6.0 Hz'],
            [84488, 'Eyes Open'],
            [85538, 'Eyes Closed'],
            [86086, 'Photic Off'],
            [88144, 'Photic On - 9.0 Hz'],
            [90160, 'Photic Off'],
            [91458, 'Eyes Open'],
            [92218, 'Photic On - 12.0 Hz'],
            [92762, 'Eyes Closed'],
            [94198, 'Photic Off'],
            [94742, 'Eyes Open'],
            [95708, 'Eyes Close

-----

## Data Filtering by Diagnosis

#### Non-Vascular Dementia, Non-Vascular MCI, Normal

In [7]:
diagnosis_filter = [
    # Normal
    {'type': 'Normal',
     'include': ['normal'], 
     'exclude': []},
    # Non-vascular MCI
    {'type': 'Non-vascular MCI',
     'include': ['mci'], 
     'exclude': ['mci_vascular']},
    # Non-vascular dementia
    {'type': 'Non-vascular dementia',
     'include': ['dementia'], 
     'exclude': ['vd']},
]

def generate_class_label(label):
    for c, f in enumerate(diagnosis_filter):
        # inc = set(f['include']) & set(label) == set(f['include'])
        inc = len(set(f['include']) & set(label)) > 0        
        exc = len(set(f['exclude']) & set(label)) == 0
        if  inc and exc:
            return (c, f['type'])
    return (-1, 'The others')

class_label_to_type = [d_f['type'] for d_f in diagnosis_filter]
print('class_label_to_type:', class_label_to_type)

class_label_to_type: ['Normal', 'Non-vascular MCI', 'Non-vascular dementia']


In [8]:
splitted_metadata = [[] for i in diagnosis_filter]

for m in metadata:
    c, n = generate_class_label(m['label'])
    if c >= 0:
        m['class_type'] = n
        m['class_label'] = c
        splitted_metadata[c].append(m)
        
for i, split in enumerate(splitted_metadata):
    if len(split) == 0:
        print(f'(Warning) Split group {i} has no data.')
    else:
        print(f'- There are {len(split):} data belonging to {split[0]["class_type"]}')

- There are 463 data belonging to Normal
- There are 347 data belonging to Non-vascular MCI
- There are 229 data belonging to Non-vascular dementia


-----

## Configure the Train, Validation, and Test Splits

#### Split the filtered dataset and shuffle them

In [9]:
# Train : Val : Test = 8 : 1 : 1
ratio1 = 0.8
ratio2 = 0.1

metadata_train = []
metadata_val = []
metadata_test = []

for split in splitted_metadata:
    random.shuffle(split)
    
    n1 = round(len(split) * ratio1)
    n2 = n1 + round(len(split) * ratio2)

    metadata_train.extend(split[:n1])
    metadata_val.extend(split[n1:n2])
    metadata_test.extend(split[n2:])

random.shuffle(metadata_train)
random.shuffle(metadata_val)
random.shuffle(metadata_test)

print('Train data size\t\t:', len(metadata_train))
print('Validation data size\t:', len(metadata_val))
print('Test data size\t\t:', len(metadata_test))

print('\n', '--- Recheck ---', '\n')
train_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_train:
    train_class_nums[m['class_label']] += 1

val_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_val:
    val_class_nums[m['class_label']] += 1

test_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_test:
    test_class_nums[m['class_label']] += 1

print('Train data label distribution\t:', train_class_nums, train_class_nums.sum())
print('Val data label distribution\t:', val_class_nums, val_class_nums.sum())
print('Test data label distribution\t:', test_class_nums, test_class_nums.sum())

Train data size		: 831
Validation data size	: 104
Test data size		: 104

 --- Recheck --- 

Train data label distribution	: [370 278 183] 831
Val data label distribution	: [46 35 23] 104
Test data label distribution	: [47 34 23] 104


-----

## Test TorchVision Transform

#### Random crop

In [10]:
for i in range(2):
    dataset = EEGDataset(root_path, metadata_train, EEGRandomCrop(crop_length=3))
    print(dataset[0]['signal'])
    print('\n')
    print('-' * 100)
    print('\n')

[[-44. -41. -34.]
 [-10. -11. -11.]
 [ 17.   6. -14.]
 [ 36.  34.  14.]
 [ 21.  34.  35.]
 [  4.   9.  26.]
 [-24. -28. -18.]
 [ 13.  -9. -24.]
 [ 21.  19.  15.]
 [  8.  18.  24.]
 [-21. -10.   3.]
 [-20.  -7.   5.]
 [  3.  20.  28.]
 [-28. -21.  -2.]
 [-30. -20.   0.]
 [-16.  -4.  12.]
 [-27. -37. -34.]
 [ 26.   0. -28.]
 [ 29.  17.  -3.]
 [ 45.  40.  32.]
 [724. 652. 429.]]


----------------------------------------------------------------------------------------------------


[[  8.   4.   3.]
 [  2.   1.   1.]
 [  3.   7.  10.]
 [  4.   9.  11.]
 [ -7.  -6.  -5.]
 [ -5.  -9. -13.]
 [  2.   0.  -1.]
 [ -3.   5.  12.]
 [ -8.  -4.  -3.]
 [ -3.  -5.  -6.]
 [  1.  -3.  -6.]
 [  8.   3.  -3.]
 [-15. -20. -21.]
 [ 13.   4.  -1.]
 [  0. -10. -19.]
 [  4.  -5. -13.]
 [  0.   7.  12.]
 [ -1.  11.  18.]
 [ -1.   8.  12.]
 [102.  90.  81.]
 [ -1.   0.  -1.]]


----------------------------------------------------------------------------------------------------




In [11]:
for i in range(2):
    dataset = EEGDataset(root_path, metadata_train, EEGRandomCrop(crop_length=3, multiple=2))
    print(dataset[0]['signal'])
    print('\n')
    print('-' * 100)
    print('\n')

[array([[ 80.,  87., 104.],
       [ 38.,  40.,  53.],
       [  2.,   1.,  -6.],
       [ -2.,  -4.,  -1.],
       [ 33.,  35.,  36.],
       [162., 203., 277.],
       [-25., -19.,   3.],
       [ 20.,  17.,   7.],
       [ -3.,  -2.,  -3.],
       [ -8.,  -6.,  -9.],
       [-20., -18., -18.],
       [-26., -24., -28.],
       [-37., -35., -35.],
       [-14., -11.,  -4.],
       [-15., -19., -23.],
       [ -6.,  -2.,  -2.],
       [ 43.,  45.,  37.],
       [ 19.,  11.,   4.],
       [  1.,  -9., -12.],
       [ 20.,  18.,  22.],
       [  2.,   2.,   0.]], dtype=float32), array([[ 21.,  21.,  17.],
       [ 16.,  14.,  10.],
       [  2.,  -4.,  -5.],
       [ -1.,  -3.,  -2.],
       [-10., -10.,  -8.],
       [ 16.,  21.,  21.],
       [  2.,   3.,   2.],
       [  2.,   4.,   3.],
       [  2.,   4.,   6.],
       [ -6.,  -4.,  -2.],
       [  1.,   0.,  -4.],
       [  8.,   4.,   2.],
       [ 11.,  11.,  11.],
       [-14.,  -8.,  -4.],
       [ -5.,  -1.,   3.],
       [ -

#### Normalization per signal

In [12]:
dataset = EEGDataset(root_path, metadata_train, EEGNormalizePerSignal())
print(dataset[0])

print()
print('-' * 100)
print()

print('Mean:', np.mean(dataset[0]['signal'], axis=1))
print('Std:', np.std(dataset[0]['signal'], axis=1))

{'signal': array([[ 8.8514559e-02,  6.4116552e-02,  4.4598151e-02, ...,
         9.8273762e-02,  5.4357354e-02,  2.9959347e-02],
       [ 2.6334402e-01,  4.3795574e-01,  3.2154793e-01, ...,
        -5.8061254e-01, -6.3881642e-01, -6.0971451e-01],
       [ 1.2442030e+00,  1.3330292e+00,  1.3330292e+00, ...,
        -7.9879951e-01, -7.0997328e-01, -3.5466847e-01],
       ...,
       [ 4.0791610e-01,  7.3437029e-01,  9.7921097e-01, ...,
         8.9759737e-01,  1.1424381e+00,  1.1424381e+00],
       [-1.6785289e-01, -1.6109554e-02, -7.2015002e-02, ...,
        -1.4389342e-01, -1.5187991e-01, -1.6785289e-01],
       [ 8.5052430e-05,  5.2320277e-03,  1.0379002e-02, ...,
         8.5052430e-05,  8.5052430e-05, -5.0619226e-03]], dtype=float32), 'age': 59, 'class_label': 2, 'metadata': {'serial': '00130', 'edfname': '00430821_031116', 'birth': '1957-08-25', 'record': '2016-11-03T09:47:13', 'age': 59, 'dx1': 'eoad', 'label': ['dementia', 'ad', 'eoad'], 'events': [[0, 'Start Recording'], [0, 'Ne

#### Age normalization

In [13]:
ages = []
for m in metadata_train:
    ages.append(m['age'])

ages = np.array(ages)
age_mean = np.mean(ages)
age_std = np.std(ages)

print('Age mean and standard deviation:\t', age_mean, ',\t', age_std)

print()
print('-' * 100)
print()

print('before:')
dataset = EEGDataset(root_path, metadata_train, None)
for i in range(5):
    print(dataset[i]['age'])

print()
print('-' * 100)
print()

print('after:')
dataset = EEGDataset(root_path, metadata_train, EEGNormalizeAge(mean=age_mean, std=age_std))
for i in range(5):
    print(dataset[i]['age'])

Age mean and standard deviation:	 70.26594464500602 ,	 9.855772571874425

----------------------------------------------------------------------------------------------------

before:
59
64
62
76
88

----------------------------------------------------------------------------------------------------

after:
-1.1430808240974444
-0.6357639234218535
-0.8386906836920899
0.5817966381995646
1.7993571998209827


#### Drop EKG channel

In [14]:
dataset = EEGDataset(root_path, metadata_train, None)
print('before:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

print()
print('-' * 100)
print()

dataset = EEGDataset(root_path, metadata_train, EEGDropEKGChannel())
print('after:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

before: (21, 173800)
[[ 18.  13.   9. ...  20.  11.   6.]
 [  9.  15.  11. ... -20. -22. -21.]
 [ 14.  15.  15. ...  -9.  -8.  -4.]
 ...
 [  5.   9.  12. ...  11.  14.  14.]
 [-21.  -2.  -9. ... -18. -19. -21.]
 [  0.   1.   2. ...   0.   0.  -1.]]

----------------------------------------------------------------------------------------------------

after: (20, 173800)
[[ 18.  13.   9. ...  20.  11.   6.]
 [  9.  15.  11. ... -20. -22. -21.]
 [ 14.  15.  15. ...  -9.  -8.  -4.]
 ...
 [  2.   0.   1. ...   0.   1.   4.]
 [  5.   9.  12. ...  11.  14.  14.]
 [  0.   1.   2. ...   0.   0.  -1.]]


#### Drop photic stimulation channel

In [15]:
dataset = EEGDataset(root_path, metadata_train, None)
print('before:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

print()
print('-' * 100)
print()

dataset = EEGDataset(root_path, metadata_train, EEGDropPhoticChannel())
print('after:', dataset[0]['signal'].shape)
print(dataset[0]['signal'])

before: (21, 173800)
[[ 18.  13.   9. ...  20.  11.   6.]
 [  9.  15.  11. ... -20. -22. -21.]
 [ 14.  15.  15. ...  -9.  -8.  -4.]
 ...
 [  5.   9.  12. ...  11.  14.  14.]
 [-21.  -2.  -9. ... -18. -19. -21.]
 [  0.   1.   2. ...   0.   0.  -1.]]

----------------------------------------------------------------------------------------------------

after: (20, 173800)
[[ 18.  13.   9. ...  20.  11.   6.]
 [  9.  15.  11. ... -20. -22. -21.]
 [ 14.  15.  15. ...  -9.  -8.  -4.]
 ...
 [  2.   0.   1. ...   0.   1.   4.]
 [  5.   9.  12. ...  11.  14.  14.]
 [-21.  -2.  -9. ... -18. -19. -21.]]


#### To Tensor

In [16]:
dataset = EEGDataset(root_path, metadata_train, None)
print('before:')
print(dataset[0])

print()
print('-' * 100)
print()

dataset = EEGDataset(root_path, metadata_train, EEGToTensor())
print('after:')
print(dataset[0])

before:
{'signal': array([[ 18.,  13.,   9., ...,  20.,  11.,   6.],
       [  9.,  15.,  11., ..., -20., -22., -21.],
       [ 14.,  15.,  15., ...,  -9.,  -8.,  -4.],
       ...,
       [  5.,   9.,  12., ...,  11.,  14.,  14.],
       [-21.,  -2.,  -9., ..., -18., -19., -21.],
       [  0.,   1.,   2., ...,   0.,   0.,  -1.]], dtype=float32), 'age': 59, 'class_label': 2, 'metadata': {'serial': '00130', 'edfname': '00430821_031116', 'birth': '1957-08-25', 'record': '2016-11-03T09:47:13', 'age': 59, 'dx1': 'eoad', 'label': ['dementia', 'ad', 'eoad'], 'events': [[0, 'Start Recording'], [0, 'New Montage - Montage 002'], [1214, 'Eyes Open'], [3320, 'Eyes Closed'], [14869, 'Eyes Open'], [15919, 'Eyes Closed'], [19910, 'Eyes Open'], [21128, 'Eyes Closed'], [27470, 'HV - Dur: 157.3 sec. - On'], [40070, 'Eyes Open'], [41330, 'Eyes Closed'], [57336, 'Eyes Open'], [58386, 'Eyes Closed'], [58932, 'HV - Off'], [59141, 'HV fair'], [60779, 'Eyes Open'], [61788, 'Eyes Closed'], [64979, 'Eyes Open']

#### Short time Fourier transform (STFT or spectrogram)

In [17]:
composed = transforms.Compose([EEGToTensor(), EEGSpectrogram(n_fft=200, complex_mode='as_real')])
dataset = EEGDataset(root_path, metadata_train, composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

print()
print('-' * 100)
print()

composed = transforms.Compose([EEGToTensor(), EEGSpectrogram(n_fft=200, complex_mode='power')])
dataset = EEGDataset(root_path, metadata_train, composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

print()
print('-' * 100)
print()

composed = transforms.Compose([EEGToTensor(), EEGSpectrogram(n_fft=200, complex_mode='remove')])
dataset = EEGDataset(root_path, metadata_train, composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

torch.Size([42, 101, 3477]) torch.float32 <class 'torch.Tensor'>
tensor([[-2.8540e+03, -2.5239e+02, -1.3384e+02,  ...,  4.9411e+00,
          9.5575e+00,  1.2000e+01],
        [-4.7620e+03, -1.0710e+02, -2.0297e+02,  ...,  5.1596e+00,
          2.4075e+00,  8.0000e+00],
        [-2.0290e+03, -5.7747e+01, -9.6218e+01,  ..., -8.3456e+00,
         -9.1753e+00, -3.0000e+00],
        ...,
        [ 0.0000e+00,  3.8263e+02, -6.8193e+01,  ..., -2.0900e-01,
          5.2983e+00,  0.0000e+00],
        [ 0.0000e+00,  4.1983e+03,  6.2755e+03,  ..., -2.8503e+00,
         -7.9858e-01,  0.0000e+00],
        [ 0.0000e+00, -2.1638e+00, -1.5501e+01,  ...,  1.4234e+00,
          1.2831e+00,  0.0000e+00]])

----------------------------------------------------------------------------------------------------

torch.Size([21, 101, 3477]) torch.float32 <class 'torch.Tensor'>
tensor([[2.8540e+03, 6.9932e+02, 1.3459e+02,  ..., 4.9506e+00, 9.5628e+00,
         1.2000e+01],
        [4.7620e+03, 3.2330e+02, 2.633

#### Compose some at once

In [18]:
composed = transforms.Compose([EEGNormalizeAge(mean=age_mean, std=age_std),
                               EEGDropPhoticChannel(),
                               EEGRandomCrop(crop_length=200*60), # 1 minute
                               EEGNormalizePerSignal(),
                               EEGToTensor()])

train_dataset = EEGDataset(root_path, metadata_train, composed)
val_dataset = EEGDataset(root_path, metadata_val, composed)
test_dataset = EEGDataset(root_path, metadata_test, composed)

print(train_dataset[0]['signal'].shape)
print(train_dataset[0])

print()
print('-' * 100)
print()

print(val_dataset[0]['signal'].shape)
print(val_dataset[0])

print()
print('-' * 100)
print()

print(test_dataset[0]['signal'].shape)
print(test_dataset[0])

torch.Size([20, 12000])
{'signal': tensor([[-4.2218e-02, -1.3023e-02, -7.1412e-02,  ..., -6.8450e-01,
         -7.1369e-01, -5.6772e-01],
        [-4.2921e-01, -4.8261e-01, -6.4281e-01,  ..., -8.0301e-01,
         -8.0301e-01, -5.8941e-01],
        [ 7.5856e-03, -6.9897e-01, -1.4055e+00,  ...,  7.5856e-03,
          4.7862e-01,  3.6086e-01],
        ...,
        [ 2.3564e+00,  1.6056e+00,  7.4754e-01,  ..., -1.1050e-01,
          8.5480e-01,  8.5480e-01],
        [ 3.3584e+00,  2.9515e+00,  2.3410e+00,  ..., -9.1455e-01,
          1.0852e-03, -1.0065e-01],
        [ 1.5178e+00,  1.4114e+00,  1.3213e+00,  ...,  7.0717e-01,
          1.1248e+00,  1.4851e+00]]), 'age': tensor(-1.1431), 'class_label': tensor(2), 'metadata': {'serial': '00130', 'edfname': '00430821_031116', 'birth': '1957-08-25', 'record': '2016-11-03T09:47:13', 'age': 59, 'dx1': 'eoad', 'label': ['dementia', 'ad', 'eoad'], 'events': [[0, 'Start Recording'], [0, 'New Montage - Montage 002'], [1214, 'Eyes Open'], [3320, 'Eye

In [19]:
composed = transforms.Compose([EEGNormalizeAge(mean=age_mean, std=age_std),
                               EEGDropPhoticChannel(),
                               EEGRandomCrop(crop_length=200*20, multiple=2), # 20 seconds
                               EEGNormalizePerSignal(),
                               EEGToTensor()
                              ])

train_dataset = EEGDataset(root_path, metadata_train, composed)
val_dataset = EEGDataset(root_path, metadata_val, composed)
test_dataset = EEGDataset(root_path, metadata_test, composed)

print(train_dataset[0]['signal'][0].shape)
print(torch.mean(train_dataset[0]['signal'][0], axis=-1))
print(torch.std(train_dataset[0]['signal'][0], axis=-1))

print()
print('-' * 100)
print()

print(train_dataset[0]['signal'][1].shape)
print(torch.mean(train_dataset[0]['signal'][1], axis=-1))
print(torch.std(train_dataset[0]['signal'][1], axis=-1))

print()
print('-' * 100)
print()

print(train_dataset[0]['signal'][1].shape)
print(train_dataset[0])

print()
print('-' * 100)
print()

print(val_dataset[0])

print()
print('-' * 100)
print()

print(test_dataset[0])

torch.Size([20, 4000])
tensor([-9.5367e-10,  2.0027e-08,  5.7220e-09, -3.8147e-09,  4.0531e-09,
         5.9605e-09,  2.8610e-09,  6.7949e-09, -8.3447e-09,  6.4373e-09,
         1.2875e-08,  7.3910e-09,  1.2994e-08, -3.3379e-09, -1.4067e-08,
         1.4156e-08, -1.2875e-08, -5.9605e-09,  9.5367e-09,  4.5300e-09])
tensor([1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001,
        1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001,
        1.0001, 1.0001])

----------------------------------------------------------------------------------------------------

torch.Size([20, 4000])
tensor([-6.5565e-09,  4.4107e-09, -1.3113e-09,  1.1921e-08,  7.1526e-09,
        -5.3644e-09, -4.5896e-09, -8.5831e-09, -2.1458e-09,  1.4305e-08,
        -5.2452e-09, -1.2636e-08,  6.6757e-09, -7.7486e-10,  6.6757e-09,
        -1.4782e-08, -5.6028e-09,  6.0797e-09, -6.1989e-09, -2.9802e-09])
tensor([1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001,


In [20]:
composed = transforms.Compose([EEGNormalizeAge(mean=age_mean, std=age_std),
                               EEGDropPhoticChannel(),
                               EEGRandomCrop(crop_length=200*60), # 1 minutes
                               EEGNormalizePerSignal(),
                               EEGToTensor(),
                               EEGSpectrogram(n_fft=200, complex_mode='power', hop_length=200 // 2)])
dataset = EEGDataset(root_path, metadata_train, composed)
print(dataset[0]['signal'].shape, dataset[0]['signal'].dtype, type(dataset[0]['signal']))
print(dataset[0]['signal'][:, :, 10])

torch.Size([20, 101, 121]) torch.float32 <class 'torch.Tensor'>
tensor([[9.0801e+01, 1.1099e+01, 4.6555e+00,  ..., 1.4541e-01, 7.3999e-02,
         1.2837e-01],
        [8.6186e+01, 7.7931e+00, 1.3505e+01,  ..., 1.0877e-01, 2.7455e-01,
         0.0000e+00],
        [4.0296e+01, 2.2620e+01, 1.8139e+01,  ..., 3.7149e-01, 2.6048e-01,
         2.9259e-01],
        ...,
        [2.4472e+01, 1.7197e+01, 1.9991e+01,  ..., 1.2253e+00, 8.6219e-01,
         7.5294e-01],
        [5.0536e+01, 1.8511e+01, 1.5670e+01,  ..., 1.1077e+00, 1.4598e+00,
         3.7300e-01],
        [6.5496e+00, 3.7919e+01, 4.9597e+01,  ..., 1.7998e-01, 1.7055e-01,
         1.6366e-01]])


#### Data loader test

In [21]:
%%time

print('Current PyTorch device:', device)
if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed = transforms.Compose([
    EEGRandomCrop(crop_length=200*20, multiple=2), # 20 seconds
    EEGNormalizeAge(mean=age_mean, std=age_std),
    EEGDropPhoticChannel(),
    EEGNormalizePerSignal(),
    EEGToTensor()
])

train_dataset = EEGDataset(root_path, metadata_train, composed)
val_dataset = EEGDataset(root_path, metadata_val, composed)
test_dataset = EEGDataset(root_path, metadata_test, composed)

train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inplate the minibatch size
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)
    
    print(i_batch, 
          sample_batched['signal'].shape, 
          sample_batched['age'].shape, 
          sample_batched['class_label'].shape, 
          len(sample_batched['metadata']))
    
    if i_batch > 3:
        break

Current PyTorch device: cuda
0 torch.Size([64, 20, 4000]) torch.Size([64]) torch.Size([64]) 64
1 torch.Size([64, 20, 4000]) torch.Size([64]) torch.Size([64]) 64
2 torch.Size([64, 20, 4000]) torch.Size([64]) torch.Size([64]) 64
3 torch.Size([64, 20, 4000]) torch.Size([64]) torch.Size([64]) 64
4 torch.Size([64, 20, 4000]) torch.Size([64]) torch.Size([64]) 64
CPU times: total: 11.7 s
Wall time: 996 ms


#### Train, validation, test dataloaders

In [22]:
train_loader = DataLoader(train_dataset, 
                          batch_size=32, 
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

val_loader = DataLoader(val_dataset, 
                        batch_size=32, 
                        shuffle=False, 
                        drop_last=False,
                        num_workers=num_workers, 
                        pin_memory=pin_memory,
                        collate_fn=eeg_collate_fn)

test_loader = DataLoader(test_dataset, 
                         batch_size=32, 
                         shuffle=False, 
                         drop_last=False,
                         num_workers=num_workers, 
                         pin_memory=pin_memory,
                         collate_fn=eeg_collate_fn)

In [23]:
for batch_i, sample_batched in enumerate(train_loader):
    # pull up the batch data
    x = sample_batched['signal'].to(device)
    age = sample_batched['age'].to(device)
    target = sample_batched['class_label'].to(device)
    
    print(x)
    print(age)
    print(target)
    
    break

tensor([[[ 1.6559e+00,  1.6458e+00,  1.6458e+00,  ..., -1.5815e-01,
          -1.8855e-01, -2.7976e-01],
         [ 1.2338e+00,  1.3052e+00,  1.2338e+00,  ...,  2.0896e-02,
           2.0896e-02,  5.6571e-02],
         [-9.5145e-01, -1.1137e+00, -1.1948e+00,  ...,  4.2773e-01,
           4.2773e-01,  3.4660e-01],
         ...,
         [-2.6364e+00, -2.7283e+00, -3.0037e+00,  ...,  3.0184e-01,
           2.1002e-01,  2.1002e-01],
         [-9.4228e-01, -8.3814e-01, -7.8606e-01,  ...,  4.6370e-01,
           4.6370e-01,  4.6370e-01],
         [ 7.6880e-01,  9.6164e-01,  8.6015e-01,  ...,  7.5865e-01,
           8.1955e-01,  7.5865e-01]],

        [[ 5.5445e-01,  5.2110e-01,  6.5449e-01,  ..., -5.7934e-01,
          -5.1265e-01, -4.1261e-01],
         [-6.3392e-01, -6.3392e-01, -5.0142e-01,  ..., -1.2964e+00,
          -8.9890e-01, -6.3392e-01],
         [-6.8440e-01, -4.7576e-01, -2.6711e-01,  ..., -1.1017e+00,
          -1.3103e+00, -1.5190e+00],
         ...,
         [-4.3910e-01, -3