# EEG Data File Format Test

Given the metadata generated in `01_Data_Curation1` notebook, this notebook tests some file formats to read and write the EEG dataset and decides the best one.  
The following file formats are tested: `NumPy pickle`, `Feather`, `Parquet`, `Jay`, and `HDF5`.

-----

## Load Packages and Configure Notebook Environments

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
# Load some packages
import os
import re
import copy
import glob
from openpyxl import load_workbook, Workbook, styles
import json

import pyedflib
import datetime
from dateutil.relativedelta import relativedelta

import pprint
import warnings
import ctypes
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import pandas as pd
import pyarrow.feather as feather
import datatable as dt

import numpy as np
import random
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# custom package
from datasets.cau_eeg_dataset import *
from datasets.utils import *
from datasets.pipeline import *

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.11.0
cuda is available.


-----

## Load and Check Data

In [4]:
# Data file path
original_path = r'local/dataset/01_Original_Data_220307'
curated_path = r'local/dataset/02_Curated_Data_temp'

os.makedirs(curated_path, exist_ok=True)

In [5]:
meta_file = os.path.join(original_path, r'new_DB_list.xlsx')
ws = load_workbook(meta_file, data_only=True)['metadata']

metadata = []

num = 2
while True:
    m = dict()
    m['edfname'] = ws.cell(row=num, column=1).value
    m['dx1'] = ws.cell(row=num, column=2).value
    m['birth'] = ws.cell(row=num, column=3).value
    m['anomaly'] = True if ws.cell(row=num, column=4).value is not None else False
    num += 1
    
    # check whether the row is empty (which is EOF condition)
    if m['edfname'] is None:
        break
    elif m['anomaly']:
        continue
        
    # move the pivot row
    metadata.append(m)
    
print('Size:', len(metadata))
print()
print('Loaded metadata (first three displayed):')
print(json.dumps(metadata[:3], indent=4))

Size: 1390

Loaded metadata (first three displayed):
[
    {
        "edfname": "00001809_261018",
        "dx1": "mci_rf",
        "birth": 400602,
        "anomaly": false
    },
    {
        "edfname": "00029426_020817",
        "dx1": "smi",
        "birth": 601204,
        "anomaly": false
    },
    {
        "edfname": "00047327_090718",
        "dx1": "vascular mci",
        "birth": 241019,
        "anomaly": false
    }
]


In [6]:
m = metadata[0]
edf_file = os.path.join(original_path, m['edfname'] + '.edf')
signals, signal_headers, edf_header  = pyedflib.highlevel.read_edf(edf_file)

refer_headers = signal_headers

print(np.unique(signals.reshape(-1)).shape, signals.max() - signals.min()) # check the signal is discrete
print()
pprint.pp(m)
print()
pprint.pp(edf_header)
print()
pprint.pp(signal_headers)

(2820,) 4843.0

{'edfname': '00001809_261018',
 'dx1': 'mci_rf',
 'birth': 400602,
 'anomaly': False}

{'technician': '',
 'recording_additional': '',
 'patientname': '',
 'patient_additional': '',
 'patientcode': '',
 'equipment': '',
 'admincode': '',
 'gender': '',
 'startdate': datetime.datetime(2018, 10, 26, 15, 46, 26),
 'birthdate': '',
 'annotations': []}

