# Import Library

In [None]:
import mne
import numpy as np
import os
import os.path as op
import matplotlib.pyplot as plt
import nibabel as nib
from mne.datasets import sample
from mne.minimum_norm import make_inverse_operator, apply_inverse_epochs, apply_inverse
from mne.datasets import fetch_fsaverage
import scipy.io
from scipy.io import loadmat
from scipy.spatial import Delaunay
import PIL
from PIL import Image
import datetime
import tensorflow as tf
from tensorflow import keras
from keras.preprocessing.image import ImageDataGenerator
from keras import Sequential
from keras.layers import Conv2D, MaxPool2D, GlobalAveragePooling2D, Dense, Flatten, Concatenate, BatchNormalization, Dropout, Input
from keras.layers.merge import concatenate
from tensorflow.keras.optimizers import Adam
# Load the TensorBoard notebook extension
%load_ext tensorboard
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, cohen_kappa_score
from sklearn.model_selection import StratifiedKFold, KFold
import gc

%matplotlib inline
#%matplotlib qt

DIRECTORY_PATH = os.getcwd()
EXTERNAL_STORAGE_PATH = "D:\Motor Imagery"
n_splits = 5

# Preprocess Data

In [None]:
channels_mapping = {
    "EEG-Fz": "Fz",
    "EEG-0": "FC3",
    "EEG-1": "FC1",
    "EEG-2": "FCz",
    "EEG-3": "FC2",
    "EEG-4": "FC4",
    "EEG-5": "C5",
    "EEG-C3": "C3", 
    "EEG-6": "C1",
    "EEG-Cz": "Cz",
    "EEG-7": "C2",
    "EEG-C4": "C4",
    "EEG-8": "C6",
    "EEG-9": "CP3",
    "EEG-10": "CP1",
    "EEG-11": "CPz",
    "EEG-12": "CP2",
    "EEG-13": "CP4",
    "EEG-14": "P1",
    "EEG-Pz": "Pz",
    "EEG-15": "P2",
    "EEG-16": "POz",
    "EOG-left": "EOG-left",
    "EOG-central": "EOG-central",
    "EOG-right": "EOG-right"
}

channels_type_mapping = {
    "Fz": "eeg",
    "FC3": "eeg",
    "FC1": "eeg",
    "FCz": "eeg",
    "FC2": "eeg",
    "FC4": "eeg",
    "C5": "eeg",
    "C3": "eeg", 
    "C1": "eeg",
    "Cz": "eeg",
    "C2": "eeg",
    "C4": "eeg",
    "C6": "eeg",
    "CP3": "eeg",
    "CP1": "eeg",
    "CPz": "eeg",
    "CP2": "eeg",
    "CP4": "eeg",
    "P1": "eeg",
    "Pz": "eeg",
    "P2": "eeg",
    "POz": "eeg",
    "EOG-left": "eog",
    "EOG-central": "eog",
    "EOG-right": "eog"
}

In [None]:
img = nib.load("/Users/ivanl/Downloads/MRIcron_windows/MRIcron/Resources/templates/brodmann.nii.gz")

brodmann_data = img.get_fdata()
# Areas 3, 1 and 2 – Primary somatosensory cortex in the postcentral gyrus (frequently referred to as Areas 3, 1, 2 by convention)
# Area 4– Primary motor cortex
# Area 5 – Superior parietal lobule
# Area 6 – Premotor cortex and Supplementary Motor Cortex (Secondary Motor Cortex) (Supplementary motor area)
# Area 7 – Visuo-Motor Coordination
brodmann_motor = None
selected_area = [1, 2, 3, 4, 5, 6, 7]

for area in selected_area:
    if brodmann_motor is None:
        brodmann_motor = brodmann_data.reshape(-1) == area
    else:
        brodmann_motor += brodmann_data.reshape(-1) == area
#brodmann_motor = brodmann_data.reshape(-1) == 4
print(brodmann_motor)
print("brodmann template shape: " + str(brodmann_data.shape))
print("chosen points: " + str(np.sum(brodmann_motor)))

shape, affine = img.shape[:3], img.affine
coords = np.array(np.meshgrid(*(range(i) for i in shape), indexing='ij'))
coords = np.rollaxis(coords, 0, len(shape) + 1)
mm_coords = nib.affines.apply_affine(affine, coords)

def in_hull(p, hull):
    """
    Test if points in `p` are in `hull`

    `p` should be a `NxK` coordinates of `N` points in `K` dimensions
    `hull` is either a scipy.spatial.Delaunay object or the `MxK` array of the 
    coordinates of `M` points in `K`dimensions for which Delaunay triangulation
    will be computed
    """
    if not isinstance(hull,Delaunay):
        hull = Delaunay(hull)

    return hull.find_simplex(p)>=0

my_left_points = None
my_right_points = None

In [None]:
""""
labels utility function
"""
def load_subject_labels(name="A01E.mat", dir="drive/Shareddrives/Motor Imagery/BCI competition IV dataset/2a/2a true_labels/"):
  data = scipy.io.loadmat(dir + name)["classlabel"].reshape(-1)
  return data

def load_all_true_labels(dataset_path):
  data = {}
  for root, dirs, files in os.walk(dataset_path):
    for file in files:
      data[file] = load_subject_labels(name=file, dir=root) 
  return data

"""
plot graph utility function
"""
def plot_average_graph(subject_name="A01T.gdf", Class="left", filter_channels=None):
  average = {"left": None, "right": None, "foot": None, "tongue": None, "unknown": None}
  for event_class, event_data in data[subject_name]["epoch_data"].items():
    if event_data != []:
      average[event_class] = np.transpose(np.mean(event_data, axis=0))

  x = average[Class]
  if filter_channels is None:
    fig, axs = plt.subplots(x.shape[1], gridspec_kw={'hspace': 0})
    fig.set_size_inches(37, 21)
    for channel in range(x.shape[1]):
      axs[channel].title.set_text(ch_names[channel])
      axs[channel].title.set_size(20)
      axs[channel].title.set_y(0.7)
      axs[channel].plot(range(x.shape[0]), x[:, channel])
      axs[channel].axvline(x=250, color="r", linestyle='--')
      #axs[channel].axvline(x=875, color="r", linestyle='--')
  else :
    fig, axs = plt.subplots(len(filter_channels), gridspec_kw={'hspace': 0})
    fig.set_size_inches(37, 10.5)
    for i in range(len(filter_channels)):
      for channel in range(x.shape[1]):
        if(filter_channels[i] == ch_names[channel]):
          axs[i].title.set_text(ch_names[channel])
          axs[i].title.set_size(20)
          axs[i].title.set_y(0.7)
          axs[i].plot(range(x.shape[0]), x[:, channel])
          axs[i].axvline(x=250, color="r", linestyle='--')
          #axs[i].axvline(x=875, color="r", linestyle='--')
          break
  plt.tight_layout()

