In [10]:
from tsai.models.utils import *
from tsai.basics import *
from tsai.inference import load_learner
from tsai.all import *
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split

In [11]:
import numpy as np
import pandas as pd
import csv
from tabulate import tabulate
from sklearn.utils import shuffle
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from AugmentTS.augmentts.augmenters.vae import LSTMVAE, VAEAugmenter

In [12]:
def train_model_lstm(x_train, y_train, x_valid, y_valid):
    X, y, splits = combine_split_data([x_train, x_valid], [y_train, y_valid])
    tfms  = [None, TSClassification()] # TSClassification == Categorize
    batch_tfms = TSStandardize()
    dls = get_ts_dls(X, y, splits=splits, tfms=tfms, batch_tfms=batch_tfms, bs=[64, 32])
    #dls.dataset
    model = build_ts_model(LSTM_FCNPlus, dls=dls)
    learn = Learner(dls, model, metrics=accuracy)
    learn.fit_one_cycle(10, lr_max=1e-2)
    return learn

def train_model_cnn(x_train, y_train, x_valid, y_valid):
    X, y, splits = combine_split_data([x_train, x_valid], [y_train, y_valid])
    tfms  = [None, TSClassification()] # TSClassification == Categorize
    batch_tfms = TSStandardize()
    dls = get_ts_dls(X, y, splits=splits, tfms=tfms, batch_tfms=batch_tfms, bs=[64, 32])
    #dls.dataset
    model = build_ts_model(InceptionTimePlus, dls=dls)
    learn = Learner(dls, model, metrics=accuracy)
    learn.fit_one_cycle(10, lr_max=1e-2)
    return learn

def train_model_tst(x_train, y_train, x_valid, y_valid):
    X, y, splits = combine_split_data([x_train, x_valid], [y_train, y_valid])
    tfms  = [None, TSClassification()] # TSClassification == Categorize
    batch_tfms = TSStandardize()
    dls = get_ts_dls(X, y, splits=splits, tfms=tfms, batch_tfms=batch_tfms, bs=[64, 32])
    #dls.dataset
    model = build_ts_model(TSTPlus, dls=dls)
    learn = Learner(dls, model, metrics=accuracy)
    learn.fit_one_cycle(10, lr_max=1e-2)
    return learn

def run_model(x_test, y_test, learn):
    probas, target, preds = learn.get_X_preds(x_test, y_test)
    preds_labels = np.argmax(preds, axis=1)
    target_labels = np.argmax(target, axis=-1) # undo one-hot encoding
    return preds_labels, target_labels

In [13]:
x_train_new = np.load('x_train_vae.npy')
y_train_new = np.load('y_train_vae.npy')
x_valid = np.load('x_valid.npy')
y_valid = np.load('y_valid.npy')

In [14]:
print(x_train_new.shape, y_train_new.shape)
print(x_valid.shape, y_valid.shape)

(13800, 178, 1) (13800, 2)
(2300, 178, 1) (2300, 2)


In [15]:
input_dir = 'eeg_dataset_1/'
x_test = np.load(input_dir +'x_test.npy')
y_test = np.load(input_dir +'y_test.npy')
x_test = x_test.astype(float)
def process_eeg_data(x, y):
    new_x =[]
    new_y = []
    split_size = 178
    n_splits = 23
    for i in range(x.shape[0]):
        sub_data = x[i]
        y_val = y[i]
        for i in range (n_splits):
            sample = sub_data[i*split_size:(i+1)*split_size]
            new_x.append(sample)
            temp_y = y_val
            new_y.append(temp_y)
            #print(sample.shape)
            
    new_x = np.array(new_x)
    new_y = np.array(new_y)
    # x = new_x
    # y = new_y
    return new_x, new_y
x_test, y_test = process_eeg_data(x_test, y_test)

