In [None]:


"""
Data functions for splitting the dataset into train-valid-test sets,
creating fixed-length windows, and channeling the inputs into pytorch DataLoaders

Each method has its own description in it's header section.'

The methods defined in this file are:
    - train_valid_test_split
    - get_parameters_for_model
    - create_windows
    - transform_target
    - create_dataloaders
    - data_transform
"""

In [None]:
# import packages

import numpy as np
# import pandas as pd
import matplotlib.pyplot as plt
import mne
from braindecode.datasets.tuh import TUHAbnormal, TUH
from braindecode.models import get_output_shape
from braindecode.preprocessing import create_fixed_length_windows
from torch.utils.data import DataLoader
from multiprocessing import Pool, cpu_count

In [None]:

def train_valid_test_split(tuh_preproc, train_size=0.9):
    """
    Function for splitting the dataset into train, validation, 
    and test sets.
    
    Parameters
    ----------
    tuh_preproc : preprocessed TUH Abnormal dataset
    train_size : size of the training set (defines split between train-validation sets)

    Returns
    ----------
    tuh_train : train set
    tuh_val validation set
    tuh_test : test set
    """
    
    train_test_splits = tuh_preproc.split("train")
    tuh_train, tuh_test = train_test_splits['True'], train_test_splits['False']

    train_len = len(tuh_train.datasets)
    train_size = train_size

    tuh_train_inds = [*range(train_len)]
    train_index = int(train_len*train_size)-1
    new_train_inds = [*range(train_index)]


    new_val_inds = list(set(tuh_train_inds)-set(new_train_inds))
    
    train_val_splits = tuh_train.split({'train': new_train_inds, 'val': new_val_inds})
    tuh_train, tuh_val = train_val_splits['train'],train_val_splits['val']
    
    return tuh_train, tuh_val, tuh_test

In [None]:

def get_parameters_for_model(tuh_train):
    """
    Function for extracting the number of channels, and
    size of one input from the data.
    
    Parameters
    ----------
    tuh_train : training set

    Returns
    ----------
    n_chans : number of channels
    input_size_samples : size of an input in the dataset
    """
    
    n_chans, input_size_samples = tuh_train[0][0].shape
    
    return n_chans, input_size_samples

In [None]:

def create_windows(model, tuh_train, tuh_val, tuh_test, n_jobs, input_window_samples=6000, in_chans=21):
    """
    Function for creating equal-sized windows from
    the data. Based on create_fixed_length_windows
    function from braindecode package.
    
    Parameters
    ----------
    model : model to be trained later on the dataset
    tuh_train : train set
    tuh_val : validation set
    tuh_test : test set
    n_jobs : number of jobs used for parallel execution
    input_window_samples : size of the windows to be created (in ms)
    in_chans: number of channels in the dataset

    Returns
    ----------
    train_set : train set containing the newly generated windows as data points
    val_set : validation set containing the newly generated windows as data points
    test_set : test set containing the newly generated windows as data points
    n_preds_per_input : number of predictions per one input
    """
    
    n_preds_per_input = get_output_shape(model, in_chans, input_window_samples)[2] #
    
    train_set = create_fixed_length_windows(
        tuh_train,
        start_offset_samples=0,
        stop_offset_samples=None,
        window_size_samples=input_window_samples,
        window_stride_samples=n_preds_per_input,
        drop_last_window=False,
        preload=False,
        n_jobs=n_jobs,
        mapping={False: 0, True: 1},  # map non-digit targets
    )
    
    val_set = create_fixed_length_windows(
        tuh_val,
        start_offset_samples=0,
        stop_offset_samples=None,
        window_size_samples=input_window_samples,
        window_stride_samples=n_preds_per_input,
        drop_last_window=False,
        preload=False,
        n_jobs=n_jobs,
        mapping={False: 0, True: 1},  # map non-digit targets
    )
    
    test_set = create_fixed_length_windows(
        tuh_test,
        start_offset_samples=0,
        stop_offset_samples=None,
        window_size_samples=input_window_samples,
        window_stride_samples=n_preds_per_input,
        drop_last_window=False,
        preload=False,
        n_jobs=n_jobs,
        mapping={False: 0, True: 1},  # map non-digit targets
    )
    
    return train_set, val_set, test_set, n_preds_per_input

In [None]:

def transform_target(train_set, val_set, test_set, n_preds_per_input): 
    """
    Function for transforming shape of the target data, to match with
    the shape of inputs.
    
    Parameters
    ----------
    train_set : training set
    val_set : validation set
    test_set : test set
    n_preds_per_input : number of predictions per one input
    
    Returns
    ----------
    -
    """
    train_set.target_transform = lambda x: np.full((n_preds_per_input), fill_value=x)
    val_set.target_transform = lambda x: np.full((n_preds_per_input), fill_value=x)
    test_set.target_transform = lambda x: np.full((n_preds_per_input), fill_value=x)

In [None]:

def create_dataloaders(train_set, val_set, test_set, batch_size=64):
    """
    Function for generating PyTorch DataLoaders from the training,
    validation and test sets.
    
    Parameters
    ----------
    train_set : training set
    val_set : validation set
    test_set : test set
    batch_size : batch size
    
    Returns
    ----------
    train_loader : DataLoader containing the training set
    val_loader : DataLoader containing the validation set
    test_loader : DataLoader containing the test set
    
    """
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
    
    return train_loader, val_loader, test_loader

In [None]:



def data_transform(tuh_train, tuh_val, tuh_test, model, n_jobs, train_size=0.9, batch_size=64, input_window_samples=6000, in_chans=21):
    """
    Wrapper function that executes the above-defined
    functions all at once.
    
    Returns
    ----------
    train_loader : DataLoader containing the training set
    val_loader : DataLoader containing the validation set
    test_loader : DataLoader containing the test set
    
    """
    
    train_set, val_set, test_set, n_preds_per_input = create_windows(model, tuh_train, tuh_val, tuh_test, n_jobs, input_window_samples=input_window_samples, in_chans=in_chans)
    transform_target(train_set, val_set, test_set, n_preds_per_input)
    train_loader, val_loader, test_loader = create_dataloaders(train_set, val_set, test_set, batch_size=batch_size)
    
    return train_loader, val_loader, test_loader