def plot_multiple_graph(subject_name="A02T.gdf", classes=["left", "right", "foot", "tongue"], filter_channels=None):
  average = {"left": None, "right": None, "foot": None, "tongue": None, "unknown": None}
  for event_class, event_data in data[subject_name]["epoch_data"].items():
    if event_data != []:
      average[event_class] = np.transpose(np.mean(event_data, axis=0))

  color = {"left": "b", "right": "g", "foot": "c", "tongue": "m", "tongue": "y"}
  x = []
  for Class in classes:
    x.append(average[Class])

  if filter_channels is None:
    fig, axs = plt.subplots(x[0].shape[1], gridspec_kw={'hspace': 0})
    fig.set_size_inches(37, 21)
    for channel in range(x[0].shape[1]):
      axs[channel].title.set_text(ch_names[channel])
      axs[channel].title.set_size(20)
      axs[channel].title.set_y(0.7)
      axs[channel].axvline(x=250, color="r", linestyle='--')
      #axs[channel].axvline(x=875, color="r", linestyle='--')
      for i in range(len(classes)):
        axs[channel].plot(range(x[i].shape[0]), x[i][:, channel], color=color[classes[i]])
  else:
    fig, axs = plt.subplots(len(filter_channels), gridspec_kw={'hspace': 0})
    fig.set_size_inches(37, 10.5)
    for i in range(len(filter_channels)):
      for channel in range(x[0].shape[1]):
        if(filter_channels[i] == ch_names[channel]):
          axs[i].title.set_text(ch_names[channel])
          axs[i].title.set_size(20)
          axs[i].title.set_y(0.7)
          axs[i].axvline(x=250, color="r", linestyle='--')
          #axs[i].axvline(x=875, color="r", linestyle='--')
          for j in range(len(classes)):
            axs[i].plot(range(x[j].shape[0]), x[j][:, channel], color=color[classes[j]])
          break
  plt.tight_layout()

"""
load data function
"""
def load_subject(name="A01T.gdf", dir='drive/Shareddrives/Motor Imagery/BCI competition IV dataset/2a/BCICIV_2a_gdf/', debug=None):
  subject_data = {}
  # Load data
  raw = mne.io.read_raw_gdf(dir + name)
  # Rename channels
  raw.rename_channels(channels_mapping)
  # Set channels types
  raw.set_channel_types(channels_type_mapping)
  # Set montage
  # Read and set the EEG electrode locations
  ten_twenty_montage = mne.channels.make_standard_montage('standard_1020')
  raw.set_montage(ten_twenty_montage)
  # Set common average reference
  raw.set_eeg_reference('average', projection=True, verbose=False)
  # Drop eog channels
  raw.drop_channels(["EOG-left", "EOG-central", "EOG-right"])

  subject_data["raw"] = raw
  subject_data["info"] = raw.info
  if debug == "all":
    print("-------------------------------------------------------------------------------------------------------")
    for key, item in raw.info.items():
      print(key, item)
    print("-------------------------------------------------------------------------------------------------------")
  
  """
  '276': 'Idling EEG (eyes open)'
  '277': 'Idling EEG (eyes closed)'
  '768': 'Start of a trial'
  '769': 'Cue onset left (class 1)'
  '770': 'Cue onset right (class 2)'
  '771': 'Cue onset foot (class 3)'
  '772': 'Cue onset tongue (class 4)'
  '783': 'Cue unknown'
  '1023': 'Rejected trial'
  '1072': 'Eye movements'
  '32766': 'Start of a new run'
  """
  custom_mapping = {'276': 276, '277': 277, '768': 768, '769': 769, '770': 770, '771': 771, '772': 772, '783': 783, '1023': 1023, '1072': 1072, '32766': 32766}
  events_from_annot, event_dict = mne.events_from_annotations(raw, event_id=custom_mapping)

  if debug == " all":
    print("-------------------------------------------------------------------------------------------------------")
    print(event_dict)
    print(events_from_annot)
    print("-------------------------------------------------------------------------------------------------------")
    
    print("-------------------------------------------------------------------------------------------------------")
    for i in range(len(raw.annotations)):
      print(events_from_annot[i], raw.annotations[i])  
    print("-------------------------------------------------------------------------------------------------------")

  class_info = "Idling EEG (eyes open): " + str(len(events_from_annot[events_from_annot[:, 2]==276][:, 0])) + "\n" + \
               "Idling EEG (eyes closed): " + str(len(events_from_annot[events_from_annot[:, 2]==277][:, 0])) + "\n" + \
               "Start of a trial: " + str(len(events_from_annot[events_from_annot[:, 2]==768][:, 0])) + "\n" + \
               "Cue onset left (class 1): " + str(len(events_from_annot[events_from_annot[:, 2]==769][:, 0])) + "\n" + \
               "Cue onset right (class 2): " + str(len(events_from_annot[events_from_annot[:, 2]==770][:, 0])) + "\n" + \
               "Cue onset foot (class 3): " + str(len(events_from_annot[events_from_annot[:, 2]==771][:, 0])) + "\n" + \
               "Cue onset tongue (class 4): " + str(len(events_from_annot[events_from_annot[:, 2]==772][:, 0])) + "\n" + \
               "Cue unknown: " + str(len(events_from_annot[events_from_annot[:, 2]==783][:, 0])) + "\n" + \
               "Rejected trial: " + str(len(events_from_annot[events_from_annot[:, 2]==1023][:, 0])) + "\n" + \
               "Eye movements: " + str(len(events_from_annot[events_from_annot[:, 2]==1072][:, 0])) + "\n" + \
               "Start of a new run: " + str(len(events_from_annot[events_from_annot[:, 2]==32766][:, 0]))
  subject_data["class_info"] = class_info

  if debug == "all" or debug == "important":
    print("-------------------------------------------------------------------------------------------------------")
    print(class_info)
    print("-------------------------------------------------------------------------------------------------------")

  epoch_data = {"left": [], "right": [], "foot": [], "tongue": [], "unknown": []}
  rejected_trial = events_from_annot[events_from_annot[:, 2]==1023][:, 0]
  class_dict = {"left": 769, "right": 770, "foot": 771, "tongue": 772, "unknown": 783}
  raw_data = raw.get_data()  #(22, 672528)
  start = 10                 # cue+0.1s
  stop = 510                 # cue+2.1s

  for event_class, event_id in class_dict.items():
    current_event = events_from_annot[events_from_annot[:, 2]==event_id][:, 0]
    if event_class == "unknown":
      subject_true_labels = true_labels[name[:4]+".mat"]
      class_dict_labels = {1: "left", 2: "right", 3: "foot", 4: "tongue"}
      for i in range(len(current_event)):
        # exclude artifact
        if (current_event[i] - 500 != rejected_trial).all():
          current_event_data = np.expand_dims(np.array(raw_data[:22, current_event[i]+start:current_event[i]+stop]), axis=0)
          if (epoch_data.get(class_dict_labels[subject_true_labels[i]]) == None).all():
            epoch_data[class_dict_labels[subject_true_labels[i]]] = current_event_data
          else:
            epoch_data[class_dict_labels[subject_true_labels[i]]] = np.append(epoch_data[class_dict_labels[subject_true_labels[i]]], current_event_data, axis=0)
    else:
      for i in range(len(current_event)):
        # exclude artifact
        if((current_event[i] - 500 != rejected_trial).all()):
          epoch_data[event_class].append(np.array(raw_data[:22, current_event[i]+start:current_event[i]+stop]))
      epoch_data[event_class] = np.array(epoch_data[event_class])

  if debug == "all" or debug == "important":
    print("-------------------------------------------------------------------------------------------------------")
    for key, data in epoch_data.items():
      print(key, len(data))
    print("-------------------------------------------------------------------------------------------------------")

  for event_class, event_data in epoch_data.items():
    epoch_data[event_class] = np.array(event_data)

  subject_data["epoch_data"] = epoch_data
    

  return subject_data

