In [55]:
import numpy as np
import pandas as pd
from tqdm import tqdm

In [56]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [57]:
%%capture
!pip install mne
#!pip install braindecode
!pip install git+https://github.com/TNTLFreiburg/braindecode.git

In [58]:
import numpy as np
from sklearn.metrics import roc_auc_score
import numpy as np
from numpy.random import RandomState
from braindecode.datautil.iterators import _compute_start_stop_block_inds

class CroppedDiagnosisMonitor(object):
    """
    Compute trialwise misclasses from predictions for crops for non-dense predictions.
    Parameters
    ----------
    input_time_length: int
        Temporal length of one input to the model.
    """

    def __init__(self, input_time_length, n_preds_per_input):
        self.input_time_length = input_time_length
        self.n_preds_per_input = n_preds_per_input

    def monitor_epoch(self, ):
        return

    def monitor_set(self, setname, all_preds, all_losses,
                    all_batch_sizes, all_targets, dataset):
        """Assuming one hot encoding for now"""
        preds_per_trial = compute_preds_per_trial(
            all_preds, dataset, input_time_length=self.input_time_length,
            n_stride=self.n_preds_per_input)

        mean_preds_per_trial = [np.mean(preds, axis=(0, 2)) for preds in
                                preds_per_trial]
        mean_preds_per_trial = np.array(mean_preds_per_trial)

        pred_labels_per_trial = np.argmax(mean_preds_per_trial, axis=1)
        assert pred_labels_per_trial.shape == dataset.y.shape
        accuracy = np.mean(pred_labels_per_trial == dataset.y)
        misclass = 1 - accuracy
        column_name = "{:s}_misclass".format(setname)
        out = {column_name: float(misclass)}
        y = dataset.y

        n_true_positive = np.sum((y == 1) & (pred_labels_per_trial == 1))
        n_positive = np.sum(y == 1)
        if n_positive > 0:
            sensitivity = n_true_positive / float(n_positive)
        else:
            sensitivity = np.nan
        column_name = "{:s}_sensitivity".format(setname)
        out.update({column_name: float(sensitivity)})

        n_true_negative = np.sum((y == 0) & (pred_labels_per_trial == 0))
        n_negative = np.sum(y == 0)
        if n_negative > 0:
            specificity = n_true_negative / float(n_negative)
        else:
            specificity = np.nan
        column_name = "{:s}_specificity".format(setname)
        out.update({column_name: float(specificity)})
        if (n_negative > 0) and (n_positive > 0):
            auc = roc_auc_score(y, mean_preds_per_trial[:,1])
        else:
            auc = np.nan
        column_name = "{:s}_auc".format(setname)
        out.update({column_name: float(auc)})
        return out

def compute_preds_per_trial(preds_per_batch, dataset, input_time_length,
                            n_stride):
    n_trials = len(dataset.X)
    i_pred_starts = [input_time_length -
                     n_stride] * n_trials
    i_pred_stops = [t.shape[1] for t in dataset.X]

    start_stop_block_inds_per_trial = _compute_start_stop_block_inds(
        i_pred_starts,
        i_pred_stops, input_time_length, n_stride,
        False)

    n_rows_per_trial = [len(block_inds) for block_inds in
                        start_stop_block_inds_per_trial]

    all_preds_arr = np.concatenate(preds_per_batch, axis=0)
    i_row = 0
    preds_per_trial = []
    for n_rows in n_rows_per_trial:
        preds_per_trial.append(all_preds_arr[i_row:i_row + n_rows])
        i_row += n_rows
    assert i_row == len(all_preds_arr)
    return preds_per_trial


class CroppedNonDenseTrialMisclassMonitor(object):
    """
    Compute trialwise misclasses from predictions for crops for non-dense predictions.
    Parameters
    ----------
    input_time_length: int
        Temporal length of one input to the model.
    """

    def __init__(self, input_time_length, n_preds_per_input):
        self.input_time_length = input_time_length
        self.n_preds_per_input = n_preds_per_input

    def monitor_epoch(self, ):
        return

    def monitor_set(self, setname, all_preds, all_losses,
                    all_batch_sizes, all_targets, dataset):
        """Assuming one hot encoding for now"""
        n_trials = len(dataset.X)
        i_pred_starts = [self.input_time_length -
                         self.n_preds_per_input] * n_trials
        i_pred_stops = [t.shape[1] for t in dataset.X]

        start_stop_block_inds_per_trial = _compute_start_stop_block_inds(
            i_pred_starts,
            i_pred_stops, self.input_time_length, self.n_preds_per_input,
            False)

        n_rows_per_trial = [len(block_inds) for block_inds in
                            start_stop_block_inds_per_trial]

        all_preds_arr = np.concatenate(all_preds, axis=0)
        i_row = 0
        preds_per_trial = []
        for n_rows in n_rows_per_trial:
            preds_per_trial.append(all_preds_arr[i_row:i_row + n_rows])
            i_row += n_rows

        mean_preds_per_trial = [np.mean(preds, axis=(0, 2)) for preds in
                                preds_per_trial]
        mean_preds_per_trial = np.array(mean_preds_per_trial)

        pred_labels_per_trial = np.argmax(mean_preds_per_trial, axis=1)
        assert pred_labels_per_trial.shape == dataset.y.shape
        accuracy = np.mean(pred_labels_per_trial == dataset.y)
        misclass = 1 - accuracy
        column_name = "{:s}_misclass".format(setname)
        return {column_name: float(misclass)}

