# Imports

In [None]:
!pip install mne
!pip install pyriemann
!wget https://raw.githubusercontent.com/vlawhern/arl-eegmodels/master/examples/EEGNet-8-2-weights.h5
!wget https://raw.githubusercontent.com/vlawhern/arl-eegmodels/master/EEGModels.py
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/checkpoint_clean_unweighted.h5
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/checkpoint_clean_weighted.h5
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/checkpoint_unweighted.h5
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/checkpoint_weighted.h5

In [None]:
import scipy
import random
import pandas as pd

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, r2_score, f1_score

# mne imports
import mne
from mne import io
from mne.datasets import sample

# EEGNet-specific imports
from EEGModels import EEGNet
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K

# PyRiemann imports
from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
# from pyriemann.utils.viz import plot_confusion_matrix
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression

# tools for plotting confusion matrices
from matplotlib import pyplot as plt

In [None]:
# class balance of Y

def class_balance(Y):
  dicc = {}
  for i in range(len(Y)):
    if Y[i] not in dicc:
      dicc[Y[i]] = 0
    dicc[Y[i]] += 1

  for label in dicc:
    print(str(label) + ": " + str(dicc[label]))

def onehot_balance(Y):
  leg = len(Y[0])
  tally = np.zeros(leg)
  for i in range(len(Y)):
    tally += Y[i]
  print(tally)

In [None]:
channels = 2
timepoints = 21000
kernels = 1

# Cleaned Dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
rawX = np.load('/content/drive/MyDrive/BCI_project/X.npy')
rawY = np.load('/content/drive/MyDrive/BCI_project/y.npy')

In [None]:
print(len(rawX))
print(len(rawX[0]))
print(len(rawX[0][0]))

878
32
21000


In [None]:
# X dimesnions: [epochs(878)] [32] [timepoints(21000)]
# X [epoch(20)] [channel(2)] [time(21000)] [kernels(1)]

X = []
Y = []
for e in range(len(rawX)): # for each epoch in the file
  new_epoch = []
  for c in [30,31]: # for each channel
    new_channel = []
    for t in range(21000): # for each timepoint
      new_channel.append([rawX[e][c][t]])
    new_epoch.append(new_channel)
  X.append(new_epoch)

Y = []
for bruh in rawY:
  Y.append(bruh[0])

In [None]:
print(len(X))
print(len(X[0]))
print(len(X[0][0]))
print(len(X[0][0][0]))

878
2
21000
1


In [None]:
Ymap = {
    1: 0, 2: 0, 3: 1,
    4: 0, 5: 0, 6: 1,
    7: 2, 8: 2, 9: 3
}

for i in range(len(Y)):
  Y[i] = Ymap[Y[i]]

In [None]:
class_balance(Y)

0: 408
1: 224
3: 142
2: 104


# Raw (only epoched) dataset

In [None]:
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S1R1.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S1R2.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S1R3.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S8R1.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S8R2.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S8R3.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S8R4.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S8R5.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S8R6.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S11R1.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S11R2.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S11R3.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S11R4.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S11R5.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S11R6.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S12R1.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S12R2.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S12R3.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S12R4.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S12R5.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S12R6.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S13R1.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S13R2.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S13R3.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S13R4.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S13R5.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S13R6.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S15R1.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S15R2.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S15R3.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S15R4.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S15R5.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S15R6.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S17R1.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S17R2.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S17R3.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S17R4.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S17R5.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S19R1.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S19R2.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S19R3.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S19R4.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S19R5.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/S19R6.mat
!wget https://raw.githubusercontent.com/NathanJackyLee/BCI/main/labels.csv

In [None]:
rawY = pd.read_csv("/content/labels.csv")
rawY = list(rawY['label'])

In [None]:
# matdat [channel(2)] [time(21000)] [epoch(20)]
# X [epoch(20)] [channel(2)] [time(21000)] [kernels(1)]
sub = [1,8,11,12,13,15,17,19]
runs = [3,6, 6, 6, 6, 6, 5, 6]

X = []
Y = []
filenum = 0
yidx = 0
for i in range(8):
  for j in range(1,runs[i]+1):
    # for each file

    fname = "/content/S" + str(sub[i]) + "R" + str(j) + ".mat"
    matdat = scipy.io.loadmat(fname)
    matdat = matdat['dat']
    print("on file S" + str(sub[i]) + "R" + str(j) + ".mat")

    yidx = filenum*20
    filenum += 1
    for e in range(len(matdat[0][0])): # for each epoch in the file
      new_epoch = []
      for c in range(2): # for each channel
        new_channel = []
        for t in range(21000): # for each timepoint
          new_channel.append([matdat[c][t][e]])
        new_epoch.append(new_channel)
      X.append(new_epoch)
      # for t in range(21000):
      #   bruh = []
      #   for c in range(2):
      #     bruh.append(matdat[c][t][e])
      #   temp[0].append(bruh)
      # X.append(temp)

      Y.append(rawY[yidx])
      yidx += 1

In [None]:
Ymap = {
    1: 0, 2: 0, 3: 1,
    4: 0, 5: 0, 6: 1,
    7: 2, 8: 2, 9: 3
}