def load_all_subject(dataset_path):
  data = {}
  for root, dirs, files in os.walk(dataset_path):
    for file in files:
      data[file] = load_subject(name=file, dir=root) 
  return data

"""
create mne epochs data structure from numpy array
merge training and evaluation data
"""
def create_epochs(data):
  subjects_data = {}

  for subject in data.keys():
    if "E" in subject:
        continue
    epochs_data = {}
    for event in data[subject]["epoch_data"].keys():
      current_event_data = None
      
      if data[subject]["epoch_data"][event].any():
        current_event_data = data[subject]["epoch_data"][event]
      if data[subject[:3]+"E.gdf"]["epoch_data"][event].any():
        current_event_data = np.append(current_event_data, data[subject[:3]+"E.gdf"]["epoch_data"][event], axis=0)
      if current_event_data is not None:
          epochs_data[event] = mne.EpochsArray(current_event_data, data[subject]["info"], verbose=False)

    subjects_data[subject[:3]] = epochs_data

  return subjects_data

"""
Create source activity and reconstructed eeg respectively for each subject

For each subject, there are four events in total, i.e. left, right, foot, tongue
Split these data into train and test set using kfold
Compute the noise covariance matrix on train set and apply it to test set
Create source activity (only motor region) first by applying an inverse operator to the epochs 
Create reconstructed eeg by applying a forward operator to the source activity acquired earlier
Save both these files to disk
"""
def apply_inverse_and_forward_kfold(epochs, n_splits=5, save_inverse=True, save_forward=True, events=["left", "right"], subjects=None):
    global my_left_points, my_right_points
    
    if subjects is None:
        subjects = epochs.keys()
        
    for subject in subjects:  
        print(subject)
        X, Y = [], []
        info = None
        counter = 0
        for event in epochs[subject].keys():
            if info is None:
                info = epochs[subject][event].info
            for i in range(len(events)):
                if event == events[i]:
                    print(event)
                    if len(X) == 0:
                        X = epochs[subject][event].get_data()
                        Y = np.zeros(len(epochs[subject][event].get_data())) + i
                    else:
                        X = np.append(X, epochs[subject][event].get_data(), axis=0)
                        Y = np.append(Y, np.zeros(len(epochs[subject][event].get_data())) + i, axis=0)
        X = np.array(X)
        Y = np.array(Y)
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)
        for train_index, test_index in skf.split(X, Y):
            counter += 1
            X_train, X_test = X[train_index], X[test_index]
            Y_train, Y_test = Y[train_index], Y[test_index]
            
            X_train = mne.EpochsArray(X_train, info, verbose=False)
            X_test = mne.EpochsArray(X_test, info, verbose=False)
            
            noise_cov = mne.compute_covariance(X_train, tmax=0., method=['shrunk', 'empirical'], rank=None, verbose=False)
            fwd = mne.make_forward_solution(info, trans=trans, src=src,
                            bem=bem, eeg=True, meg=False, mindist=5.0, n_jobs=1, verbose=False)
            fwd_fixed = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                         use_cps=True, verbose=False)
            leadfield = fwd_fixed['sol']['data']
            inverse_operator = make_inverse_operator(info, fwd, noise_cov, loose=0.2, depth=0.8, verbose=False)
            
            method = "sLORETA"
            snr = 3.
            lambda2 = 1. / snr ** 2
            stc_train = apply_inverse_epochs(X_train, inverse_operator, lambda2,
                                          method=method, pick_ori="normal", verbose=True)
            
            # get motor region points (once)
            if my_left_points is None and my_right_points is None:
                my_source = stc_train[0]
                mni_lh = mne.vertex_to_mni(my_source.vertices[0], 0, mne_subject)
                #print(mni_lh.shape)
                mni_rh = mne.vertex_to_mni(my_source.vertices[1], 1, mne_subject)
                #print(mni_rh.shape)

                """
                fig = plt.figure(figsize=(8, 8))
                ax = fig.add_subplot(projection='3d')
                ax.scatter(mm_coords.reshape(-1, 3)[brodmann_motor][:, 0], mm_coords.reshape(-1, 3)[brodmann_motor][:, 1], mm_coords.reshape(-1, 3)[brodmann_motor][:, 2], s=15, marker='|')
                ax.scatter(mni_lh[:, 0], mni_lh[:, 1], mni_lh[:, 2], s=15, marker='_')
                ax.scatter(mni_rh[:, 0], mni_rh[:, 1], mni_rh[:, 2], s=15, marker='_')
                ax.set_xlabel('X Label')
                ax.set_ylabel('Y Label')
                ax.set_zlabel('Z Label')
                plt.show()
                """

                my_left_points = in_hull(mni_lh, mm_coords.reshape(-1, 3)[brodmann_motor])
                my_right_points = in_hull(mni_rh, mm_coords.reshape(-1, 3)[brodmann_motor])

                mni_left_motor = mne.vertex_to_mni(my_source.vertices[0][my_left_points], 0, mne_subject)
                #print(mni_left_motor.shape)
                mni_right_motor = mne.vertex_to_mni(my_source.vertices[1][my_right_points], 1, mne_subject)
                #print(mni_right_motor.shape)

                """
                fig = plt.figure(figsize=(8, 8))
                ax = fig.add_subplot(projection='3d')
                ax.scatter(mni_lh[:, 0], mni_lh[:, 1], mni_lh[:, 2], s=15, marker='|')
                ax.scatter(mni_rh[:, 0], mni_rh[:, 1], mni_rh[:, 2], s=15, marker='_')
                ax.scatter(mni_left_motor[:, 0], mni_left_motor[:, 1], mni_left_motor[:, 2], s=15, marker='o')
                ax.scatter(mni_right_motor[:, 0], mni_right_motor[:, 1], mni_right_motor[:, 2], s=15, marker='^')
                ax.set_xlabel('X Label')
                ax.set_ylabel('Y Label')
                ax.set_zlabel('Z Label')
                plt.show()
                """
                
            #print("Leadfield size : %d sensors x %d dipoles" % leadfield.shape)
            #print(stc_train[0].data.shape)
            
            # train set
            # slice source activity data
            left_hemi_data = []
            right_hemi_data = []
            for source in stc_train:
                left_hemi_data.append(source.data[:len(source.vertices[0])][my_left_points])
                right_hemi_data.append(source.data[-len(source.vertices[1]):][my_right_points])
            left_hemi_data = np.array(left_hemi_data)
            right_hemi_data = np.array(right_hemi_data)
            if save_inverse:
                source_activity_path = op.join(EXTERNAL_STORAGE_PATH, "data", "source activity", subject)
                if not op.exists(source_activity_path):
                    os.makedirs(source_activity_path)
                np.savez_compressed(op.join(source_activity_path, str(counter)+"_train_X.npz"), data=np.append(left_hemi_data, right_hemi_data, axis=1))
                np.savez_compressed(op.join(source_activity_path, str(counter)+"_train_Y.npz"), data=Y_train)
            # slice reconstructed eeg data
            reconstructed_eeg_data = []
            for source in stc_train:
                motor_source = np.zeros_like(source.data)
                motor_source[:len(source.vertices[0])][my_left_points] = source.data[:len(source.vertices[0])][my_left_points]
                motor_source[-len(source.vertices[1]):][my_right_points] = source.data[-len(source.vertices[1]):][my_right_points]
                motor_eeg = np.dot(leadfield, motor_source)
                reconstructed_eeg_data.append(motor_eeg)
            if save_forward:
                reconstructed_eeg_path = op.join(EXTERNAL_STORAGE_PATH, "data", "reconstructed eeg", subject)
                if not op.exists(reconstructed_eeg_path):
                    os.makedirs(reconstructed_eeg_path)
                np.savez_compressed(op.join(reconstructed_eeg_path, str(counter)+"_train_X.npz"), data=np.array(reconstructed_eeg_data))
                np.savez_compressed(op.join(reconstructed_eeg_path, str(counter)+"_train_Y.npz"), data=Y_train)
            
            del stc_train
            gc.collect()
            
            stc_test = apply_inverse_epochs(X_test, inverse_operator, lambda2,
                              method=method, pick_ori="normal", verbose=True)
            # test set
            # slice source activity data
            left_hemi_data = []
            right_hemi_data = []
            for source in stc_test:
                left_hemi_data.append(source.data[:len(source.vertices[0])][my_left_points])
                right_hemi_data.append(source.data[-len(source.vertices[1]):][my_right_points])
            left_hemi_data = np.array(left_hemi_data)
            right_hemi_data = np.array(right_hemi_data)
            if save_inverse:
                source_activity_path = op.join(EXTERNAL_STORAGE_PATH, "data", "source activity", subject)
                if not op.exists(source_activity_path):
                    os.makedirs(source_activity_path)
                np.savez_compressed(op.join(source_activity_path, str(counter)+"_test_X.npz"), data=np.append(left_hemi_data, right_hemi_data, axis=1))
                np.savez_compressed(op.join(source_activity_path, str(counter)+"_test_Y.npz"), data=Y_test)
            # slice reconstructed eeg data
            reconstructed_eeg_data = []
            for source in stc_test:
                motor_source = np.zeros_like(source.data)
                motor_source[:len(source.vertices[0])][my_left_points] = source.data[:len(source.vertices[0])][my_left_points]
                motor_source[-len(source.vertices[1]):][my_right_points] = source.data[-len(source.vertices[1]):][my_right_points]
                motor_eeg = np.dot(leadfield, motor_source)
                reconstructed_eeg_data.append(motor_eeg)
            if save_forward:
                reconstructed_eeg_path = op.join(EXTERNAL_STORAGE_PATH, "data", "reconstructed eeg", subject)
                if not op.exists(reconstructed_eeg_path):
                    os.makedirs(reconstructed_eeg_path)
                np.savez_compressed(op.join(reconstructed_eeg_path, str(counter)+"_test_X.npz"), data=np.array(reconstructed_eeg_data))
                np.savez_compressed(op.join(reconstructed_eeg_path, str(counter)+"_test_Y.npz"), data=Y_test)
            
            del X_train, X_test, Y_train, Y_test
            del stc_test, reconstructed_eeg_data, left_hemi_data, right_hemi_data
            gc.collect()