In [99]:
# There should always be a 'train' and 'eval' folder directly
# below these given folders
# Folders should contain all normal and abnormal data files without duplications
class Config():
    data_folders = [
        'drive/MyDrive/Data/normal/',
        'drive/MyDrive/Data/abnormal/']
    n_recordings = 600  # number of edf files to analyse, if you want to restrict the set size
    sensor_types = ["EEG"]
    n_chans = 21
    max_recording_mins = 35  # exclude larger recordings from training set
    sec_to_cut = 60  # cut away at start of each recording
    duration_recording_mins = 5#20  # how many minutes to use per recording
    test_recording_mins = 5#20
    max_abs_val = 800  # for clipping
    sampling_freq = 100
    divisor = 10  # divide signal by this
    test_on_eval = True  # teston evaluation set or on training set
    # in case of test on eval, n_folds and i_testfold determine
    # validation fold in training set for training until first stop
    n_folds = 10
    i_test_fold = 9
    shuffle = True
    model_name = 'deep'
    n_start_chans = 25
    n_chan_factor = 2  # relevant for deep model only
    input_time_length = 30000
    final_conv_length = 'auto'
    model_constraint = 'defaultnorm'
    init_lr = 1e-4
    batch_size = 16
    max_epochs = 50 # until first stop, the continue train on train+valid
    cuda = True # False
    n_classes = 2

In [61]:
import re
import numpy as np
import glob
import os.path
import mne

def session_key(file_name):
    """ sort the file name by session """
    return re.findall(r'(s\d{2})', file_name)


def natural_key(file_name):
    """ provides a human-like sorting key of a string """
    key = [int(token) if token.isdigit() else None
           for token in re.split(r'(\d+)', file_name)]
    return key

def time_key(file_name):
    """ provides a time-based sorting key """
    splits = file_name.split('/')
    print(re.findall(r'(\d{4}_\d{2}_\d{2})', splits[-2]))
    [date] = re.findall(r'(\d{4}_\d{2}_\d{2})', splits[-2])
    date_id = [int(token) for token in date.split('_')]
    recording_id = natural_key(splits[-1])
    session_id = session_key(splits[-2])

    return date_id + session_id + recording_id


def read_all_file_names(path, extension, key="time"):
    """ read all files with specified extension from given path
    :param path: parent directory holding the files directly or in subdirectories
    :param extension: the type of the file, e.g. '.txt' or '.edf'
    :param key: the sorting of the files. natural e.g. 1, 2, 12, 21 (machine 1, 12, 2, 21) or by time since this is
    important for cv. time is specified in the edf file names
    """
    file_paths = glob.glob(path + '**/*' + extension, recursive=True)

    if key == 'time':
        return sorted(file_paths, key=time_key)

    elif key == 'natural':
        return sorted(file_paths, key=natural_key)

def get_info_with_mne(file_path):
    """ read info from the edf file without loading the data. loading data is done in multiprocessing since it takes
    some time. getting info is done before because some files had corrupted headers or weird sampling frequencies
    that caused the multiprocessing workers to crash. therefore get and check e.g. sampling frequency and duration
    beforehand
    :param file_path: path of the recording file
    :return: file name, sampling frequency, number of samples, number of signals, signal names, duration of the rec
    """
    try:
        edf_file = mne.io.read_raw_edf(file_path, verbose='error')
    except ValueError:
        return None, None, None, None, None, None
        
    # some recordings have a very weird sampling frequency. check twice before skipping the file
    sampling_frequency = int(edf_file.info['sfreq'])
    if sampling_frequency < 10:
        sampling_frequency = 1 / (edf_file.times[1] - edf_file.times[0])
        if sampling_frequency < 10:
            return None, sampling_frequency, None, None, None, None

    n_samples = edf_file.n_times
    signal_names = edf_file.ch_names
    n_signals = len(signal_names)
    # some weird sampling frequencies are at 1 hz or below, which results in division by zero
    duration = n_samples / max(sampling_frequency, 1)

    # TODO: return rec object?
    return edf_file, sampling_frequency, n_samples, n_signals, signal_names, duration


