# EEG Data File Format Test

Given the metadata generated in `01_Data_Curation1` notebook, this notebook tests some file formats to read and write the EEG dataset and decides the best one.  
The following file formats are tested: `NumPy pickle`, `Feather`, `Parquet`, `Jay`, and `HDF5`.

-----

## Load Packages and Configure Notebook Environments

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [None]:
# Load some packages
import os
import re
import copy
import glob
from openpyxl import load_workbook, Workbook, styles
import json

import pyedflib
import datetime
from dateutil.relativedelta import relativedelta

import pprint
import warnings
import ctypes
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import pandas as pd
import pyarrow.feather as feather
import datatable as dt
import h5py

import numpy as np
import random
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# custom package
from datasets.cau_eeg_dataset import *
from datasets.utils import *
from datasets.pipeline import *

In [None]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

-----

## Load and Check Data

In [None]:
# Data file path
original_path = r'local/dataset/01_Original_Data_220307'
curated_path = r'local/dataset/02_Curated_Data_temp'

os.makedirs(curated_path, exist_ok=True)

In [None]:
meta_file = os.path.join(original_path, r'new_DB_list.xlsx')
ws = load_workbook(meta_file, data_only=True)['metadata']

metadata = []

num = 2
while True:
    m = dict()
    m['edfname'] = ws.cell(row=num, column=1).value
    m['dx1'] = ws.cell(row=num, column=2).value
    m['birth'] = ws.cell(row=num, column=3).value
    m['anomaly'] = True if ws.cell(row=num, column=4).value is not None else False
    num += 1
    
    # check whether the row is empty (which is EOF condition)
    if m['edfname'] is None:
        break
    elif m['anomaly']:
        continue
        
    # move the pivot row
    metadata.append(m)
    
print('Size:', len(metadata))
print()
print('Loaded metadata (first three displayed):')
print(json.dumps(metadata[:3], indent=4))

In [None]:
m = metadata[0]
edf_file = os.path.join(original_path, m['edfname'] + '.edf')
signals, signal_headers, edf_header  = pyedflib.highlevel.read_edf(edf_file)

refer_headers = signal_headers

print(np.unique(signals.reshape(-1)).shape, signals.max() - signals.min()) # check the signal is discrete
print()
pprint.pp(m)
print()
pprint.pp(edf_header)
print()
pprint.pp(signal_headers)

---

## Arrange and Save the Data

1. Remove EDF files with irregular signal header
2. Extract EDF signal and save the data after reformatting
3. Save all metadata as JSON and XLSX where JSON is for further use and XLSX for examination of human
    - `metadata_debug`: Full inclusion metadata for debugging
    - `metadata_public`: Metadata without personal information

In [None]:
start_latency_length = 10 * 200  # 10 seconds

warnings.filterwarnings(action='ignore')

text = f'Delete ALL files in {curated_path}?'
if ctypes.windll.user32.MessageBoxExW(0, text, 'Question', 4) == 6: # Yes
    for f in glob.glob(os.path.join(curated_path, '*/*')):
        os.remove(f)
    for f in glob.glob(os.path.join(curated_path, '*.*')):
        os.remove(f)

os.makedirs(os.path.join(curated_path, 'signal'), exist_ok=True)
os.makedirs(os.path.join(curated_path, 'event'), exist_ok=True)

metadata_debug = []
metadata_public = []
fh5 = h5py.File(os.path.join(curated_path, 'signal', 'signal.h5'), 'w', 
                rdcc_nbytes =(1024**2)*15, rdcc_nslots=1e6)