In [16]:
Z=1.96
acc_list_mean = []
for shuffle_id in range(10):
    x_train_shuffled, y_train_shuffled = shuffle(x_train_new, y_train_new)
    learn = train_model_lstm(x_train_shuffled, y_train_shuffled, x_valid, y_valid)
    preds_labels, target_labels = run_model(x_test, y_test, learn)
    acc_list_mean.append(accuracy_score(target_labels, preds_labels))

mean_accuracy = np.mean(acc_list_mean)
std_dev = np.std(acc_list_mean, ddof=1)  # ddof=1 for sample standard deviation
std_error = std_dev / np.sqrt(len(acc_list_mean))
interval_range = Z*std_error  
print(mean_accuracy)
print(interval_range)

epoch,train_loss,valid_loss,accuracy,time
0,0.082436,0.141641,0.94587,00:04
1,0.060788,0.151624,0.946957,00:03
2,0.059132,0.151944,0.947391,00:03
3,0.042395,0.125323,0.955652,00:04
4,0.042664,0.113318,0.959348,00:04
5,0.035564,0.104305,0.96,00:04
6,0.026547,0.134594,0.959348,00:04
7,0.021921,0.131767,0.961087,00:03
8,0.018138,0.117465,0.965217,00:03
9,0.011056,0.141601,0.959348,00:04


epoch,train_loss,valid_loss,accuracy,time
0,0.074172,0.121534,0.951739,00:03
1,0.064747,0.231415,0.939565,00:03
2,0.057809,0.134863,0.952391,00:04
3,0.051069,0.137875,0.948696,00:04
4,0.052437,0.115605,0.960435,00:04
5,0.039286,0.118894,0.955217,00:04
6,0.024692,0.132623,0.956522,00:03
7,0.020285,0.116065,0.962826,00:03
8,0.016508,0.111087,0.964783,00:03
9,0.016423,0.143467,0.962174,00:03


epoch,train_loss,valid_loss,accuracy,time
0,0.075401,0.141776,0.942826,00:03
1,0.063225,0.141269,0.94913,00:04
2,0.057059,0.150023,0.947174,00:03
3,0.05808,0.110714,0.954348,00:04
4,0.045379,0.127002,0.954783,00:03
5,0.036528,0.102046,0.962174,00:03
6,0.027089,0.10517,0.960652,00:04
7,0.014628,0.126578,0.963261,00:04
8,0.013997,0.13144,0.964783,00:03
9,0.009942,0.125245,0.964565,00:03


epoch,train_loss,valid_loss,accuracy,time
0,0.079259,0.181098,0.932609,00:04
1,0.068585,0.117428,0.951739,00:03
2,0.085614,0.176796,0.951304,00:03
3,0.055495,0.121041,0.956087,00:04
4,0.037092,0.123846,0.958478,00:03
5,0.03322,0.118712,0.957609,00:04
6,0.02483,0.105723,0.963913,00:04
7,0.020849,0.122889,0.963478,00:03
8,0.012768,0.112805,0.965,00:04
9,0.013738,0.122795,0.963261,00:03


epoch,train_loss,valid_loss,accuracy,time
0,0.075494,0.128606,0.949565,00:04
1,0.070463,0.142044,0.948043,00:04
2,0.055036,0.142238,0.950435,00:04
3,0.046857,0.106075,0.960435,00:03
4,0.043206,0.106995,0.959565,00:04
5,0.0335,0.126556,0.952609,00:04
6,0.021508,0.099559,0.964783,00:03
7,0.017336,0.127196,0.962174,00:03
8,0.012793,0.140466,0.956522,00:03
9,0.009628,0.122605,0.963043,00:04


epoch,train_loss,valid_loss,accuracy,time
0,0.076632,0.129244,0.948043,00:04
1,0.065617,0.157983,0.942391,00:03
2,0.057884,0.110739,0.959783,00:04
3,0.051924,0.160679,0.946957,00:04
4,0.046575,0.1098,0.958913,00:04
5,0.034293,0.091801,0.966956,00:03
6,0.031877,0.11587,0.958261,00:04
7,0.017364,0.113598,0.964783,00:03
8,0.012913,0.117439,0.964348,00:03
9,0.016417,0.124704,0.963913,00:04