def get_recording_length(file_path):
    """ some recordings were that huge that simply opening them with mne caused the program to crash. therefore, open
    the edf as bytes and only read the header. parse the duration from there and check if the file can safely be opened
    :param file_path: path of the directory
    :return: the duration of the recording
    """
    f = open(file_path, 'rb')
    header = f.read(256)
    f.close()

    return int(header[236:244].decode('ascii'))


def load_data(fname, preproc_functions, sensor_types=['EEG']):
    cnt, sfreq, n_samples, n_channels, chan_names, n_sec = get_info_with_mne(fname)
    cnt.load_data()
    selected_ch_names = []
    
    wanted_elecs = ['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1',
                    'FP2', 'FZ', 'O1', 'O2',
                    'P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']

    for wanted_part in wanted_elecs:
        wanted_found_name = []
        for ch_name in cnt.ch_names:
            if ' ' + wanted_part + '-' in ch_name:
                wanted_found_name.append(ch_name)
        selected_ch_names.append(wanted_found_name[0])


    cnt = cnt.pick_channels(selected_ch_names)
    n_sensors = 0
    if 'EEG' in sensor_types:
        n_sensors += 21
    if 'EKG' in sensor_types:
        n_sensors += 1

    assert len(cnt.ch_names)  == n_sensors, (
        "Expected {:d} channel names, got {:d} channel names".format(
            n_sensors, len(cnt.ch_names)))

    # change from volt to mikrovolt
    data = (cnt.get_data() * 1e6).astype(np.float32)
    fs = cnt.info['sfreq']
    for fn in preproc_functions:
        data, fs = fn(data, fs)
        data = data.astype(np.float32)
        fs = float(fs)
    return data


def get_all_sorted_file_names_and_labels(train_or_eval, folders):
    all_file_names = []
    for dirname, _, filenames in os.walk(folders[0]):
        for filename in filenames:
            all_file_names.append(os.path.join(dirname, filename))
    for dirname, _, filenames in os.walk(folders[1]):
        for filename in filenames:
            all_file_names.append(os.path.join(dirname, filename))       
    labels = ['/abnormal/' in f for f in all_file_names]
    labels = np.array(labels).astype(np.int64)
    return all_file_names, labels


class DiagnosisSet(object):
    def __init__(self, n_recordings, max_recording_mins, preproc_functions,
                 data_folders,
                 train_or_eval='train', sensor_types=['EEG'],):
        self.n_recordings = n_recordings
        self.max_recording_mins = max_recording_mins
        self.preproc_functions = preproc_functions
        self.train_or_eval = train_or_eval
        self.sensor_types = sensor_types
        self.data_folders = data_folders

    def load(self, only_return_labels=False):
        all_file_names, labels = get_all_sorted_file_names_and_labels(
            train_or_eval=self.train_or_eval, folders=self.data_folders)
        
        if self.max_recording_mins is not None:
            assert 'train' == self.train_or_eval
            lengths = [get_recording_length(fname) for fname in all_file_names]
            lengths = np.array(lengths)
            mask = lengths < self.max_recording_mins * 60
            cleaned_file_names = np.array(all_file_names)[mask]
            cleaned_labels = labels[mask]
        else:
            cleaned_file_names = np.array(all_file_names)
            cleaned_labels = labels
        if only_return_labels:
            return cleaned_labels
        
        X = []
        y = []
        n_files = len(cleaned_file_names[:self.n_recordings])
        for i_fname, fname in enumerate(tqdm(cleaned_file_names[:self.n_recordings])):
            x = load_data(fname, preproc_functions=self.preproc_functions, sensor_types=self.sensor_types)
            assert x is not None
            X.append(x)
            y.append(cleaned_labels[i_fname])
            
        y = np.array(y)
        return X, y

In [103]:
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.models.deep4 import Deep4Net
import torch
from torch.autograd import Variable


class LSTM(nn.Module):
    def __init__(self, output_size, input_size, hidden_size, num_layers):
        super(LSTM, self).__init__()

        self.num_layers = num_layers
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.seq_length = 10
        self.lstm = nn.LSTM(input_size=input_size[0], hidden_size=hidden_size,
                            num_layers=num_layers, batch_first=True, 
                            bidirectional=True)
        self.linear = nn.Linear(hidden_size*2, output_size)

    def forward(self, x):
        h0 = Variable(torch.zeros(self.num_layers*2, x.size(0), 
            self.hidden_size))
        c0 = Variable(torch.zeros(self.num_layers*2, x.size(0), 
            self.hidden_size))

        lstm_out, (hn, cn) = self.lstm(x, (h0, c0))
        out = lstm_out[:, -1, :]
        out = self.linear(out)
    
        return out
        

