In [127]:
from scipy import io
import numpy as np
import os
import tensorflow as tf
import tensorflow.keras.backend as K
import math
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt

from TA_CSPNN import *

In [128]:
# for file in os.listdir('matlab_data/'):
#     if not file == "A01T.mat":
#         continue
#     print(file)
    
#     data = io.matlab.loadmat("matlab_data/" + file)['data']
#     labels = io.matlab.loadmat('true_labels/' + file)['classlabel']
#     break
# print(data.shape)
# print(labels.shape)

In [229]:
data = io.matlab.loadmat("baseline/A01T.mat")['data']
labels = io.matlab.loadmat('true_labels/A01T.mat')['classlabel']

In [230]:
num_samples = data.shape[2]

train_size = int(num_samples * .9) # 90/10 train validation split
val_size = num_samples - train_size
print(f"num train samples: {train_size}")
print(f"num val samples: {val_size}")

num train samples: 259
num val samples: 29


In [231]:
rand_idx = np.random.randint(0,288, num_samples)
train_idx = rand_idx[:train_size]
val_idx = rand_idx[train_size:train_size+val_size]

x_train = data[:,:,train_idx]
y_train = labels[train_idx]

x_val = data[:,:,val_idx]
y_val = labels[val_idx]
print(f"x_train shape: {x_train.shape}")
print(f"y_train shape: {y_train.shape}")
print(f"x_val shape: {x_val.shape}")
print(f"y_val shape: {y_val.shape}")

x_train shape: (22, 250, 259)
y_train shape: (259, 1)
x_val shape: (22, 250, 29)
y_val shape: (29, 1)


In [232]:
new_train = []
for i in range(x_train.shape[2]):
    new_train.append(x_train[:,:,i])
    
new_val = []
for i in range(x_val.shape[2]):
    new_val.append(x_val[:,:,i])

    
x_train = np.array(new_train)[:,None,:,:]
x_val = np.array(new_val)[:,None,:,:]

print(f"x_train shape: {x_train.shape}")
print(f"x_val shape: {x_val.shape}")

x_train shape: (259, 1, 22, 250)
x_val shape: (29, 1, 22, 250)


In [233]:
y_train = to_categorical(y_train - 1)
y_val = to_categorical(y_val - 1)

In [234]:
class model_config(object):
    def __init__(self):
        self.channels = 22
        self.timesamples = 250
        self.timeKernelLen = 64
        self.num_classes = 4
        self.Ft = 8
        self.Fs = 2

In [235]:
config = model_config()
model = TA_CSPNN(config.num_classes, Channels=config.channels, Timesamples=config.timesamples,
                timeKernelLen = config.timeKernelLen, Ft=config.Ft, Fs=config.Fs)
opt = Adam(learning_rate=0.01)
model.compile(loss='categorical_crossentropy', optimizer='Adam', metrics = ['accuracy'])

es = EarlyStopping(monitor='val_loss', mode='min', verbose=2, patience=50)

In [236]:
history = model.fit(x_train, y_train, 
                    epochs=500, 
                    validation_data=((x_val, y_val)), 
                    callbacks=[es],
                    verbose=2)

Train on 259 samples, validate on 29 samples
Epoch 1/500
259/259 - 1s - loss: 1.8003 - acc: 0.2548 - val_loss: 2.0824 - val_acc: 0.2069
Epoch 2/500
259/259 - 0s - loss: 1.6948 - acc: 0.2780 - val_loss: 1.8833 - val_acc: 0.2069
Epoch 3/500
259/259 - 0s - loss: 1.7082 - acc: 0.2432 - val_loss: 1.7060 - val_acc: 0.2069
Epoch 4/500
259/259 - 0s - loss: 1.5900 - acc: 0.2355 - val_loss: 1.5978 - val_acc: 0.2069
Epoch 5/500
259/259 - 0s - loss: 1.5226 - acc: 0.3012 - val_loss: 1.5434 - val_acc: 0.2069
Epoch 6/500
259/259 - 0s - loss: 1.4723 - acc: 0.3127 - val_loss: 1.5062 - val_acc: 0.2069
Epoch 7/500
259/259 - 0s - loss: 1.4961 - acc: 0.2896 - val_loss: 1.4647 - val_acc: 0.2069
Epoch 8/500
259/259 - 0s - loss: 1.3881 - acc: 0.3127 - val_loss: 1.4212 - val_acc: 0.2069
Epoch 9/500
259/259 - 0s - loss: 1.3718 - acc: 0.3822 - val_loss: 1.3932 - val_acc: 0.2759
Epoch 10/500
259/259 - 0s - loss: 1.3967 - acc: 0.3243 - val_loss: 1.3712 - val_acc: 0.2759
Epoch 11/500
259/259 - 0s - loss: 1.3484 - a