"""
Total params: 699,685
Trainable params: 0
Non-trainable params: 699,685
"""
def create_model():
    model = tf.keras.models.Sequential([
        Conv2D(filters=4, kernel_size=(3, 3), strides=(1, 1), padding='same', activation="selu"),
        BatchNormalization(),
        MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="valid"),
        Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', activation="selu"),
        BatchNormalization(),
        MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="valid"),
        Flatten(),
        Dense(50, activation="selu"),
        Dense(1, activation="sigmoid")
    ])

    return model

def stft_min_max(X, Y, debug=True):
    Zxx = tf.signal.stft(X, frame_length=256, frame_step=16)
    Zxx = tf.abs(Zxx)

    if debug:
      print("shape of X and Y: " + str(X.shape) + " " + str(Y.shape))
      print("shape of Zxx: " + str(Zxx.shape))

      # plot spectrogram
      #samples = 0
      #print(Y[samples])
      #log_spec = tf.math.log(tf.transpose(Zxx[samples][0]))
      #height = 40
      #width = log_spec.shape[1]
      #x_axis = tf.linspace(0, 2, num=width)
      #y_axis = range(height)
      #plt.pcolormesh(x_axis, y_axis, log_spec[:40, ])
      #plt.title('STFT Magnitude')
      #plt.ylabel('Frequency [Hz]')
      #plt.xlabel('Time [sec]')
      #plt.show()
    
    X = Zxx[:, :, :, :40]
    X = tf.reshape(X, [X.shape[0], -1, 40])
    X = tf.transpose(X, perm=[0, 2, 1])
    X = tf.expand_dims(X, axis=3)
    
    # min max scaling (per instance)
    original_shape = X.shape
    X = tf.reshape(X, [original_shape[0], -1])
    X_max = tf.math.reduce_max(X, axis=1, keepdims=True)
    X_min = tf.math.reduce_min(X, axis=1, keepdims=True)
    X = tf.math.divide(tf.math.subtract(X, X_min), tf.math.subtract(X_max, X_min))
    X = tf.reshape(X, original_shape)
    
    if debug:
      print("shape of X and Y: " + str(X.shape) + " " + str(Y.shape))

      # plot spectrogram
      #samples = 0
      #print(Y[samples])
      #log_spec = tf.math.log(X[samples][:,:16,0])
      #height = 40
      #width = log_spec.shape[1]
      #x_axis = tf.linspace(0, 2, num=width)
      #y_axis = range(height)
      #plt.pcolormesh(x_axis, y_axis, log_spec)
      #plt.title('STFT Magnitude')
      #plt.ylabel('Frequency [Hz]')
      #plt.xlabel('Time [sec]')
      #plt.show()

    return X, Y