for m in tqdm(metadata):
    # EDF file check
    edf_file = os.path.join(original_path, m['edfname'] + '.edf')
    signals, signal_headers, edf_header = pyedflib.highlevel.read_edf(edf_file)
        
    if refer_headers != signal_headers:
        print('- Signal header differs from the majority:', m['edfname'])
        continue
        
    # calculate age
    age = calculate_age(birth_to_datetime(m['birth']), 
                        edf_header['startdate'])
    
    if age is None:
        print('- The age information is unknown:', m['edfname'])
        continue
    
    # EDF recoding events
    event_file = os.path.join(original_path, m['edfname'] + '.xlsx')
    wb = load_workbook(event_file, data_only=True)
    ws = wb[wb.sheetnames[0]]
    
    num = 2
    event = [] 
    
    while True:
        t = ws.cell(row=num, column=3).value
        e = ws.cell(row=num, column=4).value
        
        if t is None:
            break
        
        t = edf_header['startdate'].strftime('%Y%m%d') + t
        t = datetime.datetime.strptime(t, '%Y%m%d %H:%M:%S.%f')
        
        if num == 2: 
            startTime = t
            
        t = int(np.floor((t - startTime).total_seconds() * 200))
        event.append((t, e))
        num += 1
    
    # metadata_debug
    m2 = {}
    m2['serial'] = f'{len(metadata_debug) + 1:05}'
    m2['edfname'] = m['edfname']
    m2['birth'] = birth_to_datetime(m['birth'])
    m2['record'] = edf_header['startdate']
    m2['age'] = age
    m2['dx1'] = m['dx1']
    m2['label'] = MultiLabel.load_from_string(m['dx1'])
    m2['event'] = event
    metadata_debug.append(m2)
    
    # metadata_public
    m3 = {}
    m3['serial'] = m2['serial']
    m3['age'] = age
    m3['label'] = m2['label']
    metadata_public.append(m3)
    
    # EDF signal
    signals = trim_tailing_zeros(signals)        # trim garbage zeros
    signals = signals[:, start_latency_length:]  # throw away some signal right after starting recording
    signals = signals.astype('int32')
    df = pd.DataFrame(data=signals.T, columns=[s_h['label'] for s_h in signal_headers], dtype=np.int32)
    
    # numpy pickle
    np.save(os.path.join(curated_path, 'signal', m2['serial']), signals)
    
    # numpy memmap
    fp = np.memmap(os.path.join(curated_path, 'signal', m2['serial'] + '.dat'), 
                   dtype='int32', mode='w+', shape=signals.shape)
    fp[:] = signals[:]
    fp.flush()
    
    # feather
    df.to_feather(os.path.join(curated_path, 'signal', m2['serial'] + '.feather'))
    
    # parquet
    df.to_parquet(os.path.join(curated_path, 'signal', m2['serial']+ '.parquet'))
    
    # jay
    dt.Frame(df).to_jay(os.path.join(curated_path, 'signal', m2['serial'] + '.jay'))
    
    # hdf5
    # df.to_hdf(os.path.join(curated_path, 'signal', 'signal.h5') , m2['serial'])
    fh5.create_dataset(m2['serial'], shape=signals.shape, dtype='int32', chunks=True, data=signals)
    # rdcc_nbytes =(1024**2)*15, rdcc_nslots=1e6
    
    # event
    df = pd.DataFrame(data=event, columns=['timing', 'event'])
    with open(os.path.join(curated_path, 'event', m2['serial']) + '.json', 'w') as json_file:
        json.dump(event, json_file, indent=4, default=serialize_json)    
    df.to_feather(os.path.join(curated_path, 'event', m2['serial']) + '.feather')
    
print('Done.')
print()
print(f'Among {len(metadata)}, {len(metadata_public)} data were saved.')

warnings.filterwarnings(action='default')

In [None]:
# save metadata_public as JSON
path = os.path.join(curated_path, 'metadata_public.json')
with open(path, 'w') as json_file:
    json.dump(metadata_public, json_file, indent=4, default=serialize_json)

In [None]:
# load metadata_public
meta_path = os.path.join(curated_path, 'metadata_public.json')
with open(meta_path, 'r') as json_file:
    metadata = json.load(json_file)

pprint.pprint(metadata[0])

-----

## Data Filtering by Diagnosis

#### Non-Vascular Dementia, Non-Vascular MCI, Normal

In [None]:
diagnosis_filter = [
    # Normal
    {'type': 'Normal',
     'include': ['normal'], 
     'exclude': []},
    # Non-vascular MCI
    {'type': 'Non-vascular MCI',
     'include': ['mci'], 
     'exclude': ['mci_vascular']},
    # Non-vascular dementia
    {'type': 'Non-vascular dementia',
     'include': ['dementia'], 
     'exclude': ['vd']},
]

def generate_class_label(label):
    for c, f in enumerate(diagnosis_filter):
        # inc = set(f['include']) & set(label) == set(f['include'])
        inc = len(set(f['include']) & set(label)) > 0        
        exc = len(set(f['exclude']) & set(label)) == 0
        if  inc and exc:
            return (c, f['type'])
    return (-1, 'The others')

class_label_to_type = [d_f['type'] for d_f in diagnosis_filter]
print('class_label_to_type:', class_label_to_type)

In [None]:
splitted_metadata = [[] for i in diagnosis_filter]

for m in metadata:
    c, n = generate_class_label(m['label'])
    if c >= 0:
        m['class_type'] = n
        m['class_label'] = c
        splitted_metadata[c].append(m)
        
for i, split in enumerate(splitted_metadata):
    if len(split) == 0:
        print(f'(Warning) Split group {i} has no data.')
    else:
        print(f'- There are {len(split):} data belonging to {split[0]["class_type"]}')

