# Dataset and DataLoader

This notebook loads the `CAUEEG` dataset, tests some useful preprocessing, and makes up the PyTorch DataLoader instances for the training.

-----

## Configurations

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
import glob
import json
import pprint

import numpy as np
import random
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

# custom package
from datasets.caueeg_dataset import *
from datasets.caueeg_script import *
from datasets.pipeline import *

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.11.0+cu113
cuda is available.


In [4]:
# Data file path
data_path = r'local/dataset/02_Curated_Data_220419/'

-----

## Benchmark 1

In [24]:
with open(os.path.join(data_path, 'task1.json'), 'r') as json_file:
    task1_dict = json.load(json_file)
    
with open(os.path.join(data_path, 'annotation_debug.json'), 'r') as json_file:
    annotation_debug = json.load(json_file)

In [23]:
train_serials = [d['serial'] for d in task1_dict['train_split']]
val_serials = [d['serial'] for d in task1_dict['validation_split']]
test_serials = [d['serial'] for d in task1_dict['test_split']]

In [37]:
train_edfname = []
for serial in train_serials:
    train_edfname.append(annotation_debug['data'][int(serial) - 1]['edfname'].split('_')[0])
    
val_edfname = []
for serial in val_serials:
    val_edfname.append(annotation_debug['data'][int(serial) - 1]['edfname'].split('_')[0])
    
test_edfname = []
for serial in test_serials:
    test_edfname.append(annotation_debug['data'][int(serial) - 1]['edfname'].split('_')[0])

In [38]:
print('Train:', len(train_edfname), len(set(train_edfname)))
print('Val:', len(val_edfname), len(set(val_edfname)))
print('Test:', len(test_edfname), len(set(test_edfname)))

Train: 1110 947
Val: 139 135
Test: 139 138


In [39]:
print('Train & Val:',len(set(train_edfname) & set(val_edfname)))
print('Train & Test:',len(set(train_edfname) & set(test_edfname)))
print('Val & Test:',len(set(val_edfname) & set(test_edfname)))

Train & Val: 27
Train & Test: 28
Val & Test: 3


-----

## Benchmark 2

In [40]:
with open(os.path.join(data_path, 'task2.json'), 'r') as json_file:
    task2_dict = json.load(json_file)
    
with open(os.path.join(data_path, 'annotation_debug.json'), 'r') as json_file:
    annotation_debug = json.load(json_file)

In [41]:
train_serials = [d['serial'] for d in task2_dict['train_split']]
val_serials = [d['serial'] for d in task2_dict['validation_split']]
test_serials = [d['serial'] for d in task2_dict['test_split']]

In [42]:
train_edfname = []
for serial in train_serials:
    train_edfname.append(annotation_debug['data'][int(serial) - 1]['edfname'].split('_')[0])
    
val_edfname = []
for serial in val_serials:
    val_edfname.append(annotation_debug['data'][int(serial) - 1]['edfname'].split('_')[0])
    
test_edfname = []
for serial in test_serials:
    test_edfname.append(annotation_debug['data'][int(serial) - 1]['edfname'].split('_')[0])

In [43]:
print('Train:', len(train_edfname), len(set(train_edfname)))
print('Val:', len(val_edfname), len(set(val_edfname)))
print('Test:', len(test_edfname), len(set(test_edfname)))

Train: 950 806
Val: 119 117
Test: 118 116


In [44]:
print('Train & Val:',len(set(train_edfname) & set(val_edfname)))
print('Train & Test:',len(set(train_edfname) & set(test_edfname)))
print('Val & Test:',len(set(val_edfname) & set(test_edfname)))

Train & Val: 35
Train & Test: 28
Val & Test: 3


In [46]:
set(train_edfname) & set(test_edfname)

{'00048377',
 '00287432',
 '00480292',
 '00635487',
 '00646912',
 '00666711',
 '00671212',
 '00671744',
 '00759679',
 '00805584',
 '00824216',
 '00883719',
 '00978569',
 '00988278',
 '01063007',
 '01080162',
 '01081922',
 '01132092',
 '01135534',
 '01135545',
 '01139924',
 '01159816',
 '01225385',
 '01235034',
 '01256391',
 '01261352',
 '01274934',
 '01344212'}