for i in range(len(Y)):
  Y[i] = Ymap[Y[i]]

In [None]:
class_balance(Y)

0: 397
1: 214
3: 154
2: 113


In [None]:
print(len(X[0][0][0]))

1


In [None]:
print(len(X),len(Y))

878 878


# Preparing data for the model

In [None]:
epochs = len(X)
gacha = [i for i in range(epochs)]
# gacha = random.sample(gacha, int(epochs*0.8))
random.shuffle(gacha)

q1 = int(epochs * 0.25)
q2 = int(epochs * 0.5)
q3 = int(epochs * 0.75)

gacha_train = gacha[:q2]
gacha_val = gacha[q2:q3]
gacha_test = gacha[q3:]

X_train = []
Y_train = []
X_val = []
Y_val = []
X_test = []
Y_test = []

for i in gacha_train:
  X_train.append(X[i])
  Y_train.append(Y[i])
for i in gacha_val:
  X_val.append(X[i])
  Y_val.append(Y[i])
for i in gacha_test:
  X_test.append(X[i])
  Y_test.append(Y[i])

Y_train = np_utils.to_categorical(Y_train)
Y_val = np_utils.to_categorical(Y_val)
Y_test = np_utils.to_categorical(Y_test)

# X_train = []
# Y_train = []
# X_val = []
# Y_val = []
# X_test = []
# Y_test = []
X_train = np.array(X_train, dtype=np.float32)
X_val = np.array(X_val, dtype=np.float32)
X_test = np.array(X_test, dtype=np.float32)
Y_train = np.array(Y_train, dtype=np.float32)
Y_val = np.array(Y_val, dtype=np.float32)
Y_test = np.array(Y_test, dtype=np.float32)

In [None]:
onehot_balance(Y_train)
onehot_balance(Y_val)
onehot_balance(Y_test)

[203. 113.  52.  71.]
[108.  52.  25.  34.]
[97. 59. 27. 37.]


In [None]:
print(len(Y_test))

220


# Model

In [None]:
# configure the EEGNet-8,2,16 model with kernel length of 32 samples (other
# model configurations may do better, but this is a good starting point)
model = EEGNet(nb_classes = 4, Chans = channels, Samples = timepoints,
               dropoutRate = 0.5, kernLength = 32, F1 = 8, D = 2, F2 = 16,
               dropoutType = 'Dropout')

In [None]:
# compile the model and set the optimizers
model.compile(loss='categorical_crossentropy', optimizer='adam',
              metrics = ['accuracy'])

In [None]:
# count number of parameters in the model
numParams    = model.count_params()
print(numParams)

42948


In [None]:
# set a valid path for your system to record model checkpoints
checkpointer = ModelCheckpoint(filepath='/content/checkpoint_clean_weighted.h5', verbose=1,
                               save_best_only=True)

In [None]:
#can use class weights for imbalanced datasets
class_weights = {
    0: 1,
    1: 1.9,
    2: 3.6,
    3: 2.6
}

fittedModel = model.fit(X_train, Y_train, batch_size = 16, epochs = 300,
                        verbose = 2, validation_data=(X_val, Y_val),
                        callbacks=[checkpointer], class_weight = class_weights)

# Standard way
# fittedModel = model.fit(X_train, Y_train, batch_size = 16, epochs = 300,
#                         verbose = 2, validation_data=(X_val, Y_val),
#                         callbacks=[checkpointer])

In [None]:
### this is for raw weighted
# WEIGHTS_PATH = '/content/checkpoint_weighted.h5'
# model.load_weights(WEIGHTS_PATH)

### this is for raw unweighted
# WEIGHTS_PATH = '/content/checkpoint_unweighted.h5'
# model.load_weights(WEIGHTS_PATH)

### this is for clean unweighted
# WEIGHTS_PATH = '/content/checkpoint_clean_unweighted.h5'
# model.load_weights(WEIGHTS_PATH)

### this is for clean unweighted
# WEIGHTS_PATH = '/content/checkpoint_clean_unweighted.h5'
# model.load_weights(WEIGHTS_PATH)

In [None]:
XX = np.array(X, dtype=np.float32)

In [None]:
YY = np_utils.to_categorical(Y)

In [None]:
probs       = model.predict(XX)
preds       = probs.argmax(axis = -1)



In [None]:
len(preds)

878

In [None]:
actual = YY.argmax(axis = -1)

In [None]:
acc = accuracy_score(actual,preds)
f1 = f1_score(actual,preds,average='macro')
print("Accuracy: ",acc)
print("F1-score: ",f1)

Accuracy:  0.6343963553530751
F1-score:  0.6073879802748562


In [None]:
confmat = [[0 for i in range(4)] for j in range(4)]

In [None]:
confmat

[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]

In [None]:
hits = 0
for a,b in zip(preds,actual):
  confmat[a][b] += 1
  # if a == b:
  #   hits += 1

In [None]:
for row in confmat:
  print(row)

[295, 72, 34, 45]
[35, 113, 11, 12]
[36, 15, 58, 10]
[31, 14, 10, 87]
