In [None]:


"""
Created on Mon May  2 17:36:56 2022

Data functions for splitting the dataset into train-valid-test sets,
creating fixed-length windows, and channeling the inputs into pytorch DataLoaders

Each method has its own description in it's header section.'

The methods defined in this file are:
    

@author: Kitti
"""

In [None]:


import numpy as np
# import pandas as pd
import matplotlib.pyplot as plt
import mne
from braindecode.datasets.tuh import TUHAbnormal, TUH
from braindecode.models import get_output_shape
from braindecode.preprocessing import create_fixed_length_windows
from torch.utils.data import DataLoader
from multiprocessing import Pool, cpu_count

In [None]:



def train_valid_test_split(tuh_preproc, train_size=0.9):
    train_test_splits = tuh_preproc.split("train")
    tuh_train, tuh_test = train_test_splits['True'], train_test_splits['False']

    train_len = len(tuh_train.datasets)
    train_size = train_size

    tuh_train_inds = [*range(train_len)]
    train_index = int(train_len*train_size)-1
    new_train_inds = [*range(train_index)]


    new_val_inds = list(set(tuh_train_inds)-set(new_train_inds))
    
    train_val_splits = tuh_train.split({'train': new_train_inds, 'val': new_val_inds})
    tuh_train, tuh_val = train_val_splits['train'],train_val_splits['val']
    
    return tuh_train, tuh_val, tuh_test

In [None]:



#Rename it later

def get_parameters_for_model(tuh_train):
    n_chans, input_size_samples = tuh_train[0][0].shape
    
    return n_chans, input_size_samples

In [None]:



# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.

def create_windows(model, tuh_train, tuh_val, tuh_test, n_jobs, input_window_samples=6000, in_chans=21):

    n_preds_per_input = get_output_shape(model, in_chans, input_window_samples)[2] #
    
    train_set = create_fixed_length_windows(
        tuh_train,
        start_offset_samples=0,
        stop_offset_samples=None,
        window_size_samples=input_window_samples,
        window_stride_samples=n_preds_per_input,
        drop_last_window=False,
        preload=False,
        n_jobs=n_jobs,
        mapping={False: 0, True: 1},  # map non-digit targets
    )
    
    val_set = create_fixed_length_windows(
        tuh_val,
        start_offset_samples=0,
        stop_offset_samples=None,
        window_size_samples=input_window_samples,
        window_stride_samples=n_preds_per_input,
        drop_last_window=False,
        preload=False,
        n_jobs=n_jobs,
        mapping={False: 0, True: 1},  # map non-digit targets
    )
    
    test_set = create_fixed_length_windows(
        tuh_test,
        start_offset_samples=0,
        stop_offset_samples=None,
        window_size_samples=input_window_samples,
        window_stride_samples=n_preds_per_input,
        drop_last_window=False,
        preload=False,
        n_jobs=n_jobs,
        mapping={False: 0, True: 1},  # map non-digit targets
    )
    
    return train_set, val_set, test_set, n_preds_per_input

In [None]:



def transform_target(train_set, val_set, test_set, n_preds_per_input): 
    train_set.target_transform = lambda x: np.full((n_preds_per_input), fill_value=x)
    val_set.target_transform = lambda x: np.full((n_preds_per_input), fill_value=x)
    test_set.target_transform = lambda x: np.full((n_preds_per_input), fill_value=x)

In [None]:



def create_dataloaders(train_set, val_set, test_set, batch_size=64):

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
    
    return train_loader, val_loader, test_loader

In [None]:



def data_transform(tuh_train, tuh_val, tuh_test, model, n_jobs, train_size=0.9, batch_size=64, input_window_samples=6000, in_chans=21):

    train_set, val_set, test_set, n_preds_per_input = create_windows(model, tuh_train, tuh_val, tuh_test, n_jobs, input_window_samples=input_window_samples, in_chans=in_chans)
    transform_target(train_set, val_set, test_set, n_preds_per_input)
    train_loader, val_loader, test_loader = create_dataloaders(train_set, val_set, test_set, batch_size=batch_size)
    
    return train_loader, val_loader, test_loader