In [None]:
# cd to google drive
os.chdir("G:")

# Download fsaverage files
fs_dir = fetch_fsaverage(verbose=True)
subjects_dir = op.dirname(fs_dir)

# The files live in:
mne_subject = 'fsaverage'
trans = 'fsaverage'  # MNE has a built-in fsaverage transformation
src = op.join(fs_dir, 'bem', 'fsaverage-ico-5-src.fif')
bem = op.join(fs_dir, 'bem', 'fsaverage-5120-5120-5120-bem-sol.fif')

source = mne.read_source_spaces(src)
left = source[0]
right = source[1]
left_pos = left["rr"][left["inuse"]==1]
right_pos = right["rr"][right["inuse"]==1]
                        
transformation = mne.read_trans(op.join(fs_dir, "bem", "fsaverage-trans.fif"))

save_path = op.join(os.getcwd(), "Shared drives", "Motor Imagery", "Source Estimate")

In [None]:
true_labels_path = "Shared drives/Motor Imagery/BCI competition IV dataset/2a/2a true_labels/"
true_labels = load_all_true_labels(true_labels_path)

dataset_path = 'Shared drives/Motor Imagery/BCI competition IV dataset/2a/BCICIV_2a_gdf/'
data = load_all_subject(dataset_path)

In [None]:
# some information to help to understand functions and data structure
# for key, item in data.items():
#   print(key)

# ch_names = data["A01T.gdf"]["info"]["ch_names"]
# print(ch_names)

# print(data["A01T.gdf"]["class_info"])

# for key, value in data.items():
#   print(key)
#   for event_class, event_data in value["epoch_data"].items():
#       print(event_class, len(event_data))
#   print()

# subject_name = "A01T.gdf"
# Class = "left"
# filter_channels = ["C3", "Cz", "C4"]
# plot_average_graph(subject_name, Class, filter_channels)

# subject_name = "A02T.gdf"
# classes = ["left", "right"]
# filter_channels = ["C3", "Cz", "C4"]
# plot_multiple_graph(subject_name, classes, filter_channels)

In [None]:
epochs = create_epochs(data)
#apply_inverse_and_forward_kfold(epochs, n_splits=n_splits, subjects=["A05"], events=['foot', 'tongue'])

In [None]:
# some information to help to understand functions and data structure
# my_epochs = epochs["A01"]["right"]
# my_evoked = my_epochs.average().pick("eeg")

# noise_cov = mne.compute_covariance(my_epochs, tmax=0., method=['shrunk', 'empirical'], rank=None, verbose=False)
# fwd = mne.make_forward_solution(my_epochs.info, trans=trans, src=src,
#                             bem=bem, eeg=True, meg=False, mindist=5.0, n_jobs=1)
# # forward matrix
# fwd_fixed = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
#                                          use_cps=True)

# inverse_operator = make_inverse_operator(
#     my_epochs.info, fwd, noise_cov, loose=0.2, depth=0.8)

# method = "sLORETA"
# snr = 3.
# lambda2 = 1. / snr ** 2
# stc = mne.minimum_norm.apply_inverse(my_evoked, inverse_operator, lambda2,
#                               method=method, pick_ori="normal", verbose=True)

# reconstruct_evoked = mne.apply_forward(fwd_fixed, stc, my_evoked.info)
# my_evoked.plot_topomap()
# reconstruct_evoked.plot_topomap()

# CNN Classification (original data)

In [None]:
"""
labels
left (class 0) right (class 1) foot (class 2) tongue (class 3)

channels
c3(7) cz(9) c4(11)
"""
results = {"A01": {}, "A02": {}, "A03": {}, "A04": {}, "A05": {}, "A06": {}, "A07": {}, "A08": {}, "A09": {}}
labels = {"left": 0, "right": 1}
events = ["left", "right"]
select_channels = [7, 9, 11]
debug = True
training = False

# train model on each subject individually
data_list = []
for subject in results.keys():
  data_list.append(subject)

# train model on individual subject
# data_list = []
# data_list.append("A09")