def build_model(config):
    if config.model_name == 'shallow':
        model = ShallowFBCSPNet(in_chans=config.n_chans, 
                                n_classes=config.n_classes,
                                input_time_length=config.input_time_length,
                                final_conv_length=config.final_conv_length).create_network()
        
    elif config.model_name == 'deep':
        model = Deep4Net(in_chans=config.n_chans, n_classes=config.n_classes,
                         #n_filters_time=config.n_start_chans,
                         #n_filters_spat=config.n_start_chans,
                         input_time_length=config.input_time_length,
                         n_filters_2 = int(config.n_start_chans * config.n_chan_factor),
                         n_filters_3 = int(config.n_start_chans * (config.n_chan_factor ** 2.0)),
                         n_filters_4 = int(config.n_start_chans * (config.n_chan_factor ** 3.0)),
                         final_conv_length=config.final_conv_length,
                         stride_before_pool=True).create_network()
        
    elif (config.model_name == 'deep_smac'):
        if config.model_name == 'deep_smac':
            do_batch_norm = False
        else:
            assert config.model_name == 'deep_smac_bnorm'
            do_batch_norm = True
        double_time_convs = False
        drop_prob = 0.244445
        filter_length_2 = 12
        filter_length_3 = 14
        filter_length_4 = 12
        filter_time_length = 21
        final_conv_length = 1
        first_nonlin = elu
        first_pool_mode = 'mean'
        first_pool_nonlin = identity
        later_nonlin = elu
        later_pool_mode = 'mean'
        later_pool_nonlin = identity
        n_filters_factor = 1.679066
        n_filters_start = 32
        pool_time_length = 1
        pool_time_stride = 2
        split_first_layer = True
        n_chan_factor = n_filters_factor
        n_start_chans = n_filters_start
        model = Deep4Net(config.n_chans, config.n_classes,
                 n_filters_time=n_start_chans,
                 n_filters_spat=n_start_chans,
                 input_time_length=config.input_time_length,
                 n_filters_2=int(n_start_chans * n_chan_factor),
                 n_filters_3=int(n_start_chans * (n_chan_factor ** 2.0)),
                 n_filters_4=int(n_start_chans * (n_chan_factor ** 3.0)),
                 final_conv_length=config.final_conv_length,
                 batch_norm=do_batch_norm,
                 double_time_convs=double_time_convs,
                 drop_prob=drop_prob,
                 filter_length_2=filter_length_2,
                 filter_length_3=filter_length_3,
                 filter_length_4=filter_length_4,
                 filter_time_length=filter_time_length,
                 first_nonlin=first_nonlin,
                 first_pool_mode=first_pool_mode,
                 first_pool_nonlin=first_pool_nonlin,
                 later_nonlin=later_nonlin,
                 later_pool_mode=later_pool_mode,
                 later_pool_nonlin=later_pool_nonlin,
                 pool_time_length=pool_time_length,
                 pool_time_stride=pool_time_stride,
                 split_first_layer=split_first_layer,
                 stride_before_pool=True).create_network()
        
    elif config.model_name == 'shallow_smac':
        conv_nonlin = identity
        do_batch_norm = True
        drop_prob = 0.328794
        filter_time_length = 56
        final_conv_length = 22
        n_filters_spat = 73
        n_filters_time = 24
        pool_mode = 'max'
        pool_nonlin = identity
        pool_time_length = 84
        pool_time_stride = 3
        split_first_layer = True
        model = ShallowFBCSPNet(in_chans=config.n_chans, n_classes=config.n_classes,
                                n_filters_time=n_filters_time,
                                n_filters_spat=n_filters_spat,
                                input_time_length=input_time_length,
                                final_conv_length=config.final_conv_length,
                                conv_nonlin=conv_nonlin,
                                batch_norm=do_batch_norm,
                                drop_prob=drop_prob,
                                filter_time_length=filter_time_length,
                                pool_mode=pool_mode,
                                pool_nonlin=pool_nonlin,
                                pool_time_length=pool_time_length,
                                pool_time_stride=pool_time_stride,
                                split_first_layer=split_first_layer,
                                ).create_network()
    elif config.model_name == 'linear':
        model = nn.Sequential()
        model.add_module("conv_classifier",
                         nn.Conv2d(config.n_chans, config.n_classes, (600,1)))
        model.add_module('Softmax', nn.Softmax(1))
        model.add_module('squeeze', Expression(lambda x: x.squeeze(3)))
        
    #to_dense_prediction_model(model)
    if config.cuda:
        model.cuda()
    return model  

In [63]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

__modifiedby__ = 'Mohamed Radwan'
__originalauthor__ = 'David Nahmias'
__credits__ = ['David Nahmias']
__maintainer__ = 'Mohamed Radwan'