epoch,train_loss,valid_loss,accuracy,time
0,0.070519,0.142205,0.947174,00:03
1,0.058937,0.147166,0.94413,00:04
2,0.0589,0.115683,0.953043,00:04
3,0.044057,0.152915,0.950217,00:04
4,0.038917,0.125242,0.951739,00:04
5,0.032124,0.11426,0.956087,00:04
6,0.023924,0.097941,0.959783,00:04
7,0.017273,0.129111,0.960217,00:03
8,0.011626,0.122117,0.959783,00:04
9,0.00964,0.11443,0.961957,00:04


epoch,train_loss,valid_loss,accuracy,time
0,0.071597,0.124869,0.949348,00:04
1,0.068499,0.163164,0.947609,00:04
2,0.053081,0.127133,0.956304,00:04
3,0.049874,0.124537,0.953913,00:03
4,0.042455,0.142329,0.953696,00:04
5,0.03079,0.130504,0.95587,00:04
6,0.025754,0.141339,0.955435,00:03
7,0.018338,0.123309,0.963261,00:03
8,0.014932,0.119323,0.966522,00:03
9,0.013246,0.117543,0.966522,00:03


epoch,train_loss,valid_loss,accuracy,time
0,0.074861,0.141952,0.94587,00:04
1,0.058864,0.122916,0.953043,00:04
2,0.057619,0.145109,0.947609,00:04
3,0.049641,0.112273,0.952174,00:04
4,0.042912,0.112496,0.955217,00:04
5,0.035776,0.096258,0.962174,00:03
6,0.025166,0.112529,0.960435,00:03
7,0.016151,0.122837,0.960435,00:03
8,0.012872,0.135298,0.963043,00:03
9,0.010802,0.111794,0.966087,00:03


epoch,train_loss,valid_loss,accuracy,time
0,0.0791,0.119465,0.951304,00:03
1,0.075128,0.139575,0.955,00:04
2,0.061338,0.134622,0.950652,00:03
3,0.048908,0.105505,0.958913,00:03
4,0.039424,0.109557,0.959348,00:03
5,0.03671,0.172535,0.948913,00:04
6,0.029664,0.101575,0.963913,00:03
7,0.015547,0.146199,0.96,00:04
8,0.014183,0.139979,0.961304,00:04
9,0.009354,0.142942,0.962174,00:04


0.9729565217391304
0.0009489407087953801


In [17]:
Z=1.96
acc_list_mean = []
for shuffle_id in range(10):
    x_train_shuffled, y_train_shuffled = shuffle(x_train_new, y_train_new)
    learn = train_model_cnn(x_train_shuffled, y_train_shuffled, x_valid, y_valid)
    preds_labels, target_labels = run_model(x_test, y_test, learn)
    acc_list_mean.append(accuracy_score(target_labels, preds_labels))

mean_accuracy = np.mean(acc_list_mean)
std_dev = np.std(acc_list_mean, ddof=1)  # ddof=1 for sample standard deviation
std_error = std_dev / np.sqrt(len(acc_list_mean))
interval_range = Z*std_error  
print(mean_accuracy)
print(interval_range)