Epoch 90/500
259/259 - 0s - loss: 1.1779 - acc: 0.4672 - val_loss: 1.3041 - val_acc: 0.4138
Epoch 91/500
259/259 - 0s - loss: 1.1407 - acc: 0.5019 - val_loss: 1.3127 - val_acc: 0.3103
Epoch 92/500
259/259 - 0s - loss: 1.1303 - acc: 0.5174 - val_loss: 1.3146 - val_acc: 0.3103
Epoch 93/500
259/259 - 0s - loss: 1.1635 - acc: 0.4749 - val_loss: 1.3046 - val_acc: 0.3448
Epoch 94/500
259/259 - 0s - loss: 1.0875 - acc: 0.5637 - val_loss: 1.3744 - val_acc: 0.2759
Epoch 95/500
259/259 - 0s - loss: 1.1962 - acc: 0.4595 - val_loss: 1.3324 - val_acc: 0.3103
Epoch 96/500
259/259 - 0s - loss: 1.1129 - acc: 0.5251 - val_loss: 1.2768 - val_acc: 0.3448
Epoch 97/500
259/259 - 0s - loss: 1.1518 - acc: 0.4942 - val_loss: 1.2483 - val_acc: 0.4483
Epoch 98/500
259/259 - 0s - loss: 1.1534 - acc: 0.4942 - val_loss: 1.2596 - val_acc: 0.5862
Epoch 99/500
259/259 - 0s - loss: 1.1250 - acc: 0.5174 - val_loss: 1.2818 - val_acc: 0.3103
Epoch 100/500
259/259 - 0s - loss: 1.1034 - acc: 0.5251 - val_loss: 1.3058 - val

Epoch 179/500
259/259 - 0s - loss: 1.0596 - acc: 0.5753 - val_loss: 1.1928 - val_acc: 0.4828
Epoch 180/500
259/259 - 0s - loss: 1.0625 - acc: 0.5598 - val_loss: 1.1513 - val_acc: 0.5862
Epoch 181/500
259/259 - 0s - loss: 1.0584 - acc: 0.5483 - val_loss: 1.1700 - val_acc: 0.5517
Epoch 182/500
259/259 - 0s - loss: 1.1042 - acc: 0.5174 - val_loss: 1.2022 - val_acc: 0.4828
Epoch 183/500
259/259 - 0s - loss: 1.0761 - acc: 0.5405 - val_loss: 1.2430 - val_acc: 0.3448
Epoch 184/500
259/259 - 0s - loss: 1.1146 - acc: 0.5251 - val_loss: 1.2576 - val_acc: 0.4138
Epoch 185/500
259/259 - 0s - loss: 1.0390 - acc: 0.5483 - val_loss: 1.2558 - val_acc: 0.3793
Epoch 186/500
259/259 - 0s - loss: 1.0700 - acc: 0.5251 - val_loss: 1.3594 - val_acc: 0.2414
Epoch 187/500
259/259 - 0s - loss: 1.0510 - acc: 0.5598 - val_loss: 1.4845 - val_acc: 0.2414
Epoch 188/500
259/259 - 0s - loss: 1.0449 - acc: 0.5521 - val_loss: 1.3836 - val_acc: 0.2759
Epoch 189/500
259/259 - 0s - loss: 1.0800 - acc: 0.5290 - val_loss: 1.