import logging
import time
from copy import copy
import sys

from collections import Counter
import random
import numpy as np
from numpy.random import RandomState
import resampy
from torch import optim
import torch.nn.functional as F
import torch as th
from torch.nn.functional import elu
from torch import nn
import torch.backends.cudnn as cudnn

from braindecode.datautil.signal_target import SignalAndTarget
from braindecode.torch_ext.util import np_to_var
from braindecode.torch_ext.util import set_random_seeds
from braindecode.torch_ext.modules import Expression
from braindecode.experiments.experiment import Experiment
from braindecode.datautil.iterators import CropsFromTrialsIterator
from braindecode.experiments.monitors import (RuntimeMonitor, LossMonitor,
                                              MisclassMonitor)
from braindecode.experiments.stopcriteria import MaxEpochs
from braindecode.datautil.iterators import get_balanced_batches
from braindecode.torch_ext.constraints import MaxNormDefaultConstraint
from braindecode.torch_ext.util import var_to_np
from braindecode.torch_ext.functions import identity


def splitDataRandom(allData,allLabels,setNum=0,shuffle=0):
    numberEqSamples = min(Counter(allLabels).values())
    trainSamplesNum = int(np.ceil(numberEqSamples*0.75))
    testSamplesNum = numberEqSamples-trainSamplesNum

    labels0 = allLabels[allLabels == 0]
    labels1 = allLabels[allLabels == 1]
    data0 = np.array(allData)[allLabels == 0]
    data1 = np.array(allData)[allLabels == 1]

    fullRange = list(range(numberEqSamples))
    random.shuffle(fullRange)

    testIndices = fullRange[trainSamplesNum:]
    trainIndices = fullRange[:trainSamplesNum]
    
    allDataTrain = np.concatenate((data0[trainIndices],data1[trainIndices]),axis=0)
    allLabelsTrain = np.concatenate((labels0[trainIndices],labels1[trainIndices]),axis=0)

    allDataTest = np.concatenate((data0[testIndices],data1[testIndices]),axis=0)
    allLabelsTest = np.concatenate((labels0[testIndices],labels1[testIndices]),axis=0)

    return allDataTrain, allLabelsTrain, allDataTest, allLabelsTest


def create_set(X, y, inds):
    """
    X list and y nparray
    :return: 
    """
    new_X = []
    for i in inds:
        new_X.append(X[i])
    new_y = y[inds]
    return SignalAndTarget(new_X, new_y)


class TrainValidTestSplitter(object):
    def __init__(self, n_folds, i_test_fold, shuffle):
        self.n_folds = n_folds
        self.i_test_fold = i_test_fold
        self.rng = RandomState(404)
        self.shuffle = shuffle

    def split(self, X, y,):
        if len(X) < self.n_folds:
            raise ValueError("Less Trials: {:d} than folds: {:d}".format(
                len(X), self.n_folds
            ))
        folds = get_balanced_batches(len(X), self.rng, self.shuffle,
                                     n_batches=self.n_folds)
        test_inds = folds[self.i_test_fold]
        valid_inds = folds[self.i_test_fold - 1]
        all_inds = list(range(len(X)))
        train_inds = np.setdiff1d(all_inds, np.union1d(test_inds, valid_inds))
        train_set = create_set(X, y, train_inds)
        valid_set = create_set(X, y, valid_inds)
        test_set = create_set(X, y, test_inds)

        return train_set, valid_set, test_set


class TrainValidSplitter(object):
    def __init__(self, n_folds, i_valid_fold, shuffle):
        self.n_folds = n_folds
        self.i_valid_fold = i_valid_fold
        self.rng = RandomState(404)
        self.shuffle = shuffle

    def split(self, X, y):
        folds = get_balanced_batches(len(X), self.rng, self.shuffle,
                                     n_batches=self.n_folds)
        valid_inds = folds[self.i_valid_fold]
        all_inds = list(range(len(X)))
        train_inds = np.setdiff1d(all_inds, valid_inds)
        train_set = create_set(X, y, train_inds)
        valid_set = create_set(X, y, valid_inds)
        return train_set, valid_set

    
def preprocess(config):
    preproc_functions = []
    preproc_functions.append(
        lambda data, fs: (data[:, int(config.sec_to_cut * fs):-int(
            config.sec_to_cut * fs)], fs))
    preproc_functions.append(
        lambda data, fs: (data[:, :int(config.duration_recording_mins * 60 * fs)], fs))
    if config.max_abs_val is not None:
        preproc_functions.append(lambda data, fs:
                                 (np.clip(data, -config.max_abs_val, config.max_abs_val), fs))

    preproc_functions.append(lambda data, fs: (resampy.resample(data, fs,
                                                                config.sampling_freq,
                                                                axis=1,
                                                                filter='kaiser_fast'),
                                               config.sampling_freq))

    if config.divisor is not None:
        preproc_functions.append(lambda data, fs: (data / config.divisor, fs))
        
    return preproc_functions

    
