In [82]:
import numpy as np
import os
import pickle
from data.data_utils import *
from data.dataloader_ssl import load_dataset_ssl
from data.dataloader_detection import load_dataset_detection
from constants import *

In [None]:
# You have to modify the path below.

In [10]:
adj_mat_dir = './data/electrode_graph/adj_mx_3d.pkl'

with open(adj_mat_dir, 'rb') as pf:
    adj_mat = pickle.load(pf)
    adj_mat = adj_mat[-1]

In [41]:
_, datasets, scaler = load_dataset_ssl(
    input_dir='/data/TUSZ_reasmple',
    raw_data_dir='/data/TUSZ',
    train_batch_size=7,
    test_batch_size=7,
    time_step_size=1,
    input_len=12,
    output_len=12,
    standardize=True,
    num_workers=8,
    augmentation=False,
    adj_mat_dir='./data/electrode_graph/adj_mx_3d.pkl',
    graph_type='combined',
    top_k=3,
    filter_type='laplacian',
    use_fft=False,
    preproc_dir=None
)


In [88]:
_, detection_datasets, _ =load_dataset_detection(
    input_dir='/data/TUSZ_reasmple',
    raw_data_dir='/data/TUSZ',
    train_batch_size=7,
    test_batch_size=7,
    time_step_size=1,
    max_seq_len=12,
    standardize=True,
    num_workers=8,
    augmentation=False,
    adj_mat_dir='./data/electrode_graph/adj_mx_3d.pkl',
    graph_type='combined',
    top_k=3,
    filter_type='laplacian',
    use_fft=False,
    sampling_ratio=1,
    seed=10,
    preproc_dir=None,
)

number of seizure files:  13646
Number of clips in train: 27292
Number of clips in dev: 28057
Number of clips in test: 44959


In [42]:
datasets

{'train': <data.dataloader_ssl.SeizureDataset at 0x7fbd8fec0c10>,
 'dev': <data.dataloader_ssl.SeizureDataset at 0x7fbeb36b36d0>,
 'test': <data.dataloader_ssl.SeizureDataset at 0x7fbd8fec0760>}

In [89]:
detection_datasets

{'train': <data.dataloader_detection.SeizureDataset at 0x7fbeb2b75580>,
 'dev': <data.dataloader_detection.SeizureDataset at 0x7fbd4905deb0>,
 'test': <data.dataloader_detection.SeizureDataset at 0x7fbeb34786d0>}

In [43]:
train_set_ssl = datasets['train']
dev_set_ssl = datasets['dev']
test_set_ssl = datasets['test']

In [90]:
train_set_det = detection_datasets['train']
dev_set_det = detection_datasets['dev']
test_set_det = detection_datasets['test']

training set: 27292
validation set: 28057
testing set: 44959


In [93]:
train_set_det[0][0].size()

torch.Size([12, 19, 200])

# Statistical Information
## SSL dataset:
### Sample size:
1. training: 2700 * 12 * 19 * 200
2. validation: 300 * 12 * 19 * 200

## Train&Test&Val dataset:
### Sample size:
1. training: 2400 * 12 * 19 * 200
2. validation: 600 * 12 * 19 * 200
3. testing: 3900 * 12 * 19 * 200

# EEG database generation

In [99]:
train_set_det[0][0]

array([[[-35.745934  , -40.66693   , -35.91564   , ...,   8.490601  ,
           0.46872663,  -7.520901  ],
        [-46.802982  , -46.37525   , -41.00268   , ...,   7.1955004 ,
          -1.1872213 , -11.204208  ],
        [-51.609562  , -52.843307  , -48.513218  , ..., -21.573282  ,
         -30.002481  , -36.658726  ],
        ...,
        [-72.727234  , -72.62722   , -67.358475  , ..., -42.114433  ,
         -51.35033   , -59.054707  ],
        [-42.76837   , -41.847393  , -36.6317    , ..., -36.58768   ,
         -41.499382  , -46.518955  ],
        [ -1.1577948 ,  -0.8255753 ,   4.6585774 , ...,  -1.4473364 ,
          -4.146161  ,  -6.230305  ]],

       [[-11.692073  , -20.343641  , -19.421991  , ...,  33.943066  ,
          46.822598  ,  59.09726   ],
        [-16.464197  , -23.827961  , -24.651245  , ...,  29.848948  ,
          41.259216  ,  53.955677  ],
        [-40.128345  , -45.840836  , -44.30438   , ...,  20.112703  ,
          29.609285  ,  41.033188  ],
        ...,


In [100]:
ssl_train_idx = np.random.choice(len(train_set_ssl), 2700, replace=False)
ssl_val_idx = np.random.choice(len(dev_set_ssl), 300, replace=False)

train_idx = np.random.choice(len(train_set_det), 2400, replace=False)
val_idx = np.random.choice(len(dev_set_det), 600, replace=False)
test_idx = np.random.choice(len(test_set_det), 3900, replace=False)

In [123]:
import time
start_time=time.time()

ssl_train_data = [train_set_ssl[i][0].numpy() for i in ssl_train_idx]
ssl_val_data = [dev_set_ssl[i][0].numpy() for i in ssl_val_idx]

train_data = [train_set_det[i][0].numpy() for i in train_idx]
train_label = [train_set_det[i][1].numpy() for i in train_idx]

val_data = [dev_set_det[i][0].numpy() for i in val_idx]
val_label = [dev_set_det[i][1].numpy() for i in val_idx]

test_data = [test_set_det[i][0].numpy() for i in test_idx]
test_label = [test_set_det[i][1].numpy() for i in test_idx]

ssl done: 208.72148847579956
train done: 349.81030559539795
val done: 385.46712350845337
test done: 653.7382099628448


In [128]:
save_dir = '/data/EEG_database/'
np.savez_compressed(os.path.join(save_dir, 'ssl_train_data.npz'), a=np.array(ssl_train_data))
np.savez_compressed(os.path.join(save_dir, 'ssl_val_data.npz'), a=np.array(ssl_val_data))

np.savez_compressed(os.path.join(save_dir, 'train_data.npz'), a=np.array(train_data))
np.savez_compressed(os.path.join(save_dir, 'train_label.npz'), a=np.array(train_label))

np.savez_compressed(os.path.join(save_dir, 'val_data.npz'), a=np.array(val_data))
np.savez_compressed(os.path.join(save_dir, 'val_label.npz'), a=np.array(val_label))

np.savez_compressed(os.path.join(save_dir, 'test_data.npz'), a=np.array(test_data))
np.savez_compressed(os.path.join(save_dir, 'test_label.npz'), a=np.array(test_label))