for data_name in data_list:
  accuracy = 0
  precision = 0
  recall = 0
  f1 = 0
  kappa = 0
    
  X, Y = [], []
  for event in epochs[data_name].keys():
    for i in range(len(events)):
      if event == events[i]:
        if len(X) == 0:
          X = epochs[data_name][event].get_data()
          Y = np.zeros(len(epochs[data_name][event].get_data())) + i
        else:
          X = np.append(X, epochs[data_name][event].get_data(), axis=0)
          Y = np.append(Y, np.zeros(len(epochs[data_name][event].get_data())) + i, axis=0)
  X = np.array(X)
  Y = np.array(Y)
    
  skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)
  for train_index, test_index in skf.split(X, Y):
    X_train, X_test = X[train_index], X[test_index]
    Y_train, Y_test = Y[train_index], Y[test_index]
    
    # pick c3, cZ, c4 channels
    X_train = X_train[:, select_channels, :]
    X_test = X_test[:, select_channels, :]

    print(data_name)
    print("Training...")
    X_train, Y_train = stft_min_max(X_train, Y_train, debug)
    print("Testing...")
    X_test, Y_test = stft_min_max(X_test, Y_test, debug)

    if debug:
      print("shape of X_train and Y_train: " + str(X_train.shape) + " " + str(Y_train.shape))
      print("shape of X_test and Y_test: " + str(X_test.shape) + " " + str(Y_test.shape))

    if training:
      # create new model
      model = create_model()

      log_dir = DIRECTORY_PATH + "/logs/" + data_name + "/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
      tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
      optimizer = Adam(learning_rate=1e-5)
      model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
      model.fit(X_train, Y_train, validation_data=(X_test, Y_test), batch_size=32, epochs=200, callbacks=[tensorboard_callback], verbose=0)

      Y_hat = model.predict(X_test)
      Y_hat = (Y_hat >= 0.5)
      accuracy += accuracy_score(Y_test, Y_hat)
      precision += precision_score(Y_test, Y_hat, average="macro")
      recall += recall_score(Y_test, Y_hat, average="macro")
      f1 += f1_score(Y_test, Y_hat, average="macro")
      kappa += cohen_kappa_score(Y_test, Y_hat)
    
      # save model
      model.save_weights(DIRECTORY_PATH + "/models/" + data_name + "_" + str(accuracy_score(Y_test, Y_hat))[:6] + "/")
    else:
      # load pretrained model
      model = create_model()
      model.load_weights(DIRECTORY_PATH + "/models/" + "A09_0.9183/")
      # freeze model
      model.trainable = False
      optimizer = Adam(learning_rate=1e-5)
      model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
        
      Y_hat = model.predict(X_test)
      Y_hat = (Y_hat >= 0.5)
      accuracy += accuracy_score(Y_test, Y_hat)
      precision += precision_score(Y_test, Y_hat, average="macro")
      recall += recall_score(Y_test, Y_hat, average="macro")
      f1 += f1_score(Y_test, Y_hat, average="macro")
      kappa += cohen_kappa_score(Y_test, Y_hat)

  accuracy /= n_splits
  precision /= n_splits
  recall /= n_splits
  f1 /= n_splits
  kappa /= n_splits

  if debug:
    print("accuracy: " + str(accuracy))
    print("precision: " + str(precision))
    print("recall: " + str(recall))
    print("f1: " + str(f1))
    print("kappa: " + str(kappa))

  results[data_name]["accuracy"] = accuracy
  results[data_name]["precision"] = precision
  results[data_name]["recall"] = recall
  results[data_name]["f1"] = f1
  results[data_name]["kappa"] = kappa

In [None]:
# Calculate average performance
average_accuracy = 0
average_precision = 0
average_recall = 0
average_f1 = 0
average_kappa = 0
for key, value in results.items():
  average_accuracy += value["accuracy"]
  average_precision += value["precision"]
  average_recall += value["recall"]
  average_f1 += value["f1"]
  average_kappa += value["kappa"]

average_accuracy /= 9
average_precision /= 9
average_recall /= 9
average_f1 /= 9
average_kappa /= 9

print("average accuracy: " + str(average_accuracy))
print("average precision: " + str(average_precision))
print("average recall: " + str(average_recall))
print("average f1: " + str(average_f1))
print("average kappa: " + str(average_kappa))

# CNN Classification (reconstructed data)

In [None]:
"""
labels
left (class 0) right (class 1) foot (class 2) tongue (class 3)

channels
c3(7) cz(9) c4(11)
"""

results = {"A01": {}, "A02": {}, "A03": {}, "A04": {}, "A05": {}, "A06": {}, "A07": {}, "A08": {}, "A09": {}}
labels = {"left": 0, "right": 1}
select_channels = [7, 9, 11]
debug = True
training = False

# train model on each subject individually
data_list = []
for subject in results.keys():
  data_list.append(subject)

# train model on individual subject
# data_list = []
# data_list.append("A09")

for data_name in data_list:
  # load data from external storage
  directory_path = op.join(EXTERNAL_STORAGE_PATH, "all motor region", "data", "reconstructed eeg", data_name)
  counter = 0
  accuracy = 0
  precision = 0
  recall = 0
  f1 = 0
  kappa = 0
  while(counter < n_splits):
    counter += 1
    X_train = np.load(op.join(directory_path, str(counter)+"_train_X.npz"), allow_pickle=True)["data"]
    X_test = np.load(op.join(directory_path, str(counter)+"_test_X.npz"), allow_pickle=True)["data"]
    Y_train = np.load(op.join(directory_path, str(counter)+"_train_Y.npz"), allow_pickle=True)["data"]
    Y_test = np.load(op.join(directory_path, str(counter)+"_test_Y.npz"), allow_pickle=True)["data"]
    
    # pick c3, cZ, c4 channels
    X_train = X_train[:, select_channels, :]
    X_test = X_test[:, select_channels, :]

    print(data_name)
    print("Training...")
    X_train, Y_train = stft_min_max(X_train, Y_train, debug)
    print("Testing...")
    X_test, Y_test = stft_min_max(X_test, Y_test, debug)

    if debug:
      print("shape of X_train and Y_train: " + str(X_train.shape) + " " + str(Y_train.shape))
      print("shape of X_test and Y_test: " + str(X_test.shape) + " " + str(Y_test.shape))

    if training:
      # create new model
      model = create_model()
      
      log_dir = DIRECTORY_PATH + "/logs/" + data_name + "/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
      tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
      optimizer = Adam(learning_rate=1e-5)
      model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
      model.fit(X_train, Y_train, validation_data=(X_test, Y_test), batch_size=32, epochs=200, callbacks=[tensorboard_callback], verbose=0)

      Y_hat = model.predict(X_test)
      Y_hat = (Y_hat >= 0.5)
      accuracy += accuracy_score(Y_test, Y_hat)
      precision += precision_score(Y_test, Y_hat, average="macro")
      recall += recall_score(Y_test, Y_hat, average="macro")
      f1 += f1_score(Y_test, Y_hat, average="macro")
      kappa += cohen_kappa_score(Y_test, Y_hat)

      # save model
      model.save_weights(DIRECTORY_PATH + "/models/" + data_name + "_" + str(accuracy_score(Y_test, Y_hat))[:6] + "/")
    else:
      # load pretrained model
      model = create_model()
      model.load_weights(DIRECTORY_PATH + "/models/" + "A09_0.9183/")
      # freeze model
      model.trainable = False
      optimizer = Adam(learning_rate=1e-5)
      model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
        
      Y_hat = model.predict(X_test)
      Y_hat = (Y_hat >= 0.5)
      accuracy += accuracy_score(Y_test, Y_hat)
      precision += precision_score(Y_test, Y_hat, average="macro")
      recall += recall_score(Y_test, Y_hat, average="macro")
      f1 += f1_score(Y_test, Y_hat, average="macro")
      kappa += cohen_kappa_score(Y_test, Y_hat)

  accuracy /= n_splits
  precision /= n_splits
  recall /= n_splits
  f1 /= n_splits
  kappa /= n_splits
  if debug:
    print("accuracy: " + str(accuracy))
    print("precision: " + str(precision))
    print("recall: " + str(recall))
    print("f1: " + str(f1))
    print("kappa: " + str(kappa))

  results[data_name]["accuracy"] = accuracy
  results[data_name]["precision"] = precision
  results[data_name]["recall"] = recall
  results[data_name]["f1"] = f1
  results[data_name]["kappa"] = kappa