259/259 - 0s - loss: 1.0383 - acc: 0.5792 - val_loss: 1.1498 - val_acc: 0.4828
Epoch 268/500
259/259 - 0s - loss: 1.0268 - acc: 0.5637 - val_loss: 1.1275 - val_acc: 0.5172
Epoch 269/500
259/259 - 0s - loss: 1.0511 - acc: 0.5598 - val_loss: 1.1429 - val_acc: 0.5172
Epoch 270/500
259/259 - 0s - loss: 1.0059 - acc: 0.6100 - val_loss: 1.1617 - val_acc: 0.5172
Epoch 271/500
259/259 - 0s - loss: 1.0178 - acc: 0.5830 - val_loss: 1.1800 - val_acc: 0.5172
Epoch 272/500
259/259 - 0s - loss: 1.0549 - acc: 0.5328 - val_loss: 1.2963 - val_acc: 0.3448
Epoch 273/500
259/259 - 0s - loss: 0.9826 - acc: 0.5830 - val_loss: 1.4251 - val_acc: 0.2414
Epoch 274/500
259/259 - 0s - loss: 0.9817 - acc: 0.6100 - val_loss: 1.3332 - val_acc: 0.3793
Epoch 275/500
259/259 - 0s - loss: 1.0287 - acc: 0.5830 - val_loss: 1.2295 - val_acc: 0.4828
Epoch 276/500
259/259 - 0s - loss: 1.0524 - acc: 0.5753 - val_loss: 1.2509 - val_acc: 0.4483
Epoch 277/500
259/259 - 0s - loss: 0.9533 - acc: 0.6100 - val_loss: 1.2805 - val_acc

Epoch 356/500
259/259 - 0s - loss: 0.9862 - acc: 0.6409 - val_loss: 1.0464 - val_acc: 0.6552
Epoch 357/500
259/259 - 0s - loss: 0.9983 - acc: 0.5907 - val_loss: 1.1548 - val_acc: 0.4828
Epoch 358/500
259/259 - 0s - loss: 0.9936 - acc: 0.5830 - val_loss: 1.1516 - val_acc: 0.4828
Epoch 359/500
259/259 - 0s - loss: 0.9828 - acc: 0.5830 - val_loss: 1.1260 - val_acc: 0.4138
Epoch 360/500
259/259 - 0s - loss: 0.9064 - acc: 0.6178 - val_loss: 1.0872 - val_acc: 0.5172
Epoch 361/500
259/259 - 0s - loss: 1.0331 - acc: 0.5946 - val_loss: 1.0811 - val_acc: 0.4828
Epoch 362/500
259/259 - 0s - loss: 1.0004 - acc: 0.5907 - val_loss: 1.1991 - val_acc: 0.4483
Epoch 363/500
259/259 - 0s - loss: 0.9856 - acc: 0.6062 - val_loss: 1.2096 - val_acc: 0.4138
Epoch 364/500
259/259 - 0s - loss: 0.9213 - acc: 0.6255 - val_loss: 1.1927 - val_acc: 0.4138
Epoch 365/500
259/259 - 0s - loss: 0.9616 - acc: 0.6332 - val_loss: 1.1707 - val_acc: 0.4828
Epoch 366/500
259/259 - 0s - loss: 1.0607 - acc: 0.5792 - val_loss: 1.

In [237]:
test_data = io.matlab.loadmat("baseline/A01E.mat")['data']
y_test = io.matlab.loadmat('true_labels/A01E.mat')['classlabel']

In [238]:
test_data.shape

(22, 250, 288)

In [239]:
x_test = []
for i in range(test_data.shape[2]):
    x_test.append(test_data[:,:,i])
    
x_test = np.array(x_test)[:,None,:,:]

In [240]:
x_test.shape

(288, 1, 22, 250)

In [241]:
y_pred = model.predict(x_test)

pred = np.argmax(y_pred, axis=1) + 1

acc = 0
for i in range(len(y_test)):
    if pred[i] == y_test[i,0]:
        acc += 1
        
acc / len(y_test)

0.2673611111111111