def data_loader(config):
    cudnn.benchmark = True
    
    preproc_functions = preprocess(config)
    
    dataset = DiagnosisSet(n_recordings=config.n_recordings,
                           max_recording_mins=config.max_recording_mins,
                           preproc_functions=preproc_functions,
                           data_folders=config.data_folders,
                           train_or_eval='train',
                           sensor_types=config.sensor_types)
    
    if config.test_on_eval:
        test_recording_mins = config.duration_recording_mins
        test_preproc_functions = copy(preproc_functions)
        test_preproc_functions[1] = lambda data, fs: (
            data[:, :int(test_recording_mins * 60 * fs)], fs)
        test_dataset = DiagnosisSet(n_recordings=config.n_recordings,
                                max_recording_mins=None,
                                preproc_functions=test_preproc_functions,
                                data_folders=config.data_folders,
                                train_or_eval='val',
                                sensor_types=config.sensor_types)

    data, labels = dataset.load()
    X,y,test_X,test_y = splitDataRandom(data, labels,shuffle=0)
    
    if config.test_on_eval:
        max_shape = np.max([list(x.shape) for x in test_X], axis=0)
    if not config.test_on_eval:
        splitter = TrainValidTestSplitter(config.n_folds, config.i_test_fold,
                                          shuffle=config.shuffle)
        train_set, valid_set, test_set = splitter.split(X, y)
        
    else:
        splitter = TrainValidSplitter(config.n_folds, 
                                      i_valid_fold=config.i_test_fold,
                                      shuffle=config.shuffle)
        train_set, valid_set = splitter.split(X, y)
        test_set = SignalAndTarget(test_X, test_y)
        
    return train_set, valid_set, test_set

In [65]:
config = Config()
train_set, valid_set, test_set = data_loader(config)

model = build_model(config)
optimizer = optim.Adam(model.parameters(), lr=config.init_lr)

100%|██████████| 567/567 [10:38<00:00,  1.13s/it]


In [82]:
from braindecode.torch_ext.util import np_to_var, var_to_np
from braindecode.datautil.iterators import get_balanced_batches
import torch.nn.functional as F
from sklearn.metrics import accuracy_score

rng = RandomState((2018,8,7))


def evaluate_on_val(model, data, config):
    """print metrics each epoch"""
    model.eval()
    accuracies = []
    losses = []
    i_trials_in_batch = get_balanced_batches(len(data.X), rng, 
                                             shuffle=True, 
                                             batch_size=config.batch_size)
    for i_trials in i_trials_in_batch:
        batch_X = np.array(data.X)[i_trials][:,:,:,None]
        batch_y = np.array(data.y)[i_trials]
        batch_X = np_to_var(batch_X)   
        batch_y = np_to_var(batch_y)     
        if config.cuda:
            batch_X = batch_X.cuda()
            batch_y = batch_y.cuda()
        outputs = model(batch_X)
        loss = F.nll_loss(outputs, batch_y)
        outputs = outputs.cpu().detach().numpy()
        batch_y = batch_y.cpu().detach().numpy()
        predicted_labels = np.argmax(outputs, axis=1)
        accuracy = accuracy_score(batch_y, predicted_labels)
        accuracies.append(accuracy)
        losses.append(loss.item())
    
    print('Accuracy: ', np.mean(accuracies), ', Loss: ', np.mean(losses))
    

def evaluate_on_test(model, eeg_features, labels, config):
    """evaluate the model on test data"""
    accuracies = []
    model.eval()
    i_trials_in_batch = get_balanced_batches(len(eeg_features), rng, 
                                                 shuffle=True,
                                                 batch_size=config.batch_size)
    for i_trials in i_trials_in_batch:
        batch_X = np.array(eeg_features)[i_trials][:,:,:,None]
        batch_y = np.array(labels)[i_trials]
        batch_y = np_to_var(batch_y) 
        batch_X = np_to_var(batch_X) 
        if config.cuda:
            batch_X = batch_X.cuda()
            batch_y = batch_y.cuda()
        outputs = model(batch_X)
        outputs = outputs.cpu().detach().numpy()
        batch_y = batch_y.cpu().detach().numpy()
        predicted_labels = np.argmax(outputs, axis=1)
        accuracy = accuracy_score(batch_y, predicted_labels)
        accuracies.append(accuracy)
    print('Test Accuracy: ', np.mean(accuracies))