In [None]:
# Calculate average performance
average_accuracy = 0
average_precision = 0
average_recall = 0
average_f1 = 0
average_kappa = 0
for key, value in results.items():
  average_accuracy += value["accuracy"]
  average_precision += value["precision"]
  average_recall += value["recall"]
  average_f1 += value["f1"]
  average_kappa += value["kappa"]

average_accuracy /= 9
average_precision /= 9
average_recall /= 9
average_f1 /= 9
average_kappa /= 9

print("average accuracy: " + str(average_accuracy))
print("average precision: " + str(average_precision))
print("average recall: " + str(average_recall))
print("average f1: " + str(average_f1))
print("average kappa: " + str(average_kappa))

# CNN Classification (Auto-select Source Activity)

In [None]:
def get_inverse_and_forward_information(epochs, events=["left", "right"]):
    
    subject = "A01"
    X, Y = [], []
    info = None
    for event in epochs[subject].keys():
        if info is None:
            info = epochs[subject][event].info
        for i in range(len(events)):
            if event == events[i]:
                if len(X) == 0:
                    X = epochs[subject][event].get_data()
                    Y = np.zeros(len(epochs[subject][event].get_data())) + i
                else:
                    X = np.append(X, epochs[subject][event].get_data(), axis=0)
                    Y = np.append(Y, np.zeros(len(epochs[subject][event].get_data())) + i, axis=0)
    X = np.array(X)
    Y = np.array(Y)
    X_epochs = mne.EpochsArray(X, info, verbose=False)
    X_evoked = X_epochs.average().pick("eeg")

    noise_cov = mne.compute_covariance(X_epochs, tmax=0., method=['shrunk', 'empirical'], rank=None, verbose=False)
    fwd = mne.make_forward_solution(info, trans=trans, src=src,
                    bem=bem, eeg=True, meg=False, mindist=5.0, n_jobs=1, verbose=False)
    fwd_fixed = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                 use_cps=True, verbose=False)
    leadfield = fwd_fixed['sol']['data']
    inverse_operator = make_inverse_operator(info, fwd, noise_cov, loose=0.2, depth=0.8, verbose=False)

    method = "sLORETA"
    snr = 3.
    lambda2 = 1. / snr ** 2
    stc = apply_inverse(X_evoked, inverse_operator, lambda2, method=method, pick_ori="normal", verbose=True)

    # get motor region points
    my_source = stc
    mni_lh = mne.vertex_to_mni(my_source.vertices[0], 0, mne_subject)
    print(mni_lh.shape)
    mni_rh = mne.vertex_to_mni(my_source.vertices[1], 1, mne_subject)
    print(mni_rh.shape)

    """
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(projection='3d')
    ax.scatter(mm_coords.reshape(-1, 3)[brodmann_motor][:, 0], mm_coords.reshape(-1, 3)[brodmann_motor][:, 1], mm_coords.reshape(-1, 3)[brodmann_motor][:, 2], s=15, marker='|')
    ax.scatter(mni_lh[:, 0], mni_lh[:, 1], mni_lh[:, 2], s=15, marker='_')
    ax.scatter(mni_rh[:, 0], mni_rh[:, 1], mni_rh[:, 2], s=15, marker='_')
    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')
    plt.show()
    """

    my_left_points = in_hull(mni_lh, mm_coords.reshape(-1, 3)[brodmann_motor])
    my_right_points = in_hull(mni_rh, mm_coords.reshape(-1, 3)[brodmann_motor])

    mni_left_motor = mne.vertex_to_mni(my_source.vertices[0][my_left_points], 0, mne_subject)
    print(mni_left_motor.shape)
    mni_right_motor = mne.vertex_to_mni(my_source.vertices[1][my_right_points], 1, mne_subject)
    print(mni_right_motor.shape)

    """
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(projection='3d')
    ax.scatter(mni_lh[:, 0], mni_lh[:, 1], mni_lh[:, 2], s=15, marker='|')
    ax.scatter(mni_rh[:, 0], mni_rh[:, 1], mni_rh[:, 2], s=15, marker='_')
    ax.scatter(mni_left_motor[:, 0], mni_left_motor[:, 1], mni_left_motor[:, 2], s=15, marker='o')
    ax.scatter(mni_right_motor[:, 0], mni_right_motor[:, 1], mni_right_motor[:, 2], s=15, marker='^')
    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')
    plt.show()
    """

    print("Leadfield size : %d sensors x %d dipoles" % leadfield.shape)
    print(stc.data.shape)

    information = {"my_left_points": my_left_points, 
                   "my_right_points": my_right_points, 
                   "stc_data_shape": stc.data.shape, 
                   "leadfield": leadfield,
                   "left_vertices": stc.vertices[0],
                   "right_vertices": stc.vertices[1]}

    return information

In [None]:
information = get_inverse_and_forward_information(epochs)
print(information)

In [None]:
print(information["leadfield"][:, :len(information["left_vertices"])][:, information["my_left_points"]].shape)
print(information["leadfield"][:, -len(information["right_vertices"]):][:, information["my_right_points"]].shape)
forward_matrix = np.concatenate((information["leadfield"][:, :len(information["left_vertices"])][:, information["my_left_points"]], 
                          information["leadfield"][:, -len(information["right_vertices"]):][:, information["my_right_points"]]),
                          axis = 1)
print(forward_matrix.shape)    

In [None]:
class Stft_Min_Max(keras.layers.Layer):
    def __init__(self):
        super(Stft_Min_Max, self).__init__()

    def call(self, inputs):
        Zxx = tf.signal.stft(inputs, frame_length=256, frame_step=16)
        Zxx = tf.abs(Zxx)
        X = Zxx[:, :, :, :40]
        X = tf.reshape(X, [tf.shape(X)[0], -1, 40])
        X = tf.transpose(X, perm=[0, 2, 1])
        X = tf.expand_dims(X, axis=3)

        # min max scaling (per instance)
        #original_shape = tf.shape(X)
        original_shape = [tf.shape(X)[0], 40, 48, 1]
        X = tf.reshape(X, [original_shape[0], -1])
        X_max = tf.math.reduce_max(X, axis=1, keepdims=True)
        X_min = tf.math.reduce_min(X, axis=1, keepdims=True)
        X = tf.math.divide(tf.math.subtract(X, X_min), tf.math.subtract(X_max, X_min))
        X = tf.reshape(X, original_shape)
        
        return X


