# EEG Data File Format Test

Given the metadata generated in `01_Data_Curation1` notebook, this notebook tests some file formats to read and write the EEG dataset and decides the best one.  
The following file formats are tested: `NumPy pickle`, `Feather`, `Parquet`, `Jay`, and `HDF5`.

-----

## Load Packages and Configure Notebook Environments

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

In [2]:
# Load some packages
import os
import re
import copy
import glob
from openpyxl import load_workbook, Workbook, styles
import json

import pyedflib
import datetime
from dateutil.relativedelta import relativedelta

import pprint
import warnings
import ctypes
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import pandas as pd
import pyarrow.feather as feather
import datatable as dt
import h5py

import numpy as np
import random
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# custom package
from datasets.caueeg_dataset import *
from datasets.caueeg_data_curation import *
from datasets.pipeline import *

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.10.1+cu113
cuda is available.


-----

## Load and Check Data

In [4]:
# Data file path
original_path = r'local/dataset/01_Original_Data_220307'
curated_path = r'local/dataset/02_Curated_Data_Temp'

os.makedirs(curated_path, exist_ok=True)

In [5]:
meta_file = os.path.join(original_path, r'new_DB_list.xlsx')
ws = load_workbook(meta_file, data_only=True)['metadata']

metadata = []

num = 2
while True:
    m = dict()
    m['edfname'] = ws.cell(row=num, column=1).value
    m['dx1'] = ws.cell(row=num, column=2).value
    m['birth'] = ws.cell(row=num, column=3).value
    m['anomaly'] = True if ws.cell(row=num, column=4).value is not None else False
    num += 1
    
    # check whether the row is empty (which is EOF condition)
    if m['edfname'] is None:
        break
    elif m['anomaly']:
        continue
        
    # move the pivot row
    metadata.append(m)
    
print('Size:', len(metadata))
print()
print('Loaded metadata (first three displayed):')
print(json.dumps(metadata[:3], indent=4))

Size: 1390

Loaded metadata (first three displayed):
[
    {
        "edfname": "00001809_261018",
        "dx1": "mci_rf",
        "birth": 400602,
        "anomaly": false
    },
    {
        "edfname": "00029426_020817",
        "dx1": "smi",
        "birth": 601204,
        "anomaly": false
    },
    {
        "edfname": "00047327_090718",
        "dx1": "vascular mci",
        "birth": 241019,
        "anomaly": false
    }
]


In [6]:
m = metadata[0]
edf_file = os.path.join(original_path, m['edfname'] + '.edf')
signals, signal_headers, edf_header  = pyedflib.highlevel.read_edf(edf_file)

refer_headers = signal_headers

print(np.unique(signals.reshape(-1)).shape, signals.max() - signals.min()) # check the signal is discrete
print()
pprint.pp(m)
print()
pprint.pp(edf_header)
print()
pprint.pp(signal_headers)

(2820,) 4843.0

{'edfname': '00001809_261018',
 'dx1': 'mci_rf',
 'birth': 400602,
 'anomaly': False}

{'technician': '',
 'recording_additional': '',
 'patientname': '',
 'patient_additional': '',
 'patientcode': '',
 'equipment': '',
 'admincode': '',
 'gender': '',
 'startdate': datetime.datetime(2018, 10, 26, 15, 46, 26),
 'birthdate': '',
 'annotations': []}