epoch,train_loss,valid_loss,accuracy,time
0,0.081189,0.153986,0.951522,00:05
1,0.077566,0.165366,0.94587,00:05
2,0.068099,0.109186,0.96087,00:05
3,0.05093,0.16636,0.95,00:05
4,0.049713,0.132227,0.957391,00:05
5,0.040202,0.105549,0.959348,00:05
6,0.031545,0.112181,0.962609,00:05
7,0.020982,0.103567,0.964348,00:05
8,0.015804,0.114522,0.965217,00:05
9,0.014913,0.139739,0.958261,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.080556,0.172872,0.933043,00:06
1,0.078483,0.141439,0.946304,00:05
2,0.071702,0.129989,0.953913,00:06
3,0.058567,0.153782,0.941739,00:05
4,0.04761,0.107282,0.958478,00:05
5,0.036312,0.147552,0.958478,00:05
6,0.025217,0.182249,0.953696,00:05
7,0.019465,0.13549,0.962391,00:05
8,0.017003,0.142732,0.96,00:05
9,0.012534,0.158932,0.955,00:06


epoch,train_loss,valid_loss,accuracy,time
0,0.073201,0.176859,0.940217,00:06
1,0.07704,0.145647,0.945652,00:05
2,0.073079,0.160492,0.94413,00:06
3,0.050932,0.144776,0.95,00:06
4,0.04325,0.137676,0.95587,00:06
5,0.049827,0.156784,0.949783,00:05
6,0.027697,0.109456,0.964565,00:05
7,0.022214,0.129746,0.957391,00:05
8,0.016639,0.126422,0.962174,00:06
9,0.013718,0.12638,0.962174,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.087579,0.143939,0.948913,00:05
1,0.076441,0.126692,0.951304,00:05
2,0.073307,0.155739,0.942391,00:05
3,0.063449,0.117711,0.956522,00:05
4,0.046661,0.117316,0.958043,00:06
5,0.036743,0.126783,0.955,00:05
6,0.032861,0.150357,0.957174,00:05
7,0.022686,0.148118,0.959565,00:06
8,0.015399,0.136501,0.961304,00:05
9,0.014455,0.153044,0.95913,00:06


epoch,train_loss,valid_loss,accuracy,time
0,0.076549,0.189411,0.94913,00:05
1,0.078633,0.153144,0.948043,00:05
2,0.062342,0.149284,0.952174,00:05
3,0.064435,0.133203,0.959783,00:05
4,0.048342,0.190596,0.945435,00:05
5,0.041374,0.12598,0.960652,00:05
6,0.034127,0.127967,0.959348,00:05
7,0.021265,0.109176,0.963043,00:06
8,0.015854,0.130329,0.961957,00:05
9,0.014805,0.132778,0.961304,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.077667,0.179369,0.949565,00:05
1,0.079877,0.131378,0.950217,00:05
2,0.062853,0.180585,0.945217,00:05
3,0.059556,0.134236,0.945435,00:05
4,0.043507,0.113048,0.957391,00:05
5,0.035722,0.1072,0.960217,00:05
6,0.025718,0.11651,0.959565,00:05
7,0.02362,0.11535,0.958913,00:06
8,0.020353,0.143196,0.956522,00:05
9,0.019825,0.115821,0.961304,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.079287,0.159166,0.939565,00:06
1,0.076438,0.16115,0.942609,00:05
2,0.058994,0.118136,0.958478,00:05
3,0.053827,0.155376,0.951304,00:05
4,0.05004,0.142486,0.948696,00:05
5,0.032326,0.138275,0.952826,00:06
6,0.030993,0.092732,0.96413,00:06
7,0.028875,0.136167,0.954783,00:05
8,0.01596,0.123384,0.965217,00:05
9,0.014905,0.123803,0.96413,00:06