class AutoSelect(tf.keras.Model):
    def __init__(self, select_channels, forward_matrix, random_select):
        super(AutoSelect, self).__init__()
        # preprocessing
        self.select_channels = select_channels
        self.forward_matrix = tf.transpose(tf.constant(forward_matrix), perm=[1, 0])
        self.concatenate = Concatenate(axis=1)
        self.stft_min_max = Stft_Min_Max()
        self.dropout = Dropout(0.5)
        self.source_select = Dense(forward_matrix.shape[1], activation='sigmoid')
        self.random_select = random_select
        
        # classifier
        self.conv1 = Conv2D(filters=4, kernel_size=(3, 3), strides=(1, 1), padding='same', activation="selu")
        self.bn1 = BatchNormalization()
        self.mp1 = MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="valid")
        self.conv2 = Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', activation="selu")
        self.bn2 = BatchNormalization()
        self.mp2 = MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="valid")
        self.flatten = Flatten()
        self.dense1 = Dense(50, activation="selu")
        self.dense2 = Dense(1, activation="sigmoid")

    def call(self, inputs):                       # (n, 7981, 500)
        # preprocessing
        x = tf.transpose(inputs, perm=[0, 2, 1])  # (n, 500, 7981)
        
        if self.random_select:
            x = self.dropout(x)                   # (n, 500, 7981)
        else:
            source_select = self.source_select(x) # (n, 500, 7981)
            x = x * source_select                 # (n, 500, 7981)
        
        x = tf.matmul(x, self.forward_matrix)     # (n, 500, 22)
        x = tf.transpose(x, perm=[0, 2, 1])       # (n, 22, 500)
        c3 = tf.expand_dims(x[:, self.select_channels[0], :], axis=1)
        cZ = tf.expand_dims(x[:, self.select_channels[1], :], axis=1)
        c4 = tf.expand_dims(x[:, self.select_channels[2], :], axis=1)
        x = self.concatenate([c3, cZ, c4])        # (n, 3, 500)
        x = self.stft_min_max(x)                  # (n, 40, 48, 1)
        
        # classifier
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.mp1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.mp2(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        
        return x

In [None]:
"""
labels
left (class 0) right (class 1) foot (class 2) tongue (class 3)

channels
c3(7) cz(9) c4(11)
"""

results = {"A01": {}, "A02": {}, "A03": {}, "A04": {}, "A05": {}, "A06": {}, "A07": {}, "A08": {}, "A09": {}}
labels = {"left": 0, "right": 1}
select_channels = [7, 9, 11]
debug = True
training = True

# train model on each subject individually
data_list = []
for subject in results.keys():
  data_list.append(subject)

# train model on individual subject
# data_list = []
# data_list.append("A09")

for data_name in data_list:
  # load data from external storage
  directory_path = op.join(EXTERNAL_STORAGE_PATH, "all motor region", "data", "source activity", data_name)
  counter = 0
  accuracy = 0
  precision = 0
  recall = 0
  f1 = 0
  kappa = 0
  while(counter < n_splits):
    counter += 1
    X_train = tf.constant(np.load(op.join(directory_path, str(counter)+"_train_X.npz"), allow_pickle=True)["data"].astype(np.float32))
    X_test = tf.constant(np.load(op.join(directory_path, str(counter)+"_test_X.npz"), allow_pickle=True)["data"].astype(np.float32))
    Y_train = tf.constant(np.load(op.join(directory_path, str(counter)+"_train_Y.npz"), allow_pickle=True)["data"].astype(np.float32))
    Y_test = tf.constant(np.load(op.join(directory_path, str(counter)+"_test_Y.npz"), allow_pickle=True)["data"].astype(np.float32))
    
    if debug:
      print("shape of X_train and Y_train: " + str(X_train.shape) + " " + str(Y_train.shape))
      print("shape of X_test and Y_test: " + str(X_test.shape) + " " + str(Y_test.shape))

    if training:
      # create new model
      model = AutoSelect(select_channels, forward_matrix, False)
      
      log_dir = DIRECTORY_PATH + "/logs/" + data_name + "/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
      tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
      optimizer = Adam(learning_rate=1e-5)
      model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
      model.fit(X_train, Y_train, validation_data=(X_test, Y_test), batch_size=32, epochs=200, callbacks=[tensorboard_callback], verbose=1)
        
      Y_hat = model.predict(X_test)
      Y_hat = (Y_hat >= 0.5)
      accuracy += accuracy_score(Y_test, Y_hat)
      precision += precision_score(Y_test, Y_hat, average="macro")
      recall += recall_score(Y_test, Y_hat, average="macro")
      f1 += f1_score(Y_test, Y_hat, average="macro")
      kappa += cohen_kappa_score(Y_test, Y_hat)

      # save model
      model.save_weights(DIRECTORY_PATH + "/models/" + data_name + "_" + str(accuracy_score(Y_test, Y_hat))[:6] + "/")
    else:
      # load pretrained model
      model = AutoSelect(select_channels, forward_matrix, False)
      model.load_weights(DIRECTORY_PATH + "/models/" + "A09_0.9183/")
      # freeze model
      model.trainable = False
      optimizer = Adam(learning_rate=1e-5)
      model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
        
      Y_hat = model.predict(X_test)
      Y_hat = (Y_hat >= 0.5)
      accuracy += accuracy_score(Y_test, Y_hat)
      precision += precision_score(Y_test, Y_hat, average="macro")
      recall += recall_score(Y_test, Y_hat, average="macro")
      f1 += f1_score(Y_test, Y_hat, average="macro")
      kappa += cohen_kappa_score(Y_test, Y_hat)
    
    del X_train, Y_train, X_test, Y_test
    gc.collect()

  accuracy /= n_splits
  precision /= n_splits
  recall /= n_splits
  f1 /= n_splits
  kappa /= n_splits
  if debug:
    print("accuracy: " + str(accuracy))
    print("precision: " + str(precision))
    print("recall: " + str(recall))
    print("f1: " + str(f1))
    print("kappa: " + str(kappa))

  results[data_name]["accuracy"] = accuracy
  results[data_name]["precision"] = precision
  results[data_name]["recall"] = recall
  results[data_name]["f1"] = f1
  results[data_name]["kappa"] = kappa

In [None]:
# Calculate average performance
average_accuracy = 0
average_precision = 0
average_recall = 0
average_f1 = 0
average_kappa = 0
for key, value in results.items():
  average_accuracy += value["accuracy"]
  average_precision += value["precision"]
  average_recall += value["recall"]
  average_f1 += value["f1"]
  average_kappa += value["kappa"]

average_accuracy /= 9
average_precision /= 9
average_recall /= 9
average_f1 /= 9
average_kappa /= 9

print("average accuracy: " + str(average_accuracy))
print("average precision: " + str(average_precision))
print("average recall: " + str(average_recall))
print("average f1: " + str(average_f1))
print("average kappa: " + str(average_kappa))