[{'label': 'Fp1-AVG',
  'dimension': 'uV',
  'sample_rate': 200.0,
  'physical_max': 32767.0,
  'physical_min': -32768.0,
  'digital_max': 32767,
  'digital_min': -32768,
  'prefilter': '',
  'transducer': 'E'},
 {'label': 'F3-AVG',
  'dimension': 'uV',
  'sample_rate': 200.0,
  'physical_max': 32767.0,
  'physical_min': -32768.0,
  'digital_max': 32767,
  'digital_min': -32768,
  'prefilter': '',
  'transducer': 'E'},
 {'label': 'C3-AVG',
  'dimension': 'uV',
  'sample_rate': 200.0,
  'physical_max': 32767.0,
  'physical_min': -32768.0,
  'digital_max': 32767,
  'digital_min': -32768,
  'prefilter': '',
  'transducer': 'E'},

---

## Arrange and Save the Data

1. Remove EDF files with irregular signal header
2. Extract EDF signal and save the data after reformatting
3. Save all metadata as JSON and XLSX where JSON is for further use and XLSX for examination of human
    - `metadata_debug`: Full inclusion metadata for debugging
    - `metadata_public`: Metadata without personal information

In [7]:
warnings.filterwarnings(action='ignore')

text = f'Delete ALL files in {curated_path}?'
if ctypes.windll.user32.MessageBoxExW(0, text, 'Question', 4) == 6: # Yes
    for f in glob.glob(os.path.join(curated_path, '*/*')):
        os.remove(f)
    for f in glob.glob(os.path.join(curated_path, '*.*')):
        os.remove(f)

os.makedirs(os.path.join(curated_path, 'signal'), exist_ok=True)
os.makedirs(os.path.join(curated_path, 'event'), exist_ok=True)

metadata_debug = []
metadata_public = []
fh5 = h5py.File(os.path.join(curated_path, 'signal', 'signal.h5'), 'w', 
                rdcc_nbytes =(1024**2)*15, rdcc_nslots=1e6)

for m in tqdm(metadata):
    # EDF file check
    edf_file = os.path.join(original_path, m['edfname'] + '.edf')
    signals, signal_headers, edf_header = pyedflib.highlevel.read_edf(edf_file)
        
    if refer_headers != signal_headers:
        print('- Signal header differs from the majority:', m['edfname'])
        continue
        
    # calculate age
    age = calculate_age(birth_to_datetime(m['birth']), 
                        edf_header['startdate'])
    
    if age is None:
        print('- The age information is unknown:', m['edfname'])
        continue
    
    # EDF recoding events
    event_file = os.path.join(original_path, m['edfname'] + '.xlsx')
    wb = load_workbook(event_file, data_only=True)
    ws = wb[wb.sheetnames[0]]
    
    num = 2
    event = [] 
    
    while True:
        t = ws.cell(row=num, column=3).value
        e = ws.cell(row=num, column=4).value
        
        if t is None:
            break
        
        t = edf_header['startdate'].strftime('%Y%m%d') + t
        t = datetime.datetime.strptime(t, '%Y%m%d %H:%M:%S.%f')
        
        if num == 2: 
            startTime = t
            
        t = int(np.floor((t - startTime).total_seconds() * 200))
        event.append((t, e))
        num += 1
    
    # metadata_debug
    m2 = {}
    m2['serial'] = f'{len(metadata_debug) + 1:05}'
    m2['edfname'] = m['edfname']
    m2['birth'] = birth_to_datetime(m['birth'])
    m2['record'] = edf_header['startdate']
    m2['age'] = age
    m2['dx1'] = m['dx1']
    m2['label'] = MultiLabel.load_from_string(m['dx1'])
    m2['event'] = event
    metadata_debug.append(m2)
    
    # metadata_public
    m3 = {}
    m3['serial'] = m2['serial']
    m3['age'] = age
    m3['label'] = m2['label']
    metadata_public.append(m3)
    
    # EDF signal
    signals = trim_trailing_zeros(signals)        # trim garbage zeros
    signals = signals.astype('int32')
    df = pd.DataFrame(data=signals.T, columns=[s_h['label'] for s_h in signal_headers], dtype=np.int32)
    
    # numpy pickle
    np.save(os.path.join(curated_path, 'signal', m2['serial']), signals)
    
    # numpy memmap
    fp = np.memmap(os.path.join(curated_path, 'signal', m2['serial'] + '.dat'), 
                   dtype='int32', mode='w+', shape=signals.shape)
    fp[:] = signals[:]
    fp.flush()
    
    # feather
    feather.write_feather(df, os.path.join(curated_path, 'signal', m2['serial'] + '.feather'))
    
    # parquet
    df.to_parquet(os.path.join(curated_path, 'signal', m2['serial']+ '.parquet'))
    
    # jay
    dt.Frame(df).to_jay(os.path.join(curated_path, 'signal', m2['serial'] + '.jay'))
    
    # hdf5
    # df.to_hdf(os.path.join(curated_path, 'signal', 'signal.h5') , m2['serial'])
    fh5.create_dataset(m2['serial'], shape=signals.shape, dtype='int32', chunks=True, data=signals)
    # rdcc_nbytes =(1024**2)*15, rdcc_nslots=1e6
    
    # event
    df = pd.DataFrame(data=event, columns=['timing', 'event'])
    with open(os.path.join(curated_path, 'event', m2['serial']) + '.json', 'w') as json_file:
        json.dump(event, json_file, indent=4, default=serialize_json)    
    df.to_feather(os.path.join(curated_path, 'event', m2['serial']) + '.feather')
    
print('Done.')
print()
print(f'Among {len(metadata)}, {len(metadata_public)} data were saved.')

warnings.filterwarnings(action='default')

  0%|          | 0/1390 [00:00<?, ?it/s]

- The age information is unknown: 00602793_300518
- The age information is unknown: 00850537_061014
Done.

Among 1390, 1388 data were saved.


In [8]:
# save metadata_public as JSON
path = os.path.join(curated_path, 'metadata.json')
with open(path, 'w') as json_file:
    json.dump(metadata_public, json_file, indent=4, default=serialize_json)

In [9]:
# load metadata_public
meta_path = os.path.join(curated_path, 'metadata.json')
with open(meta_path, 'r') as json_file:
    metadata = json.load(json_file)

pprint.pprint(metadata[0])

{'age': 78,
 'label': ['mci', 'mci_amnestic', 'mci_amnestic_rf'],
 'serial': '00001'}


-----

## Data Filtering by Diagnosis

#### Non-Vascular Dementia, Non-Vascular MCI, Normal

In [10]:
diagnosis_filter = [
    # Normal
    {'type': 'Normal',
     'include': ['normal'], 
     'exclude': []},
    # Non-vascular MCI
    {'type': 'Non-vascular MCI',
     'include': ['mci'], 
     'exclude': ['mci_vascular']},
    # Non-vascular dementia
    {'type': 'Non-vascular dementia',
     'include': ['dementia'], 
     'exclude': ['vd']},
]

def generate_class_label(label):
    for c, f in enumerate(diagnosis_filter):
        # inc = set(f['include']) & set(label) == set(f['include'])
        inc = len(set(f['include']) & set(label)) > 0        
        exc = len(set(f['exclude']) & set(label)) == 0
        if  inc and exc:
            return (c, f['type'])
    return (-1, 'The others')

class_label_to_name = [d_f['type'] for d_f in diagnosis_filter]
print('class_label_to_name:', class_label_to_name)

class_label_to_name: ['Normal', 'Non-vascular MCI', 'Non-vascular dementia']


In [11]:
splitted_metadata = [[] for i in diagnosis_filter]

for m in metadata:
    c, n = generate_class_label(m['label'])
    if c >= 0:
        m['class_type'] = n
        m['class_label'] = c
        splitted_metadata[c].append(m)
        
for i, split in enumerate(splitted_metadata):
    if len(split) == 0:
        print(f'(Warning) Split group {i} has no data.')
    else:
        print(f'- There are {len(split):} data belonging to {split[0]["class_type"]}')

- There are 458 data belonging to Normal
- There are 350 data belonging to Non-vascular MCI
- There are 233 data belonging to Non-vascular dementia


-----

## Configure the Train, Validation, and Test Splits

#### Split the filtered dataset and shuffle them

In [12]:
# Train : Val : Test = 8 : 1 : 1
ratio1 = 0.8
ratio2 = 0.1

metadata_train = []
metadata_val = []
metadata_test = []

for split in splitted_metadata:
    random.shuffle(split)
    
    n1 = round(len(split) * ratio1)
    n2 = n1 + round(len(split) * ratio2)

    metadata_train.extend(split[:n1])
    metadata_val.extend(split[n1:n2])
    metadata_test.extend(split[n2:])

random.shuffle(metadata_train)
random.shuffle(metadata_val)
random.shuffle(metadata_test)

print('Train data size\t\t:', len(metadata_train))
print('Validation data size\t:', len(metadata_val))
print('Test data size\t\t:', len(metadata_test))

print('\n', '--- Recheck ---', '\n')
train_class_nums = np.zeros((len(class_label_to_name)), dtype=np.int32)
for m in metadata_train:
    train_class_nums[m['class_label']] += 1

val_class_nums = np.zeros((len(class_label_to_name)), dtype=np.int32)
for m in metadata_val:
    val_class_nums[m['class_label']] += 1

test_class_nums = np.zeros((len(class_label_to_name)), dtype=np.int32)
for m in metadata_test:
    test_class_nums[m['class_label']] += 1

print('Train data label distribution\t:', train_class_nums, train_class_nums.sum())
print('Val data label distribution\t:', val_class_nums, val_class_nums.sum())
print('Test data label distribution\t:', test_class_nums, test_class_nums.sum())

Train data size		: 832
Validation data size	: 104
Test data size		: 105

 --- Recheck --- 

Train data label distribution	: [366 280 186] 832
Val data label distribution	: [46 35 23] 104
Test data label distribution	: [46 35 24] 105


-----

## Check Signal Loading Time by Data Format

In [13]:
class EegDataset_np(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.npy')
        signal = np.load(fname)
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [14]:
class EegDataset_np_memmap(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.dat')
        signal = np.memmap(fname, dtype='int32', mode='r').reshape(21, -1)
        # signal = np.load(fname).astype('float32')
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [15]:
class EegDataset_feather(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.feather')
        # signal = pd.read_feather(fname).to_numpy().T
        # signal = pd.read_feather(fname).values.T
        signal = feather.read_feather(fname).values.T
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [16]:
class EegDataset_feather2(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.feather')
        df = feather.read_feather(fname)
        signal = df.values.T
        sample = {'signal': signal,
                  'channel': df.columns.to_list(),
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [17]:
class EegDataset_parquet(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.parquet')
        # signal = pd.read_parquet(fname).to_numpy().T
        signal = pd.read_parquet(fname).values.T
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [18]:
class EegDataset_jay(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.jay')
        signal = dt.fread(fname).to_numpy().T
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [19]:
class EegDataset_h5(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.f_handle = h5py.File(os.path.join(self.root_dir, 'signal', 'signal.h5'), 'r')
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        #fname = os.path.join(self.root_dir, 'signal', 'signal.h5')
        # signal = pd.read_hdf(fname, m['serial']).to_numpy().T
        signal = self.f_handle[m['serial']][:]
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [44]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

print('Current PyTorch device:', device)

if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed = transforms.Compose([
    EegRandomCrop(crop_length=200*20, multiple=2), # 20 seconds
    EegDropPhoticChannel(),
    EegToTensor()
])

Current PyTorch device: cuda


### Test NumPy Pickle

In [45]:
%%time

train_dataset = EegDataset_np(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

CPU times: total: 47.6 s
Wall time: 3.97 s


In [46]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

[tensor([[ -6.,  -2.,   3.,  ...,   0.,  -2.,  -5.],
        [ -6.,  -4.,  -1.,  ...,   3.,   2.,   0.],
        [ -6.,  -2.,   0.,  ...,  -2.,  -3.,  -5.],
        ...,
        [  7.,  10.,  11.,  ...,  -6.,  -8.,  -7.],
        [ -6.,  -7.,  -8.,  ...,   4.,   2.,   4.],
        [-96., -15.,  35.,  ...,  38., -31.,  21.]]),
 tensor([[ -4.,  -3.,  -1.,  ...,   1.,   2.,   3.],
        [  1.,   2.,   3.,  ...,  -5.,  -5.,  -5.],
        [ -6.,  -5.,  -3.,  ...,  -2.,   0.,   2.],
        ...,
        [ -8.,  -5.,  -4.,  ...,  -8.,  -7.,  -6.],
        [  4.,   5.,   5.,  ...,  -5.,  -3.,  -2.],
        [-23., -10.,  82.,  ...,  30.,  12.,  78.]])]
torch.float32


### Test NumPy Memmap

In [47]:
%%time

train_dataset = EegDataset_np_memmap(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

CPU times: total: 10.5 s
Wall time: 876 ms


In [48]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

[tensor([[-49., -50., -51.,  ...,  29.,  30.,  31.],
        [-21., -24., -25.,  ...,  -5.,  -5.,  -4.],
        [ -2.,  -5.,  -7.,  ..., -12., -10., -10.],
        ...,
        [ -6.,  -7.,  -8.,  ..., -32., -32., -30.],
        [-10.,  -8.,  -9.,  ...,   3.,   4.,   6.],
        [ 42.,  -2.,  29.,  ..., -92., -21.,  24.]]),
 tensor([[ -2.,  -1.,  -3.,  ..., -15., -13., -13.],
        [-14., -13., -14.,  ..., -16., -12., -13.],
        [  5.,   5.,   6.,  ...,   2.,   3.,   4.],
        ...,
        [ -9.,  -8.,  -7.,  ...,  -2.,   0.,   2.],
        [  7.,   8.,   8.,  ...,   0.,   1.,   3.],
        [-10., -15., -45.,  ...,  86.,  14.,  45.]])]
torch.float32


### Test Feather

In [49]:
%%time

train_dataset = EegDataset_feather(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

CPU times: total: 55.2 s
Wall time: 3.55 s


In [50]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

[tensor([[ -7.,  -4.,  -3.,  ...,   4.,   5.,   5.],
        [ 13.,  12.,   9.,  ..., -11., -11., -13.],
        [  2.,  -2.,  -4.,  ...,   4.,   5.,   6.],
        ...,
        [  8.,   6.,   3.,  ...,  15.,  14.,  12.],
        [  7.,   5.,   3.,  ...,   2.,   1.,   0.],
        [-72., -43.,  12.,  ...,  23., -27., -93.]]),
 tensor([[  5.,   9.,  10.,  ..., -34., -34., -36.],
        [-11., -13., -16.,  ..., -11., -10.,  -9.],
        [  4.,   2.,   0.,  ...,  -2.,   0.,   4.],
        ...,
        [  4.,   3.,   1.,  ...,   0.,   0.,   2.],
        [  2.,   1.,   1.,  ...,  -8.,  -7.,  -5.],
        [-50., -79., -34.,  ...,  48.,  92., -12.]])]
torch.float32


In [51]:
%%time

train_dataset = EegDataset_feather2(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

CPU times: total: 57.3 s
Wall time: 3.56 s


In [52]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

[tensor([[  1.,  -1.,   0.,  ..., -10.,  -6.,  -5.],
        [  5.,   3.,   7.,  ...,  -4.,  -1.,  -2.],
        [-11.,  -9., -10.,  ...,   5.,   7.,   8.],
        ...,
        [  8.,  10.,  10.,  ..., -11., -10.,  -9.],
        [  0.,   1.,   0.,  ...,   6.,   4.,   5.],
        [-25., -43.,  -3.,  ..., -20.,   6.,  26.]]),
 tensor([[ -7.,  -6.,  -2.,  ..., -79., -78., -77.],
        [  2.,   3.,   3.,  ..., -29., -28., -28.],
        [ -1.,  -2.,  -2.,  ...,   1.,   3.,   3.],
        ...,
        [ -4.,  -5.,  -5.,  ...,   4.,   6.,   8.],
        [ -3.,  -4.,  -5.,  ...,   6.,   8.,   8.],
        [-64.,  71.,  91.,  ...,  68., -67., -91.]])]
torch.float32


### Test Parquet

In [29]:
%%time

train_dataset = EegDataset_parquet(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

CPU times: total: 2min 45s
Wall time: 13.1 s


In [30]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

[tensor([[  6.,   1.,   1.,  ..., -18., -20., -21.],
        [  8.,   6.,   5.,  ..., -19., -21., -22.],
        [ -6.,  -6.,  -6.,  ...,   6.,   4.,   4.],
        ...,
        [  0.,   0.,   1.,  ...,   5.,   5.,   6.],
        [ -3.,  -1.,   1.,  ...,  -2.,   0.,   0.],
        [ -3.,  52.,  -9.,  ...,  41., -84., -54.]]),
 tensor([[  5.,   8.,  13.,  ...,  -1.,   1.,  -2.],
        [ 12.,  14.,  16.,  ..., -13., -12., -13.],
        [ -2.,   1.,  -2.,  ...,   4.,   4.,   5.],
        ...,
        [ -3.,  -3.,  -2.,  ...,  18.,  19.,  21.],
        [  1.,  -1.,  -1.,  ...,  -1.,   0.,   1.],
        [-73., -25.,  46.,  ...,   8., -33., -90.]])]
torch.float32


### Test Jay

In [31]:
%%time

train_dataset = EegDataset_jay(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

CPU times: total: 2min 57s
Wall time: 7.81 s


In [32]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

[tensor([[ -10.,  -12.,  -11.,  ...,  -12.,  -13.,  -14.],
        [  -7.,   -7.,   -6.,  ...,   -4.,   -5.,   -6.],
        [  -6.,   -7.,   -8.,  ...,   -6.,   -6.,   -5.],
        ...,
        [   9.,    9.,   10.,  ...,   11.,   11.,    9.],
        [   6.,    6.,    6.,  ...,    5.,    4.,    5.],
        [  46.,  -81.,  -56.,  ...,   53.,  -18., -106.]]),
 tensor([[ 35.,  25.,  17.,  ..., -41., -40., -40.],
        [  8.,   7.,   6.,  ...,   1.,   0.,  -2.],
        [-14., -12., -11.,  ...,   1.,   3.,   1.],
        ...,
        [  4.,   5.,   5.,  ...,   0.,   1.,  -1.],
        [ -7.,  -4.,  -2.,  ...,   3.,   4.,   5.],
        [-16., -85., -53.,  ...,   2.,  51., -36.]])]
torch.float32


### Test HDF5

In [33]:
%%time

train_dataset = EegDataset_h5(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

CPU times: total: 1min 45s
Wall time: 8.64 s


In [34]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

[tensor([[-11., -13., -12.,  ...,  -9., -13., -14.],
        [  1.,   1.,   0.,  ...,  -7.,  -8.,  -7.],
        [  3.,   3.,   2.,  ...,   3.,   4.,   6.],
        ...,
        [  5.,   3.,   2.,  ...,  -1.,   1.,   2.],
        [  3.,   3.,   3.,  ...,   1.,   3.,   4.],
        [  9.,  83.,  54.,  ..., -32., -58.,  20.]]),
 tensor([[-27., -27., -27.,  ..., -39., -39., -40.],
        [ -8.,  -9.,  -9.,  ..., -16., -14., -14.],
        [  0.,  -2.,  -2.,  ...,   2.,   2.,   4.],
        ...,
        [ -8.,  -8.,  -9.,  ...,  -6.,  -6.,  -6.],
        [  2.,   3.,   4.,  ...,   8.,   9.,  10.],
        [ -4.,  51.,  -6.,  ...,  69.,  68.,  -4.]])]
torch.float32


-----

## Check Event Loading Time by Data Format

In [35]:
class EegDataset_event_json(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'event', m['serial'] + '.json')
        with open(fname, 'r') as json_file:
            event = json.load(json_file)        
        sample = {'signal': np.zeros((20, 100)),
                  'event': event,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [36]:
class EegDataset_event_feather(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'event', m['serial'] + '.feather')
        event = pd.read_feather(fname)
        sample = {'signal': np.zeros((20, 100)),
                  'event': event,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [37]:
composed = transforms.Compose([
    EegToTensor()
])

### Test JSON

In [38]:
%%time

train_dataset = EegDataset_event_json(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(10):
    for i_batch, sample_batched in enumerate(train_loader):
        sample_batched['signal'].to(device)
        sample_batched['age'].to(device)
        sample_batched['class_label'].to(device)

CPU times: total: 16.2 s
Wall time: 1.36 s


In [39]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['event'])

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
[[0, 'Start Recording'],
 [0, 'New Montage - Montage 002'],
 [1236, 'Eyes Open'],
 [2202, 'Eyes Closed'],
 [5268, 'Eyes Open'],
 [6654, 'Eyes Closed'],
 [13332, 'Eyes Open'],
 [14466, 'Eyes Closed'],
 [19800, 'Eyes Open'],
 [20808, 'Eyes Closed'],
 [34554, 'Eyes Open'],
 [35310, 'Eyes Closed'],
 [40686, 'Eyes Open'],
 [41736, 'Eyes Closed'],
 [48246, 'Eyes Open'],
 [49338, 'Eyes Closed'],
 [58367, 'Eyes Open'],
 [59208, 'Eyes Closed'],
 [64667, 'Eyes Open'],
 [65592, 'Eyes Closed'],
 [70674, 'Eyes Open'],
 [71514, 'Eyes Closed'],
 [81468, 'Eyes Open'],
 [82350, 'Eyes Closed'],
 [88692, 'Eyes Open'],
 [89616, 'Eyes Closed'],
 [96210, 'Eyes Open'],
 [97176, 'Eyes Closed'],
 [103602, 'Eyes Open'],
 [104609, 'Eyes Closed'],
 [110994, 'Eyes 

### Test Feather

In [40]:
%%time

train_dataset = EegDataset_event_feather(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(10):
    for i_batch, sample_batched in enumerate(train_loader):
        sample_batched['signal'].to(device)
        sample_batched['age'].to(device)
        sample_batched['class_label'].to(device)

CPU times: total: 1min 26s
Wall time: 7.16 s


In [41]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['event'])

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
    timing                      event
0        0            Start Recording
1        0  New Montage - Montage 002
2     1236                  Eyes Open
3     2202                Eyes Closed
4     5268                  Eyes Open
5     6654                Eyes Closed
6    13332                  Eyes Open
7    14466                Eyes Closed
8    19800                  Eyes Open
9    20808                Eyes Closed
10   34554                  Eyes Open
11   35310                Eyes Closed
12   40686                  Eyes Open
13   41736                Eyes Closed
14   48246                  Eyes Open
15   49338                Eyes Closed
16   58367                  Eyes Open
17   59208                Eyes Closed
18   64667              