epoch,train_loss,valid_loss,accuracy,time
0,0.082042,0.12853,0.953478,00:06
1,0.0832,0.151259,0.946304,00:05
2,0.064606,0.148126,0.953043,00:05
3,0.053384,0.155813,0.95087,00:06
4,0.041464,0.136724,0.953261,00:05
5,0.04203,0.104042,0.95413,00:06
6,0.031205,0.126771,0.955652,00:05
7,0.025575,0.143505,0.956087,00:05
8,0.021078,0.122075,0.959783,00:06
9,0.013739,0.13623,0.953696,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.085372,0.142966,0.949565,00:05
1,0.087815,0.149377,0.945652,00:06
2,0.075052,0.156391,0.951304,00:05
3,0.056704,0.137228,0.95087,00:05
4,0.047565,0.143549,0.951739,00:05
5,0.034466,0.138287,0.955435,00:05
6,0.035213,0.136038,0.956957,00:05
7,0.020761,0.143521,0.961087,00:05
8,0.018404,0.134496,0.95913,00:05
9,0.013392,0.133708,0.959565,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.075975,0.150988,0.95,00:06
1,0.083785,0.164302,0.943478,00:05
2,0.0647,0.138143,0.951739,00:05
3,0.057915,0.11892,0.94913,00:06
4,0.047062,0.113556,0.958913,00:06
5,0.034779,0.096814,0.962609,00:06
6,0.027613,0.120742,0.958043,00:05
7,0.021065,0.108365,0.958696,00:05
8,0.017899,0.130022,0.957826,00:05
9,0.01651,0.126459,0.958696,00:05


0.9700434782608696
0.0013471058072291923


In [18]:
Z=1.96
acc_list_mean = []
for shuffle_id in range(10):
    x_train_shuffled, y_train_shuffled = shuffle(x_train_new, y_train_new)
    learn = train_model_tst(x_train_shuffled, y_train_shuffled, x_valid, y_valid)
    preds_labels, target_labels = run_model(x_test, y_test, learn)
    acc_list_mean.append(accuracy_score(target_labels, preds_labels))

mean_accuracy = np.mean(acc_list_mean)
std_dev = np.std(acc_list_mean, ddof=1)  # ddof=1 for sample standard deviation
std_error = std_dev / np.sqrt(len(acc_list_mean))
interval_range = Z*std_error  
print(mean_accuracy)
print(interval_range)