def train(config, model, optimizer, train_set, valid_set):
    for i_epoch in range(1, config.max_epochs):
        i_trials_in_batch = get_balanced_batches(len(train_set.X), rng, 
                                                 shuffle=True,
                                                 batch_size=config.batch_size)
        # Set model to training mode
        model.train()
        for i_trials in i_trials_in_batch:
        # Have to add empty fourth dimension to X
        
            batch_X = np.array(train_set.X)[i_trials][:,:,:,None]
            batch_y = np.array(train_set.y)[i_trials]
            batch_X = np_to_var(batch_X)
            if config.cuda:
                batch_X = batch_X.cuda()
            batch_y = np_to_var(batch_y)
            if config.cuda:
                batch_y = batch_y.cuda()
            # Remove gradients of last backward pass from all parameters
            optimizer.zero_grad()
            # Compute outputs of the network
            outputs = model(batch_X)
            # Compute the loss
            loss = F.nll_loss(outputs, batch_y)
            # Do the backpropagation
            loss.backward()
            # Update parameters with the optimizer
            optimizer.step()
        print('Epoch ', i_epoch)
        print('========')
        print('Training Metrics: ')
        evaluate_on_val(model, train_set, config)
        print('Validation Metrics: ')
        evaluate_on_val(model, valid_set, config)
    return model

## Deep4Net Model

In [100]:
config = Config()
model = build_model(config)
optimizer = optim.Adam(model.parameters(), lr=config.init_lr)
model = train(config, model, optimizer, train_set, valid_set)

Epoch  1
Training Metrics: 
Accuracy:  0.527536231884058 , Loss:  1.4153133086536243
Validation Metrics: 
Accuracy:  0.45 , Loss:  1.747484266757965
Epoch  2
Training Metrics: 
Accuracy:  0.5626811594202898 , Loss:  0.6958337400270544
Validation Metrics: 
Accuracy:  0.475 , Loss:  0.8347907364368439
Epoch  3
Training Metrics: 
Accuracy:  0.5166666666666666 , Loss:  1.3553521918213887
Validation Metrics: 
Accuracy:  0.45 , Loss:  1.6935781240463257
Epoch  4
Training Metrics: 
Accuracy:  0.5166666666666666 , Loss:  1.0842221949411475
Validation Metrics: 
Accuracy:  0.45 , Loss:  1.410382628440857
Epoch  5
Training Metrics: 
Accuracy:  0.5291666666666667 , Loss:  1.0384441717811252
Validation Metrics: 
Accuracy:  0.45 , Loss:  1.4167568683624268
Epoch  6
Training Metrics: 
Accuracy:  0.5576086956521739 , Loss:  0.7346386935399927
Validation Metrics: 
Accuracy:  0.45 , Loss:  1.030613511800766
Epoch  7
Training Metrics: 
Accuracy:  0.5358695652173913 , Loss:  0.8194335530633512
Validation 

In [123]:

"""in the original implimentation, the author start a new array as zeros,
and then add perturbations to the zeros array, to show the spectral importance.
Here, we use the same original data nad only perturb the required spectrum 
one at a time"""
from scipy import fftpack
import copy

def addDataNoise(origSignals,band=[],channels=[],srate=100, zeroing=False):
    np.random.seed(seed=404)
    signals = copy.deepcopy(origSignals)

    if (len(band)+len(channels)) == 0:
        return origSignals
    
    if (len(channels)>0) and (len(band)==0):
        for s in range(len(signals)):
            for c in channels:
                cleanSignal = origSignals[s][c,:]
                timeDomNoise = np.random.normal(np.mean(cleanSignal), 
                                                np.std(cleanSignal), 
                                                size=len(cleanSignal))
                signals[s][c,:] = np.float32(timeDomNoise)
                """add noise to all channels: cleanSignal + timeDomNoise"""

    if (len(band) == 2) and (type(band[0]) == int):
        if len(channels)==0:
            channels = range(signals[0].shape[0])
        numSamples = signals[0].shape[1]
        W = fftpack.rfftfreq(numSamples,d=1./srate)
        lowHz = next(x[0] for x in enumerate(W) if x[1] > band[0])
        highHz = next(x[0] for x in enumerate(W) if x[1] > band[1])
        for s in range(len(signals)):
            for c in channels: #loop through channels
                dataDFT = fftpack.rfft(origSignals[s][c,:])
                cleanDFT = dataDFT[lowHz:highHz]
                freqDomNoise = np.random.normal(np.mean(cleanDFT), 
                                                np.std(cleanDFT), 
                                                size=len(cleanDFT))
                dataDFT[lowHz:highHz] =  freqDomNoise#cleanDFT + freqDomNoise
                signals[s][c,:] = np.float32(fftpack.irfft(dataDFT))


    elif (len(band)>0) and (type(band) == list):
        if len(channels)==0:
            channels = range(origSignals[0].shape[0])
        
        numSamples = origSignals[0].shape[1]
        W = fftpack.rfftfreq(numSamples, d=1./srate)    
        
        for s in range(len(signals)):
            for c in channels: #loop through channels
                dataDFT_original = fftpack.rfft(origSignals[s][c,:])
                dataDFT_output = fftpack.rfft(signals[s][c,:])
                for b in band:
                    lowHz = next(x[0] for x in enumerate(W) if x[1] > b[0])
                    highHz = next(x[0] for x in enumerate(W) if x[1] > b[1])
                    cleanDFT = dataDFT_original[lowHz:highHz]
                    freqDomNoise = np.random.normal(np.mean(cleanDFT), 
                                                    np.std(cleanDFT), 
                                                    size=len(cleanDFT))
                    if zeroing:
                        dataDFT_output[lowHz:highHz] =  0 #no signal in this frequency
                    else:
                        dataDFT_output[lowHz:highHz] = freqDomNoise #freqDomNoise
                signals[s][c,:] = np.float32(fftpack.irfft(dataDFT_output))

    return signals