[{'label': 'Fp1-AVG',
  'dimension': 'uV',
  'sample_rate': 200.0,
  'physical_max': 32767.0,
  'physical_min': -32768.0,
  'digital_max': 32767,
  'digital_min': -32768,
  'prefilter': '',
  'transducer': 'E'},
 {'label': 'F3-AVG',
  'dimension': 'uV',
  'sample_rate': 200.0,
  'physical_max': 32767.0,
  'physical_min': -32768.0,
  'digital_max': 32767,
  'digital_min': -32768,
  'prefilter': '',
  'transducer': 'E'},
 {'label': 'C3-AVG',
  'dimension': 'uV',
  'sample_rate': 200.0,
  'physical_max': 32767.0,
  'physical_min': -32768.0,
  'digital_max': 32767,
  'digital_min': -32768,
  'prefilter': '',
  'transducer': 'E'},

---

## Arrange and Save the Data

1. Remove EDF files with irregular signal header
2. Extract EDF signal and save the data after reformatting
3. Save all metadata as JSON and XLSX where JSON is for further use and XLSX for examination of human
    - `metadata_debug`: Full inclusion metadata for debugging
    - `metadata_public`: Metadata without personal information

In [7]:
start_latency_length = 10 * 200  # 10 seconds

warnings.filterwarnings(action='ignore')

text = f'Delete ALL files in {curated_path}?'
if ctypes.windll.user32.MessageBoxExW(0, text, 'Question', 4) == 6: # Yes
    for f in glob.glob(os.path.join(curated_path, '*/*')):
        os.remove(f)
    for f in glob.glob(os.path.join(curated_path, '*.*')):
        os.remove(f)
        
metadata_debug = []
metadata_public = []

for m in tqdm(metadata):
    # EDF file check
    edf_file = os.path.join(original_path, m['edfname'] + '.edf')
    signals, signal_headers, edf_header = pyedflib.highlevel.read_edf(edf_file)
        
    if refer_headers != signal_headers:
        print('- Signal header differs from the majority:', m['edfname'])
        continue
        
    # calculate age
    age = calculate_age(birth_to_datetime(m['birth']), 
                        edf_header['startdate'])
    
    if age is None:
        print('- The age information is unknown:', m['edfname'])
        continue
    
    # EDF recoding events
    event_file = os.path.join(original_path, m['edfname'] + '.xlsx')
    wb = load_workbook(event_file, data_only=True)
    ws = wb[wb.sheetnames[0]]
    
    num = 2
    event = [] 
    
    while True:
        t = ws.cell(row=num, column=3).value
        e = ws.cell(row=num, column=4).value
        
        if t is None:
            break
        
        t = edf_header['startdate'].strftime('%Y%m%d') + t
        t = datetime.datetime.strptime(t, '%Y%m%d %H:%M:%S.%f')
        
        if num == 2: 
            startTime = t
            
        t = int(np.floor((t - startTime).total_seconds() * 200))
        event.append((t, e))
        num += 1
    
    # metadata_debug
    m2 = {}
    m2['serial'] = f'{len(metadata_debug) + 1:05}'
    m2['edfname'] = m['edfname']
    m2['birth'] = birth_to_datetime(m['birth'])
    m2['record'] = edf_header['startdate']
    m2['age'] = age
    m2['dx1'] = m['dx1']
    m2['label'] = MultiLabel.load_from_string(m['dx1'])
    m2['event'] = event
    metadata_debug.append(m2)
    
    # metadata_public
    m3 = {}
    m3['serial'] = m2['serial']
    m3['age'] = age
    m3['label'] = m2['label']
    metadata_public.append(m3)
    
    # EDF signal
    signals = trim_tailing_zeros(signals)        # trim garbage zeros
    signals = signals[:, start_latency_length:]  # throw away some signal right after starting recording
    signals = signals.astype('int32')
    df = pd.DataFrame(data=signals.T, columns=[s_h['label'] for s_h in signal_headers], dtype=np.int32)
    
    np.save(os.path.join(curated_path, 'signal', m2['serial']), signals)
    df.to_feather(os.path.join(curated_path, 'signal', m2['serial'] + '.feather'))
    df.to_parquet(os.path.join(curated_path, 'signal', m2['serial']+ '.parquet'))
    dt.Frame(df).to_jay(os.path.join(curated_path, 'signal', m2['serial'] + '.jay'))
    df.to_hdf(os.path.join(curated_path, 'signal', 'signal.h5') , m2['serial'])

    # event
    df = pd.DataFrame(data=event, columns=['timing', 'event'])
    with open(os.path.join(curated_path, 'event', m2['serial']) + '.json', 'w') as json_file:
        json.dump(event, json_file, indent=4, default=serialize_json)    
    df.to_feather(os.path.join(curated_path, 'event', m2['serial']) + '.feather')
    
print('Done.')
print()
print(f'Among {len(metadata)}, {len(metadata_public)} data were saved.')

warnings.filterwarnings(action='default')

  0%|          | 0/1390 [00:00<?, ?it/s]

- The age information is unknown: 00602793_300518
- The age information is unknown: 00850537_061014
Done.

Among 1390, 1388 data were saved.


In [8]:
# save metadata_public as JSON
path = os.path.join(curated_path, 'metadata_public.json')
with open(path, 'w') as json_file:
    json.dump(metadata_public, json_file, indent=4, default=serialize_json)

In [9]:
# load metadata_public
meta_path = os.path.join(curated_path, 'metadata_public.json')
with open(meta_path, 'r') as json_file:
    metadata = json.load(json_file)

pprint.pprint(metadata[0])

{'age': 78,
 'label': ['mci', 'mci_amnestic', 'mci_amnestic_rf'],
 'serial': '00001'}


-----

## Data Filtering by Diagnosis

#### Non-Vascular Dementia, Non-Vascular MCI, Normal

In [10]:
diagnosis_filter = [
    # Normal
    {'type': 'Normal',
     'include': ['normal'], 
     'exclude': []},
    # Non-vascular MCI
    {'type': 'Non-vascular MCI',
     'include': ['mci'], 
     'exclude': ['mci_vascular']},
    # Non-vascular dementia
    {'type': 'Non-vascular dementia',
     'include': ['dementia'], 
     'exclude': ['vd']},
]

def generate_class_label(label):
    for c, f in enumerate(diagnosis_filter):
        # inc = set(f['include']) & set(label) == set(f['include'])
        inc = len(set(f['include']) & set(label)) > 0        
        exc = len(set(f['exclude']) & set(label)) == 0
        if  inc and exc:
            return (c, f['type'])
    return (-1, 'The others')

class_label_to_type = [d_f['type'] for d_f in diagnosis_filter]
print('class_label_to_type:', class_label_to_type)

class_label_to_type: ['Normal', 'Non-vascular MCI', 'Non-vascular dementia']


In [11]:
splitted_metadata = [[] for i in diagnosis_filter]

for m in metadata:
    c, n = generate_class_label(m['label'])
    if c >= 0:
        m['class_type'] = n
        m['class_label'] = c
        splitted_metadata[c].append(m)
        
for i, split in enumerate(splitted_metadata):
    if len(split) == 0:
        print(f'(Warning) Split group {i} has no data.')
    else:
        print(f'- There are {len(split):} data belonging to {split[0]["class_type"]}')

- There are 458 data belonging to Normal
- There are 350 data belonging to Non-vascular MCI
- There are 233 data belonging to Non-vascular dementia


-----

## Configure the Train, Validation, and Test Splits

#### Split the filtered dataset and shuffle them

In [12]:
# Train : Val : Test = 8 : 1 : 1
ratio1 = 0.8
ratio2 = 0.1

metadata_train = []
metadata_val = []
metadata_test = []

for split in splitted_metadata:
    random.shuffle(split)
    
    n1 = round(len(split) * ratio1)
    n2 = n1 + round(len(split) * ratio2)

    metadata_train.extend(split[:n1])
    metadata_val.extend(split[n1:n2])
    metadata_test.extend(split[n2:])

random.shuffle(metadata_train)
random.shuffle(metadata_val)
random.shuffle(metadata_test)

print('Train data size\t\t:', len(metadata_train))
print('Validation data size\t:', len(metadata_val))
print('Test data size\t\t:', len(metadata_test))

print('\n', '--- Recheck ---', '\n')
train_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_train:
    train_class_nums[m['class_label']] += 1

val_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_val:
    val_class_nums[m['class_label']] += 1

test_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_test:
    test_class_nums[m['class_label']] += 1

print('Train data label distribution\t:', train_class_nums, train_class_nums.sum())
print('Val data label distribution\t:', val_class_nums, val_class_nums.sum())
print('Test data label distribution\t:', test_class_nums, test_class_nums.sum())

Train data size		: 832
Validation data size	: 104
Test data size		: 105

 --- Recheck --- 

Train data label distribution	: [366 280 186] 832
Val data label distribution	: [46 35 23] 104
Test data label distribution	: [46 35 24] 105


-----

## Check Signal Loading Time by Data Format

In [22]:
class EegDataset_np(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.npy')
        signal = np.load(fname).astype('float32')
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [23]:
class EegDataset_feather(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.feather')
        signal = pd.read_feather(fname).to_numpy().T
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [24]:
class EegDataset_parquet(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.parquet')
        signal = pd.read_parquet(fname).to_numpy().T
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [25]:
class EegDataset_jay(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', m['serial'] + '.jay')
        signal = dt.fread(fname).to_numpy().T
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [26]:
class EegDataset_h5(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'signal', 'signal.h5')
        signal = pd.read_hdf(fname, m['serial']).to_numpy().T
        sample = {'signal': signal,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'event': [],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [27]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

print('Current PyTorch device:', device)

if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
composed = transforms.Compose([
    EegRandomCrop(crop_length=200*20, multiple=2), # 20 seconds
    EegDropPhoticChannel(),
    EegNormalizePerSignal(),
    EegToTensor()
])

Current PyTorch device: cuda


### Test NumPy Pickle

In [38]:
%%time

train_dataset = EegDataset_np(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

CPU times: total: 3min 2s
Wall time: 34.1 s


In [39]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

[tensor([[-3.3522, -3.2946, -3.3152,  ..., -0.1242, -0.1325, -0.1119],
        [-4.3566, -4.3079, -4.3404,  ..., -0.1373, -0.1129, -0.1048],
        [-1.0906, -1.0153, -1.0906,  ...,  0.0758,  0.0758,  0.0758],
        ...,
        [-0.9271, -0.9396, -0.9647,  ...,  0.0127,  0.0127,  0.0127],
        [-0.0912, -0.0912, -0.1094,  ...,  0.2360,  0.1996,  0.1815],
        [ 0.3789,  0.2249,  0.1223,  ...,  0.1325, -0.0626, -0.1858]]),
 tensor([[ 0.4156,  0.4156,  0.4853,  ..., -0.5610, -0.2820, -0.2123],
        [ 0.4940,  0.4940,  0.5827,  ..., -0.2156,  0.0505,  0.2279],
        [-0.9845, -0.8280, -0.6714,  ..., -1.1411, -0.9845, -0.6714],
        ...,
        [ 0.1537,  0.2702,  0.5030,  ...,  1.2016,  1.3180,  1.4345],
        [-1.1479, -1.2525, -1.2525,  ...,  0.2113,  0.4204,  0.5249],
        [-0.1101, -0.1766, -0.0686,  ...,  1.3192,  1.0616,  1.1197]])]
torch.float32


### Test Feather

In [40]:
%%time

train_dataset = EegDataset_feather(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

CPU times: total: 1min 39s
Wall time: 16.4 s


In [41]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

[tensor([[-0.6269, -0.6983, -0.8173,  ...,  1.0148,  0.9910,  0.9672],
        [-0.2354, -0.2626, -0.3442,  ..., -0.5074, -0.5618, -0.5618],
        [-0.2798, -0.2453, -0.1762,  ...,  0.8943,  0.9288,  0.8598],
        ...,
        [ 0.0498,  0.0498, -0.0515,  ...,  1.5016,  1.5016,  1.4678],
        [ 0.1102,  0.1486,  0.1486,  ...,  1.3389,  1.3389,  1.3773],
        [ 1.0724,  1.1464,  1.1570,  ...,  0.3751,  0.3434,  0.3223]]),
 tensor([[ 0.5909,  0.4657,  0.4657,  ...,  2.6199,  2.6951,  2.7452],
        [ 0.5798,  0.3845,  0.5147,  ...,  0.8402,  0.9053,  1.1005],
        [ 0.7264,  0.8804,  0.5724,  ..., -0.8139, -0.3518,  0.1103],
        ...,
        [ 0.1455,  0.1455,  0.0241,  ..., -1.0680, -0.9466, -0.9466],
        [-0.8366, -0.8366, -0.8366,  ..., -0.9195, -0.6707, -0.5878],
        [ 1.4154,  1.0992,  0.7115,  ..., -1.0431, -0.9512, -0.7982]])]
torch.float32


### Test Parquet

In [42]:
%%time

train_dataset = EegDataset_parquet(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

CPU times: total: 2min 47s
Wall time: 27.9 s


In [43]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

[tensor([[-0.5593, -0.5781, -0.7854,  ...,  1.8718,  1.7587,  1.9283],
        [-0.8838, -0.5368, -1.0573,  ..., -0.7971, -1.1441, -1.1441],
        [-0.4284, -1.0146,  2.2681,  ...,  2.7371,  2.6198,  2.5026],
        ...,
        [ 2.0997,  1.8857,  2.2067,  ...,  1.1366,  1.0296,  1.0296],
        [ 0.6551,  0.4951,  0.6551,  ...,  0.3351,  0.3351,  0.6551],
        [-0.7340, -0.7025, -0.7025,  ..., -0.2418, -0.2732, -0.1685]]),
 tensor([[ 0.6115,  0.6545,  0.7405,  ...,  0.9124,  1.0413,  1.2132],
        [ 0.5572,  0.6246,  0.7596,  ...,  1.9065,  2.0414,  2.1763],
        [ 0.9045,  0.7507,  0.7507,  ...,  1.8275,  1.9813,  1.9813],
        ...,
        [ 0.6922,  0.6922,  0.5610,  ..., -0.2265, -0.2265, -0.4890],
        [ 0.3906,  0.4735,  0.3076,  ...,  0.3906,  0.3906,  0.3906],
        [ 0.9183,  0.9948,  0.8200,  ...,  1.0604,  1.1259,  1.1587]])]
torch.float32


### Test Jay

In [44]:
%%time

train_dataset = EegDataset_jay(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)

CPU times: total: 6min 14s
Wall time: 47.5 s


In [45]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

[tensor([[-3.6729, -3.6630, -3.6530,  ..., -0.0396, -0.0446, -0.0146],
        [-3.1607, -3.3226, -3.3449,  ...,  0.0215,  0.0104,  0.0160],
        [ 4.5313,  4.7697,  5.8935,  ...,  0.0362, -0.2022, -0.2703],
        ...,
        [ 3.7753,  3.9661,  4.0933,  ..., -0.1369, -0.1369, -0.1687],
        [ 4.0867,  4.1835,  4.3447,  ..., -0.2027, -0.2995, -0.3317],
        [ 0.3593,  0.4749,  0.3908,  ...,  2.4919,  2.8175,  2.7440]]),
 tensor([[ 0.3633,  0.3633,  0.4955,  ...,  0.0990, -0.2975, -0.0772],
        [-0.3637, -0.2852,  0.0287,  ..., -0.4422, -0.1283,  0.0287],
        [ 0.6135,  0.4735,  0.3335,  ..., -2.1864, -1.9064, -1.3464],
        ...,
        [ 0.4991,  0.4991,  0.4991,  ..., -1.9070, -1.6063, -1.6063],
        [ 1.2605,  1.1668,  0.7920,  ...,  0.6983,  0.9794,  1.2605],
        [ 1.7590,  1.6566,  1.6873,  ..., -0.6785, -0.7604, -0.7707]])]
torch.float32


### Test HDF5

In [46]:
%%time

warnings.filterwarnings(action='ignore')
train_dataset = EegDataset_h5(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)
    
warnings.filterwarnings(action='default')

CPU times: total: 4min 43s
Wall time: 50.8 s


In [47]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['signal'][0].dtype)

[tensor([[-2.4124, -2.4764, -2.5084,  ...,  0.4368,  0.3087,  0.1807],
        [-1.3673, -1.3673, -1.3673,  ..., -0.6582, -0.8608, -0.8608],
        [-0.1513,  0.0243,  0.1121,  ..., -0.8538, -0.6782, -0.7660],
        ...,
        [ 0.5907,  0.1172, -0.5141,  ...,  1.5378,  1.3799,  0.7486],
        [ 0.7186,  0.9050,  0.7186,  ...,  0.7186,  0.9050,  0.5321],
        [ 0.0874, -0.1146, -0.2528,  ...,  0.7463,  0.7782,  0.6932]]),
 tensor([[-1.7081, -1.6583, -1.6583,  ...,  0.4017,  0.3685,  0.3519],
        [-0.0171,  0.1241, -0.6760,  ..., -0.0642, -0.1113, -0.2054],
        [ 0.1934,  0.4106, -0.0962,  ..., -0.2411, -0.0962, -0.0962],
        ...,
        [-0.2234, -0.1130, -0.2234,  ...,  0.5497,  0.2184, -0.1130],
        [ 0.1776,  0.1904,  0.1776,  ...,  0.0237,  0.0109, -0.0148],
        [ 4.9668,  5.3080,  4.8256,  ...,  0.0544,  0.1095,  0.1187]])]
torch.float32


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  (oid, self.atom, self.shape, self._v_chunkshape) = self._open_array()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  (oid, self.atom, self.shape, self._v_chunkshape) = self._open_array()


-----

## Check Event Loading Time by Data Format

In [51]:
class EegDataset_event_json(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'event', m['serial'] + '.json')
        with open(fname, 'r') as json_file:
            event = json.load(json_file)        
        sample = {'signal': np.zeros((20, 100)),
                  'event': event,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [58]:
class EegDataset_event_feather(Dataset):
    """EEG Dataset Class for PyTorch.

    Args:
        root_dir (str): Root path to the EDF data files.
        metadata (list of dict): List of dictionary with metadata.
        transform (callable, optional): Optional transform to be applied on each data.
    """

    def __init__(self, root_dir, metadata, transform=None):
        self.root_dir = root_dir
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        m = self.metadata[idx]
        fname = os.path.join(self.root_dir, 'event', m['serial'] + '.feather')
        event = pd.read_feather(fname)
        sample = {'signal': np.zeros((20, 100)),
                  'event': event,
                  'age': m['age'],
                  'class_label': m['class_label'],
                  'metadata': m}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [59]:
composed = transforms.Compose([
    EegToTensor()
])

### Test JSON

In [64]:
%%time

train_dataset = EegDataset_event_json(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(10):
    for i_batch, sample_batched in enumerate(train_loader):
        sample_batched['signal'].to(device)
        sample_batched['age'].to(device)
        sample_batched['class_label'].to(device)

CPU times: total: 25.1 s
Wall time: 4.33 s


In [65]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['event'])

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
[[0, 'Start Recording'],
 [0, 'New Montage - Montage 002'],
 [2674, 'Eyes Open'],
 [4396, 'Eyes Closed'],
 [15362, 'artifact'],
 [16912, 'Eyes Open'],
 [18844, 'Eyes Closed'],
 [28556, 'Eyes Open'],
 [30698, 'Eyes Closed'],
 [39182, 'Eyes Open'],
 [40778, 'Eyes Closed'],
 [49052, 'Photic On - 3.0 Hz'],
 [51068, 'Photic Off'],
 [53126, 'Photic On - 6.0 Hz'],
 [55141, 'Photic Off'],
 [57200, 'Photic On - 9.0 Hz'],
 [59216, 'Photic Off'],
 [61232, 'Photic On - 12.0 Hz'],
 [63248, 'Photic Off'],
 [63667, 'Eyes Open'],
 [65054, 'Eyes Closed'],
 [65305, 'Photic On - 15.0 Hz'],
 [67322, 'Photic Off'],
 [69380, 'Photic On - 18.0 Hz'],
 [71396, 'Photic Off'],
 [73412, 'Photic On - 21.0 Hz'],
 [75428, 'Photic Off'],
 [77486, 'Photic On - 24.0 Hz'

### Test Feather

In [66]:
%%time

train_dataset = EegDataset_event_feather(curated_path, metadata_train, composed)
train_loader = DataLoader(train_dataset, 
                          batch_size=32, # Random crop will inflate the minibatch size
                          shuffle=False, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

for k in range(10):
    for i_batch, sample_batched in enumerate(train_loader):
        sample_batched['signal'].to(device)
        sample_batched['age'].to(device)
        sample_batched['class_label'].to(device)

CPU times: total: 1min 6s
Wall time: 10.9 s


In [67]:
pprint.pprint(train_dataset[0]['signal'])
pprint.pprint(train_dataset[0]['event'])

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
    timing                      event
0        0            Start Recording
1        0  New Montage - Montage 002
2     2674                  Eyes Open
3     4396                Eyes Closed
4    15362                   artifact
5    16912                  Eyes Open
6    18844                Eyes Closed
7    28556                  Eyes Open
8    30698                Eyes Closed
9    39182                  Eyes Open
10   40778                Eyes Closed
11   49052         Photic On - 3.0 Hz
12   51068                 Photic Off
13   53126         Photic On - 6.0 Hz
14   55141                 Photic Off
15   57200         Photic On - 9.0 Hz
16   59216                 Photic Off
17   61232        Photic On - 12.0 Hz
18   63248              