epoch,train_loss,valid_loss,accuracy,time
0,0.131377,0.260718,0.923261,00:05
1,0.113559,0.219387,0.923696,00:05
2,0.109277,0.190745,0.931304,00:05
3,0.083169,0.166284,0.936739,00:05
4,0.080173,0.210604,0.936522,00:05
5,0.078791,0.155411,0.946739,00:05
6,0.056126,0.160781,0.949783,00:05
7,0.053344,0.156758,0.948043,00:05
8,0.049937,0.141439,0.951522,00:05
9,0.046235,0.141324,0.952174,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.145294,0.334448,0.921957,00:05
1,0.121851,0.265421,0.930435,00:05
2,0.105485,0.19036,0.927826,00:05
3,0.097467,0.156533,0.938913,00:05
4,0.085938,0.17434,0.931957,00:05
5,0.077912,0.173921,0.94413,00:05
6,0.065825,0.152627,0.947826,00:05
7,0.059341,0.168232,0.941739,00:05
8,0.059506,0.160714,0.944783,00:05
9,0.057141,0.147139,0.951304,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.138375,0.218857,0.93,00:05
1,0.119464,0.307286,0.894348,00:05
2,0.108638,0.299921,0.913696,00:05
3,0.104408,0.212116,0.931304,00:05
4,0.090335,0.160695,0.942609,00:05
5,0.078037,0.159177,0.946957,00:05
6,0.075118,0.167118,0.937391,00:05
7,0.059752,0.169826,0.948913,00:05
8,0.062863,0.169757,0.95087,00:05
9,0.057014,0.168635,0.95087,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.131544,0.288483,0.90587,00:05
1,0.1216,0.237265,0.929348,00:06
2,0.097894,0.188797,0.943913,00:05
3,0.090392,0.165851,0.944348,00:05
4,0.086111,0.15766,0.946522,00:05
5,0.074656,0.19099,0.923913,00:05
6,0.063063,0.179457,0.936522,00:05
7,0.069244,0.160557,0.94087,00:05
8,0.055145,0.144846,0.947391,00:05
9,0.057192,0.155306,0.945217,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.140829,0.45484,0.925435,00:04
1,0.116086,0.198018,0.941522,00:05
2,0.108032,0.196682,0.93087,00:05
3,0.100821,0.224954,0.911739,00:05
4,0.084851,0.197777,0.935652,00:05
5,0.085736,0.160548,0.938696,00:05
6,0.082046,0.176575,0.928696,00:05
7,0.064763,0.167375,0.939783,00:05
8,0.057242,0.169458,0.94,00:05
9,0.054516,0.164483,0.942391,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.124342,0.316493,0.926087,00:05
1,0.12386,0.195582,0.925,00:05
2,0.097514,0.18104,0.934348,00:05
3,0.091928,0.161261,0.940217,00:05
4,0.07883,0.168808,0.935217,00:05
5,0.071767,0.149752,0.938261,00:05
6,0.06423,0.143463,0.946087,00:05
7,0.059271,0.135895,0.943696,00:05
8,0.051029,0.131222,0.951304,00:05
9,0.046912,0.13124,0.953043,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.15101,0.22152,0.92587,00:05
1,0.118123,0.229464,0.935,00:05
2,0.11143,0.193134,0.933043,00:05
3,0.099544,0.223696,0.913913,00:05
4,0.089981,0.172173,0.946957,00:05
5,0.073156,0.197111,0.940435,00:05
6,0.068311,0.160055,0.946957,00:05
7,0.067792,0.159314,0.944348,00:05
8,0.061166,0.150618,0.947826,00:05
9,0.055102,0.159403,0.948696,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.141348,0.257093,0.931087,00:05
1,0.115766,0.226352,0.921304,00:05
2,0.103641,0.194387,0.933696,00:05
3,0.091909,0.174178,0.930435,00:05
4,0.077865,0.146227,0.94913,00:05
5,0.086833,0.173319,0.934783,00:05
6,0.068605,0.192768,0.921304,00:05
7,0.061518,0.148365,0.946957,00:05
8,0.053044,0.150779,0.948696,00:05
9,0.049303,0.145941,0.948696,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.142611,0.287419,0.913913,00:05
1,0.090487,0.210519,0.934783,00:05
2,0.097648,0.204718,0.929565,00:05
3,0.090583,0.184306,0.939565,00:05
4,0.089369,0.18636,0.937609,00:05
5,0.079097,0.15604,0.941304,00:05
6,0.067472,0.158792,0.943043,00:05
7,0.058967,0.154784,0.941304,00:05
8,0.053346,0.137181,0.949348,00:05
9,0.048796,0.140142,0.950435,00:05


epoch,train_loss,valid_loss,accuracy,time
0,0.147965,0.239167,0.92087,00:05
1,0.117286,0.232904,0.922174,00:05
2,0.105599,0.189082,0.93913,00:05
3,0.098343,0.180069,0.931957,00:05
4,0.077831,0.171961,0.943043,00:05
5,0.073354,0.231825,0.934348,00:05
6,0.0718,0.195263,0.93,00:05
7,0.066065,0.159475,0.947174,00:05
8,0.052763,0.167236,0.947826,00:06
9,0.057194,0.166866,0.946957,00:05


0.9603043478260871
0.0031299270119587515


In [None]:
learn_cnn = train_model_cnn(x_train_new, y_train_new, x_valid, y_valid)
preds_labels_cnn, target_labels_cnn = run_model(x_test, y_test, learn_cnn)
acc_cnn = accuracy_score(target_labels_cnn, preds_labels_cnn)

learn_tst = train_model_tst(x_train_new, y_train_new, x_valid, y_valid)
preds_labels_tst, target_labels_tst = run_model(x_test, y_test, learn_tst)
acc_tst = accuracy_score(target_labels_tst, preds_labels_tst)



In [None]:
print('LSTM_FCN: ', acc_lstm)
print('InceptionTime: ', acc_cnn)
print('TST: ', acc_tst)