### Evaluation on Test Data using spectral removing (remove frequencey band) and perturbation

In [125]:
delta = [[0, 4]]
theta = [[4, 8]]
alpha = [[8, 12]]
mu = [[12, 16]]
beta = [[16, 25]]
gamma = [[25, 40]]

print('No removal of spectral data')
evaluate_on_test(model, test_set.X, test_set.y, config)

print('Delta band removed')
test_noisy_band = addDataNoise(test_set.X, band=delta, srate=100, zeroing=True)
evaluate_on_test(model, test_noisy_band, test_set.y, config)

print('Theta band removed')
test_noisy_band = addDataNoise(test_set.X, band=theta, srate=100, zeroing=True)
evaluate_on_test(model, test_noisy_band, test_set.y, config)

print('Alpha band removed')
test_noisy_band = addDataNoise(test_set.X, band=alpha, srate=100, zeroing=True)
evaluate_on_test(model, test_noisy_band, test_set.y, config)

print('Mu band removed')
test_noisy_band = addDataNoise(test_set.X, band=mu, srate=100, zeroing=True)
evaluate_on_test(model, test_noisy_band, test_set.y, config)

print('Beta band removed')
test_noisy_band = addDataNoise(test_set.X, band=beta, srate=100, zeroing=True)
evaluate_on_test(model, test_noisy_band, test_set.y, config)

print('Gamma band removed')
test_noisy_band = addDataNoise(test_set.X, band=gamma, srate=100, zeroing=True)
evaluate_on_test(model, test_noisy_band, test_set.y, config)

No perturbation
Test Accuracy:  0.6415441176470589
Delta band removed
Test Accuracy:  0.5303308823529411
Theta band removed
Test Accuracy:  0.6194852941176471
Alpha band removed
Test Accuracy:  0.6397058823529411
Mu band removed
Test Accuracy:  0.6351102941176471
Beta band removed
Test Accuracy:  0.6346507352941176
Gamma band removed
Test Accuracy:  0.6424632352941176


In [126]:
print('No perturbation')
evaluate_on_test(model, test_set.X, test_set.y, config)

print('Delta band perturbation')
test_noisy_band = addDataNoise(test_set.X, band=delta, srate=100)
evaluate_on_test(model, test_noisy_band, test_set.y, config)

print('Theta band perturbation')
test_noisy_band = addDataNoise(test_set.X, band=theta, srate=100)
evaluate_on_test(model, test_noisy_band, test_set.y, config)

print('Alpha band perturbation')
test_noisy_band = addDataNoise(test_set.X, band=alpha, srate=100)
evaluate_on_test(model, test_noisy_band, test_set.y, config)

print('Mu band perturbation')
test_noisy_band = addDataNoise(test_set.X, band=mu, srate=100)
evaluate_on_test(model, test_noisy_band, test_set.y, config)

print('Beta band perturbation')
test_noisy_band = addDataNoise(test_set.X, band=beta, srate=100)
evaluate_on_test(model, test_noisy_band, test_set.y, config)

print('Gamma band perturbation')
test_noisy_band = addDataNoise(test_set.X, band=gamma, srate=100)
evaluate_on_test(model, test_noisy_band, test_set.y, config)

No perturbation
Test Accuracy:  0.6443014705882353
Delta band perturbation
Test Accuracy:  0.6796875
Theta band perturbation
Test Accuracy:  0.6259191176470589
Alpha band perturbation
Test Accuracy:  0.6484375
Mu band perturbation
Test Accuracy:  0.6351102941176471
Beta band perturbation
Test Accuracy:  0.6332720588235294
Gamma band perturbation
Test Accuracy:  0.6433823529411764