-----

## Configure the Train, Validation, and Test Splits

#### Split the filtered dataset and shuffle them

In [None]:
# Train : Val : Test = 8 : 1 : 1
ratio1 = 0.8
ratio2 = 0.1

metadata_train = []
metadata_val = []
metadata_test = []

for split in splitted_metadata:
    random.shuffle(split)
    
    n1 = round(len(split) * ratio1)
    n2 = n1 + round(len(split) * ratio2)

    metadata_train.extend(split[:n1])
    metadata_val.extend(split[n1:n2])
    metadata_test.extend(split[n2:])

random.shuffle(metadata_train)
random.shuffle(metadata_val)
random.shuffle(metadata_test)

print('Train data size\t\t:', len(metadata_train))
print('Validation data size\t:', len(metadata_val))
print('Test data size\t\t:', len(metadata_test))

print('\n', '--- Recheck ---', '\n')
train_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_train:
    train_class_nums[m['class_label']] += 1

val_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_val:
    val_class_nums[m['class_label']] += 1

test_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_test:
    test_class_nums[m['class_label']] += 1

print('Train data label distribution\t:', train_class_nums, train_class_nums.sum())
print('Val data label distribution\t:', val_class_nums, val_class_nums.sum())
print('Test data label distribution\t:', test_class_nums, test_class_nums.sum())

-----

## Check Signal Loading Time by Data Format

In [None]:
class EegDataset_np(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.npy')
        signal = np.load(fname)
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [None]:
class EegDataset_np_memmap(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.dat')
        signal = np.memmap(fname, dtype='int32', mode='r').reshape(21, -1)
        # signal = np.load(fname).astype('float32')
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [None]:
class EegDataset_feather(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.feather')
        # signal = pd.read_feather(fname).to_numpy().T
        # signal = pd.read_feather(fname).values.T
        signal = feather.read_feather(fname).values.T
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [None]:
class EegDataset_parquet(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.parquet')
        # signal = pd.read_parquet(fname).to_numpy().T
        signal = pd.read_parquet(fname).values.T
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [None]:
class EegDataset_jay(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.jay')
        signal = dt.fread(fname).to_numpy().T
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [None]:
class EegDataset_h5(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.f_handle = h5py.File(os.path.join(self.root_dir, 'signal', 'signal.h5'), 'r')
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        #fname = os.path.join(self.root_dir, 'signal', 'signal.h5')
        # signal = pd.read_hdf(fname, m['serial']).to_numpy().T
        signal = self.f_handle[m['serial']][:]
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

print('Current PyTorch device:', device)

if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed = transforms.Compose([
    EegRandomCrop(crop_length=200*20, multiple=2), # 20 seconds
    EegDropPhoticChannel(),
    EegToTensor()
])

### Test NumPy Pickle

In [None]:
%%time

train_dataset = EegDataset_np(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

In [None]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

### Test NumPy Memmap

In [None]:
%%time

train_dataset = EegDataset_np_memmap(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

In [None]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

### Test Feather

In [None]:
%%time

train_dataset = EegDataset_feather(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

In [None]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

### Test Parquet

In [None]:
%%time

train_dataset = EegDataset_parquet(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

In [None]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

### Test Jay

In [None]:
%%time

train_dataset = EegDataset_jay(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

In [None]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

### Test HDF5

In [None]:
%%time

train_dataset = EegDataset_h5(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

In [None]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

-----

## Check Event Loading Time by Data Format

In [None]:
class EegDataset_event_json(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'event', m['serial'] + '.json')
        with open(fname, 'r') as json_file:
            event = json.load(json_file)        
        sample = {'signal': np.zeros((20, 100)),
                  'event': event,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [None]:
class EegDataset_event_feather(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'event', m['serial'] + '.feather')
        event = pd.read_feather(fname)
        sample = {'signal': np.zeros((20, 100)),
                  'event': event,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [None]:
composed = transforms.Compose([
    EegToTensor()
])

### Test JSON

In [None]:
%%time

train_dataset = EegDataset_event_json(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(10):
    for i_batch, sample_batched in enumerate(train_loader):
        sample_batched['signal'].to(device)
        sample_batched['age'].to(device)
        sample_batched['class_label'].to(device)

In [None]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['event'])

### Test Feather

In [None]:
%%time

train_dataset = EegDataset_event_feather(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(10):
    for i_batch, sample_batched in enumerate(train_loader):
        sample_batched['signal'].to(device)
        sample_batched['age'].to(device)
        sample_batched['class_label'].to(device)

In [None]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['event'])