# Import Library

In [None]:
import mne
import numpy as np
import os
import os.path as op
import matplotlib.pyplot as plt
import nibabel as nib
from mne.datasets import sample
from mne.minimum_norm import make_inverse_operator, apply_inverse_epochs, apply_inverse, apply_inverse_raw
from mne.datasets import fetch_fsaverage
from mne.decoding import CSP, UnsupervisedSpatialFilter
import scipy.io
from scipy.io import loadmat
from scipy.spatial import Delaunay
from scipy import stats
import PIL
from PIL import Image
import datetime
import time
import tensorflow as tf
from tensorflow import keras
from keras.preprocessing.image import ImageDataGenerator
from keras import Sequential, Model
from keras.layers import Conv1D, MaxPool1D, AveragePooling1D, Conv2D, DepthwiseConv2D, SeparableConv2D, MaxPool2D, AveragePooling2D, GlobalAveragePooling2D, Dense, Activation, Flatten, Concatenate, BatchNormalization, LayerNormalization, Dropout, Input
from keras.constraints import max_norm
from keras.layers.merge import concatenate
from tensorflow.keras.optimizers import Adam
from tensorboard.backend.event_processing import event_accumulator
import pandas as pd
# Load the TensorBoard notebook extension
%load_ext tensorboard
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, cohen_kappa_score
from sklearn.model_selection import StratifiedKFold, KFold, train_test_split
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.pipeline import make_pipeline
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import gc
import json
import multiprocessing
from scipy.spatial import KDTree
from nibabel.nifti1 import Nifti1Image
from nilearn import plotting

%matplotlib inline
#%matplotlib qt

DIRECTORY_PATH = os.getcwd()
EXTERNAL_STORAGE_PATH = "E:\Motor Imagery"
RECONSTRUCT_SAVE_FOLDER = "all motor region"
n_splits = 5

# force tensorflow to use cpu when facing memory issue
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

# gpus = tf.config.list_physical_devices('GPU')
# if gpus:
#   try:
#     # Currently, memory growth needs to be the same across GPUs
#     for gpu in gpus:
#       tf.config.experimental.set_memory_growth(gpu, True)
#     logical_gpus = tf.config.list_logical_devices('GPU')
#     print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
#   except RuntimeError as e:
#     # Memory growth must be set before GPUs have been initialized
#     print(e)
    
mne.set_log_level("ERROR")

# Preprocess Data

In [None]:
channels_mapping = {
    "EEG-Fz": "Fz",
    "EEG-0": "FC3",
    "EEG-1": "FC1",
    "EEG-2": "FCz",
    "EEG-3": "FC2",
    "EEG-4": "FC4",
    "EEG-5": "C5",
    "EEG-C3": "C3", 
    "EEG-6": "C1",
    "EEG-Cz": "Cz",
    "EEG-7": "C2",
    "EEG-C4": "C4",
    "EEG-8": "C6",
    "EEG-9": "CP3",
    "EEG-10": "CP1",
    "EEG-11": "CPz",
    "EEG-12": "CP2",
    "EEG-13": "CP4",
    "EEG-14": "P1",
    "EEG-Pz": "Pz",
    "EEG-15": "P2",
    "EEG-16": "POz",
    "EOG-left": "EOG-left",
    "EOG-central": "EOG-central",
    "EOG-right": "EOG-right"
}

channels_type_mapping = {
    "Fz": "eeg",
    "FC3": "eeg",
    "FC1": "eeg",
    "FCz": "eeg",
    "FC2": "eeg",
    "FC4": "eeg",
    "C5": "eeg",
    "C3": "eeg", 
    "C1": "eeg",
    "Cz": "eeg",
    "C2": "eeg",
    "C4": "eeg",
    "C6": "eeg",
    "CP3": "eeg",
    "CP1": "eeg",
    "CPz": "eeg",
    "CP2": "eeg",
    "CP4": "eeg",
    "P1": "eeg",
    "Pz": "eeg",
    "P2": "eeg",
    "POz": "eeg",
    "EOG-left": "eog",
    "EOG-central": "eog",
    "EOG-right": "eog"
}

In [None]:
img = nib.load("/Users/ivanlim/Downloads/MRIcron_windows/MRIcron/Resources/templates/brodmann.nii.gz")
ch2_img = nib.load("C:/Users/ivanlim/Downloads/MRIcron_windows/MRIcron/Resources/templates/ch2.nii.gz")

brodmann_data = img.get_fdata()
# Areas 3, 1 and 2 – Primary somatosensory cortex in the postcentral gyrus (frequently referred to as Areas 3, 1, 2 by convention)
# Area 4– Primary motor cortex
# Area 5 – Superior parietal lobule
# Area 6 – Premotor cortex and Supplementary Motor Cortex (Secondary Motor Cortex) (Supplementary motor area)
# Area 7 – Visuo-Motor Coordination
brodmann_motor = []
selected_area = [1, 2, 3, 4, 5, 6, 7]
#selected_area = [4]

# old: compute one convex hull for all the selected regions
# new: compute a single convex hull for one selected region then combine them together
reconstruct_mode = "new"

for area in selected_area:
    if reconstruct_mode == "old":
        if len(brodmann_motor) == 0:
            brodmann_motor.append(brodmann_data.reshape(-1) == area)
        else:
            brodmann_motor[0] += brodmann_data.reshape(-1) == area
    else:
        brodmann_motor.append(brodmann_data.reshape(-1) == area)

print(brodmann_motor)
print("brodmann template shape: " + str(brodmann_data.shape))
if reconstruct_mode == "old":
    print("chosen points: " + str(np.sum(brodmann_motor[0])))
else:
    chosen_points = None
    for selected_region in brodmann_motor:
        print(np.sum(selected_region))
        if chosen_points is None:
            chosen_points = np.array(selected_region, copy=True)
        else:
            chosen_points += selected_region
    print("chosen points: " + str(np.sum(chosen_points)))

shape, affine = img.shape[:3], img.affine
coords = np.array(np.meshgrid(*(range(i) for i in shape), indexing='ij'))
coords = np.rollaxis(coords, 0, len(shape) + 1)
mm_coords = nib.affines.apply_affine(affine, coords)

def in_hull(p, hull):
    """
    Test if points in `p` are in `hull`

    `p` should be a `NxK` coordinates of `N` points in `K` dimensions
    `hull` is either a scipy.spatial.Delaunay object or the `MxK` array of the 
    coordinates of `M` points in `K`dimensions for which Delaunay triangulation
    will be computed
    """
    if not isinstance(hull,Delaunay):
        hull = Delaunay(hull)

    return hull.find_simplex(p)>=0

my_left_points = None
my_right_points = None

In [None]:
""""
labels utility function
"""
def load_subject_labels(name="A01E.mat", dir="drive/Shareddrives/Motor Imagery/BCI competition IV dataset/2a/2a true_labels/"):
  data = scipy.io.loadmat(dir + name)["classlabel"].reshape(-1)
  return data

def load_all_true_labels(dataset_path):
  data = {}
  for root, dirs, files in os.walk(dataset_path):
    for file in files:
      data[file] = load_subject_labels(name=file, dir=root) 
  return data

"""
plot graph utility function
"""
def plot_average_graph(subject_name="A01T.gdf", Class="left", filter_channels=None):
  average = {"left": None, "right": None, "foot": None, "tongue": None, "unknown": None}
  for event_class, event_data in data[subject_name]["epoch_data"].items():
    if event_data != []:
      average[event_class] = np.transpose(np.mean(event_data, axis=0))

  x = average[Class]
  if filter_channels is None:
    fig, axs = plt.subplots(x.shape[1], gridspec_kw={'hspace': 0})
    fig.set_size_inches(37, 21)
    for channel in range(x.shape[1]):
      axs[channel].title.set_text(ch_names[channel])
      axs[channel].title.set_size(20)
      axs[channel].title.set_y(0.7)
      axs[channel].plot(range(x.shape[0]), x[:, channel])
      axs[channel].axvline(x=250, color="r", linestyle='--')
      #axs[channel].axvline(x=875, color="r", linestyle='--')
  else :
    fig, axs = plt.subplots(len(filter_channels), gridspec_kw={'hspace': 0})
    fig.set_size_inches(37, 10.5)
    for i in range(len(filter_channels)):
      for channel in range(x.shape[1]):
        if(filter_channels[i] == ch_names[channel]):
          axs[i].title.set_text(ch_names[channel])
          axs[i].title.set_size(20)
          axs[i].title.set_y(0.7)
          axs[i].plot(range(x.shape[0]), x[:, channel])
          axs[i].axvline(x=250, color="r", linestyle='--')
          #axs[i].axvline(x=875, color="r", linestyle='--')
          break
  plt.tight_layout()

def plot_multiple_graph(subject_name="A02T.gdf", classes=["left", "right", "foot", "tongue"], filter_channels=None):
  average = {"left": None, "right": None, "foot": None, "tongue": None, "unknown": None}
  for event_class, event_data in data[subject_name]["epoch_data"].items():
    if event_data != []:
      average[event_class] = np.transpose(np.mean(event_data, axis=0))

  color = {"left": "b", "right": "g", "foot": "c", "tongue": "m", "tongue": "y"}
  x = []
  for Class in classes:
    x.append(average[Class])

  if filter_channels is None:
    fig, axs = plt.subplots(x[0].shape[1], gridspec_kw={'hspace': 0})
    fig.set_size_inches(37, 21)
    for channel in range(x[0].shape[1]):
      axs[channel].title.set_text(ch_names[channel])
      axs[channel].title.set_size(20)
      axs[channel].title.set_y(0.7)
      axs[channel].axvline(x=250, color="r", linestyle='--')
      #axs[channel].axvline(x=875, color="r", linestyle='--')
      for i in range(len(classes)):
        axs[channel].plot(range(x[i].shape[0]), x[i][:, channel], color=color[classes[i]])
  else:
    fig, axs = plt.subplots(len(filter_channels), gridspec_kw={'hspace': 0})
    fig.set_size_inches(37, 10.5)
    for i in range(len(filter_channels)):
      for channel in range(x[0].shape[1]):
        if(filter_channels[i] == ch_names[channel]):
          axs[i].title.set_text(ch_names[channel])
          axs[i].title.set_size(20)
          axs[i].title.set_y(0.7)
          axs[i].axvline(x=250, color="r", linestyle='--')
          #axs[i].axvline(x=875, color="r", linestyle='--')
          for j in range(len(classes)):
            axs[i].plot(range(x[j].shape[0]), x[j][:, channel], color=color[classes[j]])
          break
  plt.tight_layout()

"""
load data function
"""
def load_epoch_data(raw, name, debug=None):
  if debug:
    raw.plot()
  subject_data = {}

  subject_data["raw"] = raw
  subject_data["info"] = raw.info
  if debug == "all":
    print("-------------------------------------------------------------------------------------------------------")
    for key, item in raw.info.items():
      print(key, item)
    print("-------------------------------------------------------------------------------------------------------")

  """
  '276': 'Idling EEG (eyes open)'
  '277': 'Idling EEG (eyes closed)'
  '768': 'Start of a trial'
  '769': 'Cue onset left (class 1)'
  '770': 'Cue onset right (class 2)'
  '771': 'Cue onset foot (class 3)'
  '772': 'Cue onset tongue (class 4)'
  '783': 'Cue unknown'
  '1023': 'Rejected trial'
  '1072': 'Eye movements'
  '32766': 'Start of a new run'
  """
  custom_mapping = {'276': 276, '277': 277, '768': 768, '769': 769, '770': 770, '771': 771, '772': 772, '783': 783, '1023': 1023, '1072': 1072, '32766': 32766}
  events_from_annot, event_dict = mne.events_from_annotations(raw, event_id=custom_mapping)

  if debug == " all":
    print("-------------------------------------------------------------------------------------------------------")
    print(event_dict)
    print(events_from_annot)
    print("-------------------------------------------------------------------------------------------------------")

    print("-------------------------------------------------------------------------------------------------------")
    for i in range(len(raw.annotations)):
      print(events_from_annot[i], raw.annotations[i])  
    print("-------------------------------------------------------------------------------------------------------")

  class_info = "Idling EEG (eyes open): " + str(len(events_from_annot[events_from_annot[:, 2]==276][:, 0])) + "\n" + \
               "Idling EEG (eyes closed): " + str(len(events_from_annot[events_from_annot[:, 2]==277][:, 0])) + "\n" + \
               "Start of a trial: " + str(len(events_from_annot[events_from_annot[:, 2]==768][:, 0])) + "\n" + \
               "Cue onset left (class 1): " + str(len(events_from_annot[events_from_annot[:, 2]==769][:, 0])) + "\n" + \
               "Cue onset right (class 2): " + str(len(events_from_annot[events_from_annot[:, 2]==770][:, 0])) + "\n" + \
               "Cue onset foot (class 3): " + str(len(events_from_annot[events_from_annot[:, 2]==771][:, 0])) + "\n" + \
               "Cue onset tongue (class 4): " + str(len(events_from_annot[events_from_annot[:, 2]==772][:, 0])) + "\n" + \
               "Cue unknown: " + str(len(events_from_annot[events_from_annot[:, 2]==783][:, 0])) + "\n" + \
               "Rejected trial: " + str(len(events_from_annot[events_from_annot[:, 2]==1023][:, 0])) + "\n" + \
               "Eye movements: " + str(len(events_from_annot[events_from_annot[:, 2]==1072][:, 0])) + "\n" + \
               "Start of a new run: " + str(len(events_from_annot[events_from_annot[:, 2]==32766][:, 0]))
  subject_data["class_info"] = class_info

  if debug == "all" or debug == "important":
    print("-------------------------------------------------------------------------------------------------------")
    print(class_info)
    print("-------------------------------------------------------------------------------------------------------")

  epoch_data = {"left": [], "right": [], "foot": [], "tongue": [], "unknown": []}
  rejected_trial = events_from_annot[events_from_annot[:, 2]==1023][:, 0]
  class_dict = {"left": 769, "right": 770, "foot": 771, "tongue": 772, "unknown": 783}
  raw_data = raw.get_data()  #(22, 672528)
  start = 10                 # cue+0.1s
  stop = 510                 # cue+2.1s

  for event_class, event_id in class_dict.items():
    current_event = events_from_annot[events_from_annot[:, 2]==event_id][:, 0]
    if event_class == "unknown":
      subject_true_labels = true_labels[name[:4]+".mat"]
      class_dict_labels = {1: "left", 2: "right", 3: "foot", 4: "tongue"}
      for i in range(len(current_event)):
        # exclude artifact
        if (current_event[i] - 500 != rejected_trial).all():
          current_event_data = np.expand_dims(np.array(raw_data[:22, current_event[i]+start:current_event[i]+stop]), axis=0)
          if (epoch_data.get(class_dict_labels[subject_true_labels[i]]) == None).all():
            epoch_data[class_dict_labels[subject_true_labels[i]]] = current_event_data
          else:
            epoch_data[class_dict_labels[subject_true_labels[i]]] = np.append(epoch_data[class_dict_labels[subject_true_labels[i]]], current_event_data, axis=0)
    else:
      for i in range(len(current_event)):
        # exclude artifact
        if((current_event[i] - 500 != rejected_trial).all()):
          epoch_data[event_class].append(np.array(raw_data[:22, current_event[i]+start:current_event[i]+stop]))
      epoch_data[event_class] = np.array(epoch_data[event_class])

  if debug == "all" or debug == "important":
    print("-------------------------------------------------------------------------------------------------------")
    for key, data in epoch_data.items():
      print(key, len(data))
    print("-------------------------------------------------------------------------------------------------------")

  for event_class, event_data in epoch_data.items():
    epoch_data[event_class] = np.array(event_data)

  subject_data["epoch_data"] = epoch_data

  return subject_data

def load_subject(name="A01T.gdf", dir='drive/Shareddrives/Motor Imagery/BCI competition IV dataset/2a/BCICIV_2a_gdf/', filter_bank=None, debug=None):
  # Load data
  raw = mne.io.read_raw_gdf(dir + name, preload=True)
  # Rename channels
  raw.rename_channels(channels_mapping)
  # Set channels types
  raw.set_channel_types(channels_type_mapping)
  # Set montage
  # Read and set the EEG electrode locations
  ten_twenty_montage = mne.channels.make_standard_montage('standard_1020')
  raw.set_montage(ten_twenty_montage)
  # Drop eog channels
  raw.drop_channels(["EOG-left", "EOG-central", "EOG-right"])
  # Set common average reference
  raw.set_eeg_reference('average', projection=True, verbose=False)
    
  subject_data_dict = {}

  if filter_bank is None:
    filter_bank = [-1]
  else:
    filter_bank = [-1] + filter_bank
  print("filter bank: ", filter_bank)

  for i in range(len(filter_bank)-1):
    low = filter_bank[i]
    high = filter_bank[i+1]
    print("current frequency: ", low, " to ", high)
  
    # filter frequency
    filter_raw = raw.copy()
    if low != -1:
        iir_params = dict(order=5, ftype='butter')
        filter_raw.filter(low, high, method="iir", iir_params=iir_params)
    subject_data = load_epoch_data(filter_raw, name, debug)
    if low == -1:
      subject_data_dict["original"] = subject_data
    else:
      subject_data_dict[str(low)+"-"+str(high)] = subject_data

  return subject_data_dict

def load_all_subject(dataset_path, filter_bank=None):
  data = {}
  for root, dirs, files in os.walk(dataset_path):
    for file in files:
      data[file] = load_subject(name=file, dir=root, filter_bank=filter_bank) 
  return data

"""
create mne epochs data structure from numpy array
merge training and evaluation data
return original and bandpass-filtered data
"""
def create_epochs(data):
  subjects_data = {}

  for subject in data.keys():
    if "E" in subject:
        continue
    epochs_data = {}
    for event in data[subject]["original"]["epoch_data"].keys():
      current_event_data = None
      
      if data[subject]["original"]["epoch_data"][event].any():
        current_event_data = data[subject]["original"]["epoch_data"][event]
      if data[subject[:3]+"E.gdf"]["original"]["epoch_data"][event].any():
        current_event_data = np.append(current_event_data, data[subject[:3]+"E.gdf"]["original"]["epoch_data"][event], axis=0)
      if current_event_data is not None:
          epochs_data[event] = mne.EpochsArray(current_event_data, data[subject]["original"]["info"], verbose=False)

    subjects_data[subject[:3]] = epochs_data
    
  subjects_filter_data = {}
  
  for subject in data.keys():
    if "E" in subject:
        continue
    epochs_data = {}
    for freq in data[subject].keys():
        if freq != "original":
            #print(freq.split("-"))
            epochs_data[freq] = {}
        
            for event in data[subject][freq]["epoch_data"].keys():
                current_event_data = None

                if data[subject][freq]["epoch_data"][event].any():
                    current_event_data = data[subject][freq]["epoch_data"][event]
                if data[subject[:3]+"E.gdf"][freq]["epoch_data"][event].any():
                    current_event_data = np.append(current_event_data, data[subject[:3]+"E.gdf"][freq]["epoch_data"][event], axis=0)
                if current_event_data is not None:
                    epochs_data[freq][event] = mne.EpochsArray(current_event_data, data[subject][freq]["info"], verbose=False)
                
            subjects_filter_data[subject[:3]] = epochs_data
            
  return subjects_data, subjects_filter_data

"""
Create source activity and reconstructed eeg respectively for each subject

For each subject, there are four events in total, i.e. left, right, foot, tongue
Split these data into train and test set using kfold
Compute the noise covariance matrix on train set and apply it to test set
Create source activity (only motor region) first by applying an inverse operator to the epochs 
Create reconstructed eeg by applying a forward operator to the source activity acquired earlier
Save both these files to disk
"""
def apply_inverse_and_forward_kfold(epochs, n_splits=5, save_inverse=True, save_forward=True, events=["left", "right"], subjects=None):
    global my_left_points, my_right_points
    
    if subjects is None:
        subjects = epochs.keys()
        
    for subject in subjects:  
        print(subject)
        X, Y = [], []
        info = None
        counter = 0
        for event in epochs[subject].keys():
            if info is None:
                info = epochs[subject][event].info
            for i in range(len(events)):
                if event == events[i]:
                    print(event)
                    if len(X) == 0:
                        X = epochs[subject][event].get_data()
                        Y = np.zeros(len(epochs[subject][event].get_data())) + i
                    else:
                        X = np.append(X, epochs[subject][event].get_data(), axis=0)
                        Y = np.append(Y, np.zeros(len(epochs[subject][event].get_data())) + i, axis=0)
        X = np.array(X)
        Y = np.array(Y)
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)
        for train_index, test_index in skf.split(X, Y):
            counter += 1
            X_train, X_test = X[train_index], X[test_index]
            Y_train, Y_test = Y[train_index], Y[test_index]
            
            X_train = mne.EpochsArray(X_train, info, verbose=False)
            X_test = mne.EpochsArray(X_test, info, verbose=False)
            
            noise_cov = mne.compute_covariance(X_train, tmax=0., method=['shrunk', 'empirical'], rank=None, verbose=False)
            fwd = mne.make_forward_solution(info, trans=trans, src=src,
                            bem=bem, eeg=True, meg=False, mindist=5.0, n_jobs=1, verbose=False)
            fwd_fixed = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                         use_cps=True, verbose=False)
            leadfield = fwd_fixed['sol']['data']
            inverse_operator = make_inverse_operator(info, fwd, noise_cov, loose=0.2, depth=0.8, verbose=False)
            
            method = "sLORETA"
            snr = 3.
            lambda2 = 1. / snr ** 2
            stc_train = apply_inverse_epochs(X_train, inverse_operator, lambda2,
                                          method=method, pick_ori="normal", verbose=True)
            
            # get motor region points (once)
            if my_left_points is None and my_right_points is None:
                my_source = stc_train[0]
                mni_lh = mne.vertex_to_mni(my_source.vertices[0], 0, mne_subject)
                print(mni_lh.shape)
                mni_rh = mne.vertex_to_mni(my_source.vertices[1], 1, mne_subject)
                print(mni_rh.shape)

                fig = plt.figure(figsize=(8, 8))
                ax = fig.add_subplot(projection='3d')
                for selected_region in brodmann_motor:
                    ax.scatter(mm_coords.reshape(-1, 3)[selected_region][:, 0], mm_coords.reshape(-1, 3)[selected_region][:, 1], mm_coords.reshape(-1, 3)[selected_region][:, 2], s=15, marker='|')
                ax.scatter(mni_lh[:, 0], mni_lh[:, 1], mni_lh[:, 2], s=15, marker='_')
                ax.scatter(mni_rh[:, 0], mni_rh[:, 1], mni_rh[:, 2], s=15, marker='_')
                ax.set_xlabel('X')
                ax.set_ylabel('Y')
                ax.set_zlabel('Z')
                plt.show()

                my_left_points = None
                my_right_points = None
                for selected_region in brodmann_motor:
                    print(np.sum(selected_region))
                    if my_left_points is None:
                        my_left_points = in_hull(mni_lh, mm_coords.reshape(-1, 3)[selected_region])
                        my_right_points = in_hull(mni_rh, mm_coords.reshape(-1, 3)[selected_region])
                    else:
                        my_left_points += in_hull(mni_lh, mm_coords.reshape(-1, 3)[selected_region])
                        my_right_points += in_hull(mni_rh, mm_coords.reshape(-1, 3)[selected_region])

                mni_left_motor = mne.vertex_to_mni(my_source.vertices[0][my_left_points], 0, mne_subject)
                print(mni_left_motor.shape)
                mni_right_motor = mne.vertex_to_mni(my_source.vertices[1][my_right_points], 1, mne_subject)
                print(mni_right_motor.shape)

                fig = plt.figure(figsize=(8, 8))
                ax = fig.add_subplot(projection='3d')
                ax.scatter(mni_lh[:, 0], mni_lh[:, 1], mni_lh[:, 2], s=15, marker='|')
                ax.scatter(mni_rh[:, 0], mni_rh[:, 1], mni_rh[:, 2], s=15, marker='_')
                ax.scatter(mni_left_motor[:, 0], mni_left_motor[:, 1], mni_left_motor[:, 2], s=15, marker='o')
                ax.scatter(mni_right_motor[:, 0], mni_right_motor[:, 1], mni_right_motor[:, 2], s=15, marker='^')
                ax.set_xlabel('X')
                ax.set_ylabel('Y')
                ax.set_zlabel('Z')
                plt.show()
                
            print("Leadfield size : %d sensors x %d dipoles" % leadfield.shape)
            print(stc_train[0].data.shape)
            
            # train set
            # slice source activity data
            left_hemi_data = []
            right_hemi_data = []
            for source in stc_train:
                left_hemi_data.append(source.data[:len(source.vertices[0])][my_left_points])
                right_hemi_data.append(source.data[-len(source.vertices[1]):][my_right_points])
            left_hemi_data = np.array(left_hemi_data)
            right_hemi_data = np.array(right_hemi_data)
            if save_inverse:
                source_activity_path = op.join(EXTERNAL_STORAGE_PATH, RECONSTRUCT_SAVE_FOLDER, "data", "source activity", subject)
                if not op.exists(source_activity_path):
                    os.makedirs(source_activity_path)
                np.savez_compressed(op.join(source_activity_path, str(counter)+"_train_X.npz"), data=np.append(left_hemi_data, right_hemi_data, axis=1))
                np.savez_compressed(op.join(source_activity_path, str(counter)+"_train_Y.npz"), data=Y_train)
            # slice reconstructed eeg data
            reconstructed_eeg_data = []
            for source in stc_train:
                motor_source = np.zeros_like(source.data)
                motor_source[:len(source.vertices[0])][my_left_points] = source.data[:len(source.vertices[0])][my_left_points]
                motor_source[-len(source.vertices[1]):][my_right_points] = source.data[-len(source.vertices[1]):][my_right_points]
                motor_eeg = np.dot(leadfield, motor_source)
                reconstructed_eeg_data.append(motor_eeg)
            if save_forward:
                reconstructed_eeg_path = op.join(EXTERNAL_STORAGE_PATH, RECONSTRUCT_SAVE_FOLDER, "data", "reconstructed eeg", subject)
                if not op.exists(reconstructed_eeg_path):
                    os.makedirs(reconstructed_eeg_path)
                np.savez_compressed(op.join(reconstructed_eeg_path, str(counter)+"_train_X.npz"), data=np.array(reconstructed_eeg_data))
                np.savez_compressed(op.join(reconstructed_eeg_path, str(counter)+"_train_Y.npz"), data=Y_train)
            
            del stc_train
            gc.collect()
            
            stc_test = apply_inverse_epochs(X_test, inverse_operator, lambda2,
                              method=method, pick_ori="normal", verbose=True)
            # test set
            # slice source activity data
            left_hemi_data = []
            right_hemi_data = []
            for source in stc_test:
                left_hemi_data.append(source.data[:len(source.vertices[0])][my_left_points])
                right_hemi_data.append(source.data[-len(source.vertices[1]):][my_right_points])
            left_hemi_data = np.array(left_hemi_data)
            right_hemi_data = np.array(right_hemi_data)
            if save_inverse:
                source_activity_path = op.join(EXTERNAL_STORAGE_PATH, RECONSTRUCT_SAVE_FOLDER, "data", "source activity", subject)
                if not op.exists(source_activity_path):
                    os.makedirs(source_activity_path)
                np.savez_compressed(op.join(source_activity_path, str(counter)+"_test_X.npz"), data=np.append(left_hemi_data, right_hemi_data, axis=1))
                np.savez_compressed(op.join(source_activity_path, str(counter)+"_test_Y.npz"), data=Y_test)
            # slice reconstructed eeg data
            reconstructed_eeg_data = []
            for source in stc_test:
                motor_source = np.zeros_like(source.data)
                motor_source[:len(source.vertices[0])][my_left_points] = source.data[:len(source.vertices[0])][my_left_points]
                motor_source[-len(source.vertices[1]):][my_right_points] = source.data[-len(source.vertices[1]):][my_right_points]
                motor_eeg = np.dot(leadfield, motor_source)
                reconstructed_eeg_data.append(motor_eeg)
            if save_forward:
                reconstructed_eeg_path = op.join(EXTERNAL_STORAGE_PATH, RECONSTRUCT_SAVE_FOLDER, "data", "reconstructed eeg", subject)
                if not op.exists(reconstructed_eeg_path):
                    os.makedirs(reconstructed_eeg_path)
                np.savez_compressed(op.join(reconstructed_eeg_path, str(counter)+"_test_X.npz"), data=np.array(reconstructed_eeg_data))
                np.savez_compressed(op.join(reconstructed_eeg_path, str(counter)+"_test_Y.npz"), data=Y_test)
            
            del X_train, X_test, Y_train, Y_test
            del stc_test, reconstructed_eeg_data, left_hemi_data, right_hemi_data
            gc.collect()
            
            
def get_inverse_and_forward_information(epochs, events=["left", "right"]):
    
    subject = "A01"
    X, Y = [], []
    info = None
    for event in epochs[subject].keys():
        if info is None:
            info = epochs[subject][event].info
        for i in range(len(events)):
            if event == events[i]:
                if len(X) == 0:
                    X = epochs[subject][event].get_data()
                    Y = np.zeros(len(epochs[subject][event].get_data())) + i
                else:
                    X = np.append(X, epochs[subject][event].get_data(), axis=0)
                    Y = np.append(Y, np.zeros(len(epochs[subject][event].get_data())) + i, axis=0)
    X = np.array(X)
    Y = np.array(Y)
    X_epochs = mne.EpochsArray(X, info, verbose=False)
    X_evoked = X_epochs.average().pick("eeg")

    noise_cov = mne.compute_covariance(X_epochs, tmax=0., method=['shrunk', 'empirical'], rank=None, verbose=False)
    fwd = mne.make_forward_solution(info, trans=trans, src=src,
                    bem=bem, eeg=True, meg=False, mindist=5.0, n_jobs=1, verbose=False)
    fwd_fixed = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                 use_cps=True, verbose=False)
    leadfield = fwd_fixed['sol']['data']
    inverse_operator = make_inverse_operator(info, fwd, noise_cov, loose=0.2, depth=0.8, verbose=False)

    method = "sLORETA"
    snr = 3.
    lambda2 = 1. / snr ** 2
    stc = apply_inverse(X_evoked, inverse_operator, lambda2, method=method, pick_ori="normal", verbose=True)

    # get motor region points
    my_source = stc
    mni_lh = mne.vertex_to_mni(my_source.vertices[0], 0, mne_subject)
    print(mni_lh.shape)
    mni_rh = mne.vertex_to_mni(my_source.vertices[1], 1, mne_subject)
    print(mni_rh.shape)

    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(projection='3d')
    ax.scatter(mni_lh[:, 0], mni_lh[:, 1], mni_lh[:, 2], s=15, marker='_', alpha=0.5)
    ax.scatter(mni_rh[:, 0], mni_rh[:, 1], mni_rh[:, 2], s=15, marker='_', alpha=0.5)
    for selected_region in brodmann_motor:
        ax.scatter(mm_coords.reshape(-1, 3)[selected_region][:, 0], mm_coords.reshape(-1, 3)[selected_region][:, 1], mm_coords.reshape(-1, 3)[selected_region][:, 2], s=15, marker='|', alpha=0.2)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    #fig.savefig(DIRECTORY_PATH+"/brodmann template and source positions.png", dpi=1200)
    plt.show()

    my_left_points = None
    my_right_points = None
    for selected_region in brodmann_motor:
        print(np.sum(selected_region))
        if my_left_points is None:
            my_left_points = in_hull(mni_lh, mm_coords.reshape(-1, 3)[selected_region])
            my_right_points = in_hull(mni_rh, mm_coords.reshape(-1, 3)[selected_region])
        else:
            my_left_points += in_hull(mni_lh, mm_coords.reshape(-1, 3)[selected_region])
            my_right_points += in_hull(mni_rh, mm_coords.reshape(-1, 3)[selected_region])

    mni_left_motor = mne.vertex_to_mni(my_source.vertices[0][my_left_points], 0, mne_subject)
    print(mni_left_motor.shape)
    mni_right_motor = mne.vertex_to_mni(my_source.vertices[1][my_right_points], 1, mne_subject)
    print(mni_right_motor.shape)

    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(projection='3d')
    ax.scatter(mni_lh[:, 0], mni_lh[:, 1], mni_lh[:, 2], s=15, marker='|', alpha=0.3)
    ax.scatter(mni_rh[:, 0], mni_rh[:, 1], mni_rh[:, 2], s=15, marker='_', alpha=0.3)
    ax.scatter(mni_left_motor[:, 0], mni_left_motor[:, 1], mni_left_motor[:, 2], s=15, marker='o', alpha=0.5)
    ax.scatter(mni_right_motor[:, 0], mni_right_motor[:, 1], mni_right_motor[:, 2], s=15, marker='^', alpha=0.5)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    #fig.savefig(DIRECTORY_PATH+"/united_convex_hull.png", dpi=1200)
    plt.show()

    print("Leadfield size : %d sensors x %d dipoles" % leadfield.shape)
    print(stc.data.shape)

    information = {"my_left_points": my_left_points, 
                   "my_right_points": my_right_points, 
                   "stc_data_shape": stc.data.shape, 
                   "leadfield": leadfield,
                   "left_vertices": stc.vertices[0],
                   "right_vertices": stc.vertices[1],
                   "inverse_operator": inverse_operator}

    return information

def EEGInception(num_class = 2, num_channel=3, num_samples=500):
    Input_block = Input(shape = (num_channel, num_samples, 1))
    drop_rate = 0.3
    block1 = Conv2D(8, (1, 64), padding='same')(Input_block)
    block1 = BatchNormalization()(block1)
    block1 = Activation('elu')(block1)
    block1 = Dropout(drop_rate)(block1)

    block1 = DepthwiseConv2D((num_channel, 1), padding='valid', depth_multiplier = 2)(block1)
    block1 = BatchNormalization()(block1)
    block1 = Activation('elu')(block1)
    block1 = Dropout(drop_rate)(block1)

    #================================

    block2 = Conv2D(8, (1, 32), padding='same')(Input_block)
    block2 = BatchNormalization()(block2)
    block2 = Activation('elu')(block2)
    block2 = Dropout(drop_rate)(block2)

    block2 = DepthwiseConv2D((num_channel, 1), padding='valid', depth_multiplier = 2)(block2)
    block2 = BatchNormalization()(block2)
    block2 = Activation('elu')(block2)
    block2 = Dropout(drop_rate)(block2)

    #================================

    block3 = Conv2D(8, (1, 16), padding='same')(Input_block)
    block3 = BatchNormalization()(block3)
    block3 = Activation('elu')(block3)
    block3 = Dropout(drop_rate)(block3)

    block3 = DepthwiseConv2D((num_channel, 1), padding='valid', depth_multiplier = 2)(block3)
    block3 = BatchNormalization()(block3)
    block3 = Activation('elu')(block3)
    block3 = Dropout(drop_rate)(block3)

    #================================

    block = Concatenate(axis = -1)([block1, block2, block3])
    block = AveragePooling2D((1, 4))(block)

    #================================

    block1_1 = Conv2D(8, (1, 16), padding='same')(block)
    block1_1 = BatchNormalization()(block1_1)
    block1_1 = Activation('elu')(block1_1)
    block1_1 = Dropout(drop_rate)(block1_1)

    #================================

    block2_1 = Conv2D(8, (1, 8), padding='same')(block)
    block2_1 = BatchNormalization()(block2_1)
    block2_1 = Activation('elu')(block2_1)
    block2_1 = Dropout(drop_rate)(block2_1)

    #================================

    block3_1 = Conv2D(8, (1, 4), padding='same')(block)
    block3_1 = BatchNormalization()(block3_1)
    block3_1 = Activation('elu')(block3_1)
    block3_1 = Dropout(drop_rate)(block3_1)

    #================================

    block_new = Concatenate(axis = -1)([block1_1, block2_1, block3_1])
    block_new = AveragePooling2D((1, 2))(block_new)

    block_new = Conv2D(12, (1, 8), padding='same')(block_new)
    block_new = BatchNormalization()(block_new)
    block_new = Activation('elu')(block_new)
    block_new = Dropout(drop_rate)(block_new)

    block_new = AveragePooling2D((1, 2))(block_new)

    block_new = Conv2D(6, (1, 4), padding='same')(block_new)
    block_new = BatchNormalization()(block_new)
    block_new = Activation('elu')(block_new)
    block_new = Dropout(drop_rate)(block_new)

    block_new = AveragePooling2D((1, 2))(block_new)

    embedded = Flatten()(block_new)
    
    if num_class == 2:
        out = Dense(1, activation = 'sigmoid')(embedded)
    else:
        out = Dense(num_class, activation = 'softmax')(embedded)
        
    return Model(inputs = Input_block, outputs = out)

"""
Total params: 699,685
Trainable params: 0
Non-trainable params: 699,685
"""
def create_model(num_class = 2, model_name="default", num_channel=3):
    if model_name == "default":
        if num_class == 2:
            model = tf.keras.models.Sequential([
                Conv2D(filters=4, kernel_size=(3, 3), strides=(1, 1), padding='same', activation="selu"),
                BatchNormalization(),
                MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="valid"),
                Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', activation="selu"),
                BatchNormalization(),
                MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="valid"),
                Flatten(),
                Dense(50, activation="selu"),
                Dense(1, activation="sigmoid")
            ])
        else:
            model = tf.keras.models.Sequential([
                Conv2D(filters=4, kernel_size=(3, 3), strides=(1, 1), padding='same', activation="selu"),
                BatchNormalization(),
                MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="valid"),
                Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', activation="selu"),
                BatchNormalization(),
                MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="valid"),
                Flatten(),
                Dense(50, activation="selu"),
                Dense(num_class, activation="softmax")
            ])
    elif model_name == "eegnet":
        if num_class == 2:
            model = tf.keras.models.Sequential([
                Conv2D(16, (1, 64), use_bias = False, activation = 'linear', padding='same', name = 'Spectral_filter'),
                BatchNormalization(),
                DepthwiseConv2D((3, 1), use_bias = False, padding='valid', depth_multiplier = 2, activation = 'linear',
                depthwise_constraint = tf.keras.constraints.MaxNorm(max_value=1), name = 'Spatial_filter'),
                BatchNormalization(),
                Activation('elu'),
                AveragePooling2D((1, 4)),
                Dropout(0.5),
                SeparableConv2D(32, (1, 16), use_bias = False, activation = 'linear', padding = 'same'),
                BatchNormalization(),
                Activation('elu'),
                AveragePooling2D((1, 8)),
                Dropout(0.5),
                Flatten(),
                Dense(1, activation = 'sigmoid', kernel_constraint = max_norm(0.25))
            ])
        else:
            model = tf.keras.models.Sequential([
                Conv2D(16, (1, 64), use_bias = False, activation = 'linear', padding='same', name = 'Spectral_filter'),
                BatchNormalization(),
                DepthwiseConv2D((3, 1), use_bias = False, padding='valid', depth_multiplier = 2, activation = 'linear',
                depthwise_constraint = tf.keras.constraints.MaxNorm(max_value=1), name = 'Spatial_filter'),
                BatchNormalization(),
                Activation('elu'),
                AveragePooling2D((1, 4)),
                Dropout(0.5),
                SeparableConv2D(32, (1, 16), use_bias = False, activation = 'linear', padding = 'same'),
                BatchNormalization(),
                Activation('elu'),
                AveragePooling2D((1, 8)),
                Dropout(0.5),
                Flatten(),
                Dense(num_class, activation = 'softmax', kernel_constraint = max_norm(0.25))
            ])
    elif model_name == "eegnet_inception":
        model = EEGInception(num_class=num_class, num_channel=num_channel)

    return model

def stft_min_max(X, Y, debug=True):
    Zxx = tf.signal.stft(X, frame_length=256, frame_step=16)
    Zxx = tf.abs(Zxx)

    if debug:
      print("shape of X and Y: " + str(X.shape) + " " + str(Y.shape))
      print("shape of Zxx: " + str(Zxx.shape))

      # plot spectrogram
      #samples = 0
      #print(Y[samples])
      #log_spec = tf.math.log(tf.transpose(Zxx[samples][0]))
      #height = 40
      #width = log_spec.shape[1]
      #x_axis = tf.linspace(0, 2, num=width)
      #y_axis = range(height)
      #plt.pcolormesh(x_axis, y_axis, log_spec[:40, ])
      #plt.title('Short Time Fourier Transform Magnitude')
      #plt.ylabel('Frequency [Hz]')
      #plt.xlabel('Time [sec]')
      #plt.show()
    
    X = Zxx[:, :, :, :40]
    X = tf.reshape(X, [X.shape[0], -1, 40])
    X = tf.transpose(X, perm=[0, 2, 1])
    X = tf.expand_dims(X, axis=3)
    
    # min max scaling (per instance)
    original_shape = X.shape
    X = tf.reshape(X, [original_shape[0], -1])
    X_max = tf.math.reduce_max(X, axis=1, keepdims=True)
    X_min = tf.math.reduce_min(X, axis=1, keepdims=True)
    X = tf.math.divide(tf.math.subtract(X, X_min), tf.math.subtract(X_max, X_min))
    X = tf.reshape(X, original_shape)
    
    if debug:
      print("shape of X and Y: " + str(X.shape) + " " + str(Y.shape))

      # plot spectrogram
      #samples = 0
      #print(Y[samples])
      #log_spec = tf.math.log(X[samples][:,:16,0])
      #height = 40
      #width = log_spec.shape[1]
      #x_axis = tf.linspace(0, 2, num=width)
      #y_axis = range(height)
      #plt.pcolormesh(x_axis, y_axis, log_spec)
      #plt.title('Short Time Fourier Transform Magnitude')
      #plt.ylabel('Frequency [Hz]')
      #plt.xlabel('Time [sec]')
      #plt.show()

    return X, Y

In [None]:
class Stft_Min_Max(keras.layers.Layer):
    def __init__(self, num_of_channels):
        super(Stft_Min_Max, self).__init__()
        self.num_of_channels = num_of_channels

    def call(self, inputs):
        Zxx = tf.signal.stft(inputs, frame_length=256, frame_step=16)
        Zxx = tf.abs(Zxx)
        X = Zxx[:, :, :, :40]
        X = tf.reshape(X, [tf.shape(X)[0], -1, 40])
        X = tf.transpose(X, perm=[0, 2, 1])
        X = tf.expand_dims(X, axis=3)

        # min max scaling (per instance)
        #original_shape = tf.shape(X)
        original_shape = [tf.shape(X)[0], 40, 16*self.num_of_channels, 1]
        X = tf.reshape(X, [original_shape[0], -1])
        X_max = tf.math.reduce_max(X, axis=1, keepdims=True)
        X_min = tf.math.reduce_min(X, axis=1, keepdims=True)
        X = tf.math.divide(tf.math.subtract(X, X_min), tf.math.subtract(X_max, X_min))
        X = tf.reshape(X, original_shape)
        
        return X
    
def my_mask(x):
  return tf.cast(tf.greater_equal(x, 1), tf.float32)

def diff_mask(mask_op):
  @tf.custom_gradient
  def _diff_mask(x):
    def grad(dy):
      return dy * tf.ones_like(x)
    return mask_op(x), grad
  return _diff_mask

"""
Total params: 64,404,027 / 42,089,607
Trainable params: 64,404,003 / 42,089,607
Non-trainable params: 24
"""
class AutoSelect(tf.keras.Model):
    def __init__(self, select_channels, forward_matrix, random_select, use_mask, kqv, model_name="default"):
        super(AutoSelect, self).__init__()
        # preprocessing
        self.select_channels = select_channels
        self.forward_matrix = tf.transpose(tf.constant(forward_matrix), perm=[1, 0])
        self.concatenate = Concatenate(axis=1)
        self.stft_min_max = Stft_Min_Max(len(select_channels))
        self.dropout = Dropout(0.5)
        self.sigmoid = tf.keras.layers.Activation('sigmoid')
        #self.tanh = tf.keras.layers.Activation('tanh')
        #self.relu = tf.keras.layers.Activation('relu')
        
        # random select
        self.random_select = random_select
        
        # model option
        self.model_name = model_name
        
        # mask attention
        self.use_mask = use_mask
        self.mask = tf.Variable(np.ones((1, forward_matrix.shape[1])), dtype=tf.float32)
        #self.mask = tf.Variable(np.random.rand(1, forward_matrix.shape[1])+0.5, dtype=tf.float32)
        
        # correlation attention
        self.kqv = kqv
        self.concatenate_1D = Concatenate(axis=1)
        self.p1D_1 = MaxPool1D(pool_size=3, strides=1, padding='same', data_format='channels_first')
        self.conv1D_1 = Conv1D(16, 5, padding="same", activation="relu", data_format='channels_first')
        self.conv1D_2 = Conv1D(16, 10, padding="same", activation="relu", data_format='channels_first')
        self.conv1D_3 = Conv1D(16, 25, padding="same", activation="relu", data_format='channels_first')
        self.conv1D_4 = Conv1D(1, 5, padding="same", activation="sigmoid", data_format='channels_first')
        
        # dense attention
        self.source_select = Dense(forward_matrix.shape[1], activation=None)
        
        # classifier
        self.conv1 = Conv2D(filters=4, kernel_size=(3, 3), strides=(1, 1), padding='same', activation="selu")
        self.bn1 = BatchNormalization()
        self.mp1 = MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="valid")
        self.conv2 = Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', activation="selu")
        self.bn2 = BatchNormalization()
        self.mp2 = MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="valid")
        self.flatten = Flatten()
        self.dense1 = Dense(50, activation="selu")
        self.dense2 = Dense(1, activation="sigmoid")
        
        # eegnet
        if model_name == "eegnet":
            self.eegnet_model = Sequential([
                Conv2D(16, (1, 64), use_bias = False, activation = 'linear', padding='same', name = 'Spectral_filter'),
                BatchNormalization(),
                DepthwiseConv2D((3, 1), use_bias = False, padding='valid', depth_multiplier = 2, activation = 'linear',
                depthwise_constraint = max_norm(max_value=1), name = 'Spatial_filter'),
                BatchNormalization(),
                Activation('elu'),
                AveragePooling2D((1, 4)),
                Dropout(0.5),
                SeparableConv2D(32, (1, 16), use_bias = False, activation = 'linear', padding = 'same'),
                BatchNormalization(),
                Activation('elu'),
                AveragePooling2D((1, 8)),
                Dropout(0.5),
                Flatten(),
                Dense(1, activation = 'sigmoid', kernel_constraint = max_norm(0.25))
            ])
            
        # eeg-inception
        if model_name == "eegnet_inception":
            self.eegnet_inception_model = create_model(num_class = 2, model_name=model_name, num_channel=len(select_channels))

    def call(self, inputs):                                                                         # (n, 7981/6433, 500)
        # preprocessing
        x = tf.transpose(inputs, perm=[0, 2, 1])                                                    # (n, 500, 7981/6433)
            
        if self.random_select:
            x = self.dropout(x)                                                                     # (n, 500, 7981/6433)
        else:
            if self.use_mask:
                #tf.print(self.mask)
                mask = diff_mask(my_mask)(self.mask)                                                # (1, 7981/6433)
                x = x * mask                                                                        # (n, 500, 7981/6433)                  
            else:
                if self.kqv:
                    None
                else:
                    source_select = self.source_select(x)                                           # (n, 500, 7981/6433)
                    #source_select = self.dropout(source_select)   
                    source_select = self.sigmoid(source_select)
                    x = x * source_select                                                           # (n, 500, 7981/6433)
        
        x = tf.matmul(x, self.forward_matrix)                                                       # (n, 500, 22)
        x = tf.transpose(x, perm=[0, 2, 1])                                                         # (n, 22, 500)
        x = tf.gather(x, indices=self.select_channels, axis=1)                                      # (n, chan_num, 500)
        
#         # Normalization
#         x = tf.transpose(x, perm=[1, 0, 2])                                                         # (chan_num, n, 500)
#         x_mean = tf.reduce_mean(x, axis=[1, 2], keepdims=True)                                      # (chan_num, 1, 1)
#         x_std = tf.math.reduce_std(x, axis=[1, 2], keepdims=True)                                   # (chan_num, 1, 1)
#         x = (x-x_mean)/x_std                                                                        # (chan_num, n, 500)
#         x = tf.transpose(x, perm=[1, 0, 2])                                                         # (n, chan_num, 500)
        
        if self.model_name == "default":
            x = self.stft_min_max(x)                                                                # (n, 40, 48, 1)
            # classifier
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.mp1(x)
            x = self.conv2(x)
            x = self.bn2(x)
            x = self.mp2(x)
            x = self.flatten(x)
            x = self.dense1(x)
            x = self.dense2(x)
        elif self.model_name == "eegnet":
            x = tf.expand_dims(x, axis=-1)                                                          # (n, chan_num, 500, 1)
            x = self.eegnet_model(x)
        elif self.model_name == "eegnet_inception":
            x = tf.expand_dims(x, axis=-1)                                                          # (n, chan_num, 500, 1)
            x = self.eegnet_inception_model(x)
        
        return x
    
"""
Total params: 64,404,027 / 42,089,607
Trainable params: 64,404,003 / 42,089,607
Non-trainable params: 24
"""
class AutoSelectData(tf.keras.Model):
    def __init__(self, select_channels, forward_matrix, random_select, use_mask, kqv, model_name="default"):
        super(AutoSelectData, self).__init__()
        # preprocessing
        self.select_channels = select_channels
        self.forward_matrix = tf.transpose(tf.constant(forward_matrix), perm=[1, 0])
        self.concatenate = Concatenate(axis=1)
        self.stft_min_max = Stft_Min_Max(len(select_channels))
        self.dropout = Dropout(0.5)
        self.sigmoid = tf.keras.layers.Activation('sigmoid')
        #self.tanh = tf.keras.layers.Activation('tanh')
        #self.relu = tf.keras.layers.Activation('relu')
        
        # random select
        self.random_select = random_select
        
        # model option
        self.model_name = model_name
        
        # mask attention
        self.use_mask = use_mask
        self.mask = tf.Variable(np.ones((1, forward_matrix.shape[1])), dtype=tf.float32)
        #self.mask = tf.Variable(np.random.rand(1, forward_matrix.shape[1])+0.5, dtype=tf.float32)
        
        # correlation attention
        self.kqv = kqv
        self.concatenate_1D = Concatenate(axis=1)
        self.p1D_1 = MaxPool1D(pool_size=3, strides=1, padding='same', data_format='channels_first')
        self.conv1D_1 = Conv1D(16, 5, padding="same", activation="relu", data_format='channels_first')
        self.conv1D_2 = Conv1D(16, 10, padding="same", activation="relu", data_format='channels_first')
        self.conv1D_3 = Conv1D(16, 25, padding="same", activation="relu", data_format='channels_first')
        self.conv1D_4 = Conv1D(1, 5, padding="same", activation="sigmoid", data_format='channels_first')
        
        # dense attention
        self.source_select = Dense(forward_matrix.shape[1], activation=None)
        
        # classifier
        self.conv1 = Conv2D(filters=4, kernel_size=(3, 3), strides=(1, 1), padding='same', activation="selu")
        self.bn1 = BatchNormalization()
        self.mp1 = MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="valid")
        self.conv2 = Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', activation="selu")
        self.bn2 = BatchNormalization()
        self.mp2 = MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="valid")
        self.flatten = Flatten()
        self.dense1 = Dense(50, activation="selu")
        self.dense2 = Dense(1, activation="sigmoid")
        
        # eegnet
        if model_name == "eegnet":
            self.eegnet_model = Sequential([
                Conv2D(16, (1, 64), use_bias = False, activation = 'linear', padding='same', name = 'Spectral_filter'),
                BatchNormalization(),
                DepthwiseConv2D((3, 1), use_bias = False, padding='valid', depth_multiplier = 2, activation = 'linear',
                depthwise_constraint = max_norm(max_value=1), name = 'Spatial_filter'),
                BatchNormalization(),
                Activation('elu'),
                AveragePooling2D((1, 4)),
                Dropout(0.5),
                SeparableConv2D(32, (1, 16), use_bias = False, activation = 'linear', padding = 'same'),
                BatchNormalization(),
                Activation('elu'),
                AveragePooling2D((1, 8)),
                Dropout(0.5),
                Flatten(),
                Dense(1, activation = 'sigmoid', kernel_constraint = max_norm(0.25))
            ])

    def call(self, inputs):                                                                         # (n, 7981/6433, 500)
        # preprocessing
        x = tf.transpose(inputs, perm=[0, 2, 1])                                                    # (n, 500, 7981/6433)
            
        if self.random_select:
            x = self.dropout(x)                                                                     # (n, 500, 7981/6433)
        else:
            if self.use_mask:
                #tf.print(self.mask)
                mask = diff_mask(my_mask)(self.mask)                                                # (1, 7981/6433)
                x = x * mask                                                                        # (n, 500, 7981/6433)                  
            else:
                if self.kqv:
                    None
                else:
                    source_select = self.source_select(x)                                           # (n, 500, 7981/6433)
                    #source_select = self.dropout(source_select)   
                    source_select = self.sigmoid(source_select)
                    x = x * source_select                                                           # (n, 500, 7981/6433)
        
        x = tf.matmul(x, self.forward_matrix)                                                       # (n, 500, 22)
        x = tf.transpose(x, perm=[0, 2, 1])                                                         # (n, 22, 500)
        x = tf.gather(x, indices=self.select_channels, axis=1)                                      # (n, chan_num, 500)
    
        return x

## Masked Filter Bank Raw

In [None]:
"""
Create source activity from raw EEG signals for each subject
Select the most relevant sources from the source space by applying the trainable layer
Apply the forward matrix to obtain the EEG signals in channels space
Apply bandpass filtering in specified frequency ranges and epoching the raw EEG signals

Save the bandpass filtered files to disk
"""
def load_raw_subject(name="A01T.gdf", dir='drive/Shareddrives/Motor Imagery/BCI competition IV dataset/2a/BCICIV_2a_gdf/', filter_bank=None, debug=None):
  # Load data
  raw = mne.io.read_raw_gdf(dir + name, preload=True)
  # Rename channels
  raw.rename_channels(channels_mapping)
  # Set channels types
  raw.set_channel_types(channels_type_mapping)
  # Set montage
  # Read and set the EEG electrode locations
  ten_twenty_montage = mne.channels.make_standard_montage('standard_1020')
  raw.set_montage(ten_twenty_montage)
  # Drop eog channels
  raw.drop_channels(["EOG-left", "EOG-central", "EOG-right"])
  # Set common average reference
  raw.set_eeg_reference('average', projection=True, verbose=False)

  return raw


def apply_all_inverse_raw_bandpass_filter(dataset_path, epochs, subjects, filter_bank=None, save_filter=True):
    information = get_inverse_and_forward_information(epochs)
    print(information)

    print(information["leadfield"][:, :len(information["left_vertices"])][:, information["my_left_points"]].shape)
    print(information["leadfield"][:, -len(information["right_vertices"]):][:, information["my_right_points"]].shape)
    forward_matrix = np.concatenate((information["leadfield"][:, :len(information["left_vertices"])][:, information["my_left_points"]], 
                              information["leadfield"][:, -len(information["right_vertices"]):][:, information["my_right_points"]]),
                              axis = 1)
    print(forward_matrix.shape)  
    
    del information
    gc.collect()
    
    for root, dirs, files in os.walk(dataset_path):
      for file in files:
        if file.split(".")[0][:3] in subjects:
            print("current file:", file)
            raw = load_raw_subject(name=file, dir=root, filter_bank=filter_bank)
            raw_list = apply_inverse_raw_bandpass_filter(raw, epochs[file.split(".")[0][:3]], forward_matrix, file.split(".")[0], save_filter=save_filter, events=["left", "right"])

def apply_inverse_raw_bandpass_filter(raw, epoch, forward_matrix, file_name, save_filter=True, events=["left", "right"]):
    global my_left_points, my_right_points
    
    data_name = file_name[:3]
    
    use_csp = True
    if use_csp:
        select_channels = list(range(22))
    else:
        select_channels = [7, 9, 11]
    model_name = "eegnet"
    random_select = False
    use_mask = True
    kqv = False

    X, Y = [], []
    info = None
    counter = 0
    for event in epoch.keys():
        if info is None:
            info = epoch[event].info
        for i in range(len(events)):
            if event == events[i]:
                print(event)
                if len(X) == 0:
                    X = epoch[event].get_data()
                    Y = np.zeros(len(epoch[event].get_data())) + i
                else:
                    X = np.append(X, epoch[event].get_data(), axis=0)
                    Y = np.append(Y, np.zeros(len(epoch[event].get_data())) + i, axis=0)
    X = np.array(X)
    Y = np.array(Y)
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)
    for train_index, test_index in skf.split(X, Y):
        counter += 1
        X_train, X_test = X[train_index], X[test_index]
        Y_train, Y_test = Y[train_index], Y[test_index]

        X_train = mne.EpochsArray(X_train, info, verbose=False)
        X_test = mne.EpochsArray(X_test, info, verbose=False)

        noise_cov = mne.compute_covariance(X_train, tmax=0., method=['shrunk', 'empirical'], rank=None, verbose=False)
        fwd = mne.make_forward_solution(info, trans=trans, src=src,
                        bem=bem, eeg=True, meg=False, mindist=5.0, n_jobs=1, verbose=False)
        fwd_fixed = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                     use_cps=True, verbose=False)
        leadfield = fwd_fixed['sol']['data']
        inverse_operator = make_inverse_operator(info, fwd, noise_cov, loose=0.2, depth=0.8, verbose=False)

        method = "sLORETA"
        snr = 3.
        lambda2 = 1. / snr ** 2
        stc_raw = apply_inverse_raw(raw, inverse_operator, lambda2,
                                      method=method, pick_ori="normal", verbose=True)
        stc_train = apply_inverse_epochs(X_train, inverse_operator, lambda2,
                                      method=method, pick_ori="normal", verbose=True)

        # get motor region points (once)
        if my_left_points is None and my_right_points is None:
            my_source = stc_train[0]
            mni_lh = mne.vertex_to_mni(my_source.vertices[0], 0, mne_subject)
            print(mni_lh.shape)
            mni_rh = mne.vertex_to_mni(my_source.vertices[1], 1, mne_subject)
            print(mni_rh.shape)

            fig = plt.figure(figsize=(8, 8))
            ax = fig.add_subplot(projection='3d')
            for selected_region in brodmann_motor:
                ax.scatter(mm_coords.reshape(-1, 3)[selected_region][:, 0], mm_coords.reshape(-1, 3)[selected_region][:, 1], mm_coords.reshape(-1, 3)[selected_region][:, 2], s=15, marker='|')
            ax.scatter(mni_lh[:, 0], mni_lh[:, 1], mni_lh[:, 2], s=15, marker='_')
            ax.scatter(mni_rh[:, 0], mni_rh[:, 1], mni_rh[:, 2], s=15, marker='_')
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')
            plt.show()

            my_left_points = None
            my_right_points = None
            for selected_region in brodmann_motor:
                print(np.sum(selected_region))
                if my_left_points is None:
                    my_left_points = in_hull(mni_lh, mm_coords.reshape(-1, 3)[selected_region])
                    my_right_points = in_hull(mni_rh, mm_coords.reshape(-1, 3)[selected_region])
                else:
                    my_left_points += in_hull(mni_lh, mm_coords.reshape(-1, 3)[selected_region])
                    my_right_points += in_hull(mni_rh, mm_coords.reshape(-1, 3)[selected_region])

            mni_left_motor = mne.vertex_to_mni(my_source.vertices[0][my_left_points], 0, mne_subject)
            print(mni_left_motor.shape)
            mni_right_motor = mne.vertex_to_mni(my_source.vertices[1][my_right_points], 1, mne_subject)
            print(mni_right_motor.shape)

            fig = plt.figure(figsize=(8, 8))
            ax = fig.add_subplot(projection='3d')
            ax.scatter(mni_lh[:, 0], mni_lh[:, 1], mni_lh[:, 2], s=15, marker='|')
            ax.scatter(mni_rh[:, 0], mni_rh[:, 1], mni_rh[:, 2], s=15, marker='_')
            ax.scatter(mni_left_motor[:, 0], mni_left_motor[:, 1], mni_left_motor[:, 2], s=15, marker='o')
            ax.scatter(mni_right_motor[:, 0], mni_right_motor[:, 1], mni_right_motor[:, 2], s=15, marker='^')
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')
            plt.show()

        print("Leadfield size : %d sensors x %d dipoles" % leadfield.shape)
        print("stc_train[0] shape: ", stc_train[0].data.shape)
        
        del stc_train
        gc.collect()
        
        print("stc raw shape: ", stc_raw.data.shape)

        # load pretrained model
        model = AutoSelectData(select_channels, forward_matrix, random_select, use_mask, kqv, model_name)
        weight_path = "D:/forward and inverse results (new)/motor/ml motor/eegnet 22 channels/models/" + data_name + "/"

        for weight_file in os.listdir(weight_path):
          if weight_file.split("_")[0] == str(counter):
              break
        load_weights_file = os.path.join(weight_path, weight_file) + "/"
        model.load_weights(load_weights_file)
        #model.build((None, 6433, 500))
        #print(model.summary())
        model.trainable = False
        
        mask_weight = np.moveaxis(model.mask.numpy(), [0, 1], [1, 0])
        print("mask weight shape: ", mask_weight.shape)        
        
        left_hemi_data = stc_raw.data[:len(stc_raw.vertices[0])][my_left_points]
        right_hemi_data = stc_raw.data[-len(stc_raw.vertices[1]):][my_right_points]
        left_hemi_data = np.array(left_hemi_data)
        right_hemi_data = np.array(right_hemi_data)
        
        motor_source = np.append(left_hemi_data, right_hemi_data, axis=0)
        if save_filter:
#             source_activity_path = op.join(EXTERNAL_STORAGE_PATH, "all motor region filter", "raw", "source activity", file_name)
#             if not op.exists(source_activity_path):
#                 os.makedirs(source_activity_path)
#             np.savez_compressed(op.join(source_activity_path, str(counter)+".npz"), data=np.array(motor_source), info=raw.info)
            
            reconstructed_eeg_path = op.join(EXTERNAL_STORAGE_PATH, "all motor region filter", "raw", "reconstructed eeg", file_name)
            if not op.exists(reconstructed_eeg_path):
                os.makedirs(reconstructed_eeg_path)
            motor_eeg = np.dot(forward_matrix, motor_source)
            np.savez_compressed(op.join(reconstructed_eeg_path, str(counter)+".npz"), data=np.array(motor_eeg), info=raw.info)
        
        motor_source = motor_source * mask_weight
        if save_filter:
#             source_activity_path = op.join(EXTERNAL_STORAGE_PATH, "all motor region filter", "masked raw", "source activity", file_name)
#             if not op.exists(source_activity_path):
#                 os.makedirs(source_activity_path)
#             np.savez_compressed(op.join(source_activity_path, str(counter)+".npz"), data=np.array(motor_source), info=raw.info)
            
            reconstructed_eeg_path = op.join(EXTERNAL_STORAGE_PATH, "all motor region filter", "masked raw", "reconstructed eeg", file_name)
            if not op.exists(reconstructed_eeg_path):
                os.makedirs(reconstructed_eeg_path)
            motor_eeg = np.dot(forward_matrix, motor_source)
            print("motor eeg shape: ", motor_eeg.shape)
            np.savez_compressed(op.join(reconstructed_eeg_path, str(counter)+".npz"), data=np.array(motor_eeg), info=raw.info)
        
        del stc_raw
        del motor_source
        gc.collect()
        
def create_epochs_filter(data):
  subjects_data = {}

  for subject in data.keys():
    if "E" in subject:
        continue
    epochs_data = {}
    for counter in data[subject].keys():
      epochs_data[counter] = {}
      for event in data[subject][counter]["original"]["epoch_data"].keys():
        current_event_data = None
      
        if data[subject][counter]["original"]["epoch_data"][event].any():
          current_event_data = data[subject][counter]["original"]["epoch_data"][event]
        if data[subject[:3]+"E.gdf"][counter]["original"]["epoch_data"][event].any():
          current_event_data = np.append(current_event_data, data[subject[:3]+"E.gdf"][counter]["original"]["epoch_data"][event], axis=0)
        if current_event_data is not None:
          epochs_data[counter][event] = mne.EpochsArray(current_event_data, data[subject][counter]["original"]["info"], verbose=False)

    subjects_data[subject[:3]] = epochs_data
    
  subjects_filter_data = {}
  
  for subject in data.keys():
    if "E" in subject:
        continue
    epochs_data = {}
    for counter in data[subject].keys():
      epochs_data[counter] = {}
      for freq in data[subject][counter].keys():
          if freq != "original":
              #print(freq.split("-"))
              epochs_data[counter][freq] = {}
        
              for event in data[subject][counter][freq]["epoch_data"].keys():
                  current_event_data = None

                  if data[subject][counter][freq]["epoch_data"][event].any():
                      current_event_data = data[subject][counter][freq]["epoch_data"][event]
                  if data[subject[:3]+"E.gdf"][counter][freq]["epoch_data"][event].any():
                      current_event_data = np.append(current_event_data, data[subject[:3]+"E.gdf"][counter][freq]["epoch_data"][event], axis=0)
                  if current_event_data is not None:
                      epochs_data[counter][freq][event] = mne.EpochsArray(current_event_data, data[subject][counter][freq]["info"], verbose=False)
                
              subjects_filter_data[subject[:3]] = epochs_data
            
  return subjects_data, subjects_filter_data
    
def load_epoch_data_filter(raw, masked_raw, name, debug=None):
  if debug:
    raw.plot()
  subject_data = {}

  subject_data["raw"] = raw
  subject_data["info"] = raw.info
  if debug == "all":
    print("-------------------------------------------------------------------------------------------------------")
    for key, item in raw.info.items():
      print(key, item)
    print("-------------------------------------------------------------------------------------------------------")

  """
  '276': 'Idling EEG (eyes open)'
  '277': 'Idling EEG (eyes closed)'
  '768': 'Start of a trial'
  '769': 'Cue onset left (class 1)'
  '770': 'Cue onset right (class 2)'
  '771': 'Cue onset foot (class 3)'
  '772': 'Cue onset tongue (class 4)'
  '783': 'Cue unknown'
  '1023': 'Rejected trial'
  '1072': 'Eye movements'
  '32766': 'Start of a new run'
  """
  custom_mapping = {'276': 276, '277': 277, '768': 768, '769': 769, '770': 770, '771': 771, '772': 772, '783': 783, '1023': 1023, '1072': 1072, '32766': 32766}
  events_from_annot, event_dict = mne.events_from_annotations(raw, event_id=custom_mapping)

  if debug == " all":
    print("-------------------------------------------------------------------------------------------------------")
    print(event_dict)
    print(events_from_annot)
    print("-------------------------------------------------------------------------------------------------------")

    print("-------------------------------------------------------------------------------------------------------")
    for i in range(len(raw.annotations)):
      print(events_from_annot[i], raw.annotations[i])  
    print("-------------------------------------------------------------------------------------------------------")

  class_info = "Idling EEG (eyes open): " + str(len(events_from_annot[events_from_annot[:, 2]==276][:, 0])) + "\n" + \
               "Idling EEG (eyes closed): " + str(len(events_from_annot[events_from_annot[:, 2]==277][:, 0])) + "\n" + \
               "Start of a trial: " + str(len(events_from_annot[events_from_annot[:, 2]==768][:, 0])) + "\n" + \
               "Cue onset left (class 1): " + str(len(events_from_annot[events_from_annot[:, 2]==769][:, 0])) + "\n" + \
               "Cue onset right (class 2): " + str(len(events_from_annot[events_from_annot[:, 2]==770][:, 0])) + "\n" + \
               "Cue onset foot (class 3): " + str(len(events_from_annot[events_from_annot[:, 2]==771][:, 0])) + "\n" + \
               "Cue onset tongue (class 4): " + str(len(events_from_annot[events_from_annot[:, 2]==772][:, 0])) + "\n" + \
               "Cue unknown: " + str(len(events_from_annot[events_from_annot[:, 2]==783][:, 0])) + "\n" + \
               "Rejected trial: " + str(len(events_from_annot[events_from_annot[:, 2]==1023][:, 0])) + "\n" + \
               "Eye movements: " + str(len(events_from_annot[events_from_annot[:, 2]==1072][:, 0])) + "\n" + \
               "Start of a new run: " + str(len(events_from_annot[events_from_annot[:, 2]==32766][:, 0]))
  subject_data["class_info"] = class_info

  if debug == "all" or debug == "important":
    print("-------------------------------------------------------------------------------------------------------")
    print(class_info)
    print("-------------------------------------------------------------------------------------------------------")

  epoch_data = {"left": [], "right": [], "foot": [], "tongue": [], "unknown": []}
  rejected_trial = events_from_annot[events_from_annot[:, 2]==1023][:, 0]
  class_dict = {"left": 769, "right": 770, "foot": 771, "tongue": 772, "unknown": 783}
  raw_data = masked_raw.get_data()  #(22, 672528)
  start = 10                 # cue+0.1s
  stop = 510                 # cue+2.1s

  for event_class, event_id in class_dict.items():
    current_event = events_from_annot[events_from_annot[:, 2]==event_id][:, 0]
    if event_class == "unknown":
      subject_true_labels = true_labels[name[:4]+".mat"]
      class_dict_labels = {1: "left", 2: "right", 3: "foot", 4: "tongue"}
      for i in range(len(current_event)):
        # exclude artifact
        if (current_event[i] - 500 != rejected_trial).all():
          current_event_data = np.expand_dims(np.array(raw_data[:22, current_event[i]+start:current_event[i]+stop]), axis=0)
          if (epoch_data.get(class_dict_labels[subject_true_labels[i]]) == None).all():
            epoch_data[class_dict_labels[subject_true_labels[i]]] = current_event_data
          else:
            epoch_data[class_dict_labels[subject_true_labels[i]]] = np.append(epoch_data[class_dict_labels[subject_true_labels[i]]], current_event_data, axis=0)
    else:
      for i in range(len(current_event)):
        # exclude artifact
        if((current_event[i] - 500 != rejected_trial).all()):
          epoch_data[event_class].append(np.array(raw_data[:22, current_event[i]+start:current_event[i]+stop]))
      epoch_data[event_class] = np.array(epoch_data[event_class])

  if debug == "all" or debug == "important":
    print("-------------------------------------------------------------------------------------------------------")
    for key, data in epoch_data.items():
      print(key, len(data))
    print("-------------------------------------------------------------------------------------------------------")

  for event_class, event_data in epoch_data.items():
    epoch_data[event_class] = np.array(event_data)

  subject_data["epoch_data"] = epoch_data

  return subject_data
    
def load_subject_filter(name="A01T.gdf", dir='drive/Shareddrives/Motor Imagery/BCI competition IV dataset/2a/BCICIV_2a_gdf/', filter_bank=None, debug=None):    
  raw = load_raw_subject(name, dir, filter_bank, debug)
  
  directory_path = op.join(EXTERNAL_STORAGE_PATH, "all motor region filter", "masked raw", "reconstructed eeg")

  masked_raw = {}

  for subject_folder in os.listdir(directory_path):
    if name.split(".")[0] == subject_folder:
        print(name, subject_folder)
        
        for subject_file in os.listdir(op.join(directory_path, subject_folder)):
            masked_eeg_data = np.load(op.join(directory_path, subject_folder, subject_file), allow_pickle=True)["data"]
            masked_raw[subject_file.split(".")[0]] = mne.io.RawArray(masked_eeg_data, raw.info)
        break
    
  subject_data_dict = {}

  if filter_bank is None:
    filter_bank = [-1]
  else:
    filter_bank = [-1] + filter_bank
  print("filter bank: ", filter_bank)
  
  for counter, masked_raw_i in masked_raw.items():
    print(counter)
    subject_data_dict[counter] = {}
      
    for i in range(len(filter_bank)-1):
        low = filter_bank[i]
        high = filter_bank[i+1]
        print("current frequency: ", low, " to ", high)
        # filter frequency
        masked_filter_raw = masked_raw_i.copy()
        if low != -1:
          iir_params = dict(order=5, ftype='butter')
          masked_filter_raw.filter(low, high, method="iir", iir_params=iir_params)
        subject_data = load_epoch_data_filter(raw, masked_filter_raw, name, debug)
        if low == -1:
          subject_data_dict[counter]["original"] = subject_data
        else:
          subject_data_dict[counter][str(low)+"-"+str(high)] = subject_data

  return subject_data_dict
    
def load_all_subject_filter(dataset_path, filter_bank=None):
  data = {}
  for root, dirs, files in os.walk(dataset_path):
    for file in files:
      data[file] = load_subject_filter(name=file, dir=root, filter_bank=filter_bank) 
  return data

In [None]:
# cd to google drive
os.chdir("G:")

# Download fsaverage files
fs_dir = fetch_fsaverage(verbose=True)
subjects_dir = op.dirname(fs_dir)

# The files live in:
mne_subject = 'fsaverage'
trans = 'fsaverage'  # MNE has a built-in fsaverage transformation
src = op.join(fs_dir, 'bem', 'fsaverage-ico-5-src.fif')
bem = op.join(fs_dir, 'bem', 'fsaverage-5120-5120-5120-bem-sol.fif')

source = mne.read_source_spaces(src)
left = source[0]
right = source[1]
left_pos = left["rr"][left["inuse"]==1]
right_pos = right["rr"][right["inuse"]==1]
                        
transformation = mne.read_trans(op.join(fs_dir, "bem", "fsaverage-trans.fif"))

save_path = op.join(os.getcwd(), "Shared drives", "Motor Imagery", "Source Estimate")

In [None]:
lowcut = 4
highcut = 40
interval = 4
filter_bank = list(np.arange(lowcut, highcut+interval, step=interval))
print("filter bank: ", filter_bank)

true_labels_path = "Shared drives/Motor Imagery/BCI competition IV dataset/2a/2a true_labels/"
true_labels = load_all_true_labels(true_labels_path)

dataset_path = 'Shared drives/Motor Imagery/BCI competition IV dataset/2a/BCICIV_2a_gdf/'
data = load_all_subject(dataset_path, filter_bank=filter_bank)

In [None]:
# some information to help to understand functions and data structure
# for key, item in data.items():
#   print(key)

# ch_names = data["A01T.gdf"]["info"]["ch_names"]
# print(ch_names)

# print(data["A01T.gdf"]["class_info"])

# for key, value in data.items():
#   print(key)
#   for event_class, event_data in value["epoch_data"].items():
#       print(event_class, len(event_data))
#   print()

# subject_name = "A01T.gdf"
# Class = "left"
# filter_channels = ["C3", "Cz", "C4"]
# plot_average_graph(subject_name, Class, filter_channels)

# subject_name = "A02T.gdf"
# classes = ["left", "right"]
# filter_channels = ["C3", "Cz", "C4"]
# plot_multiple_graph(subject_name, classes, filter_channels)

In [None]:
epochs, epochs_filter_bank = create_epochs(data)
#apply_inverse_and_forward_kfold(epochs, n_splits=n_splits, save_inverse=True, save_forward=True)
print(epochs.keys())

In [None]:
apply_all_inverse_raw_bandpass_filter(dataset_path, epochs, ["A10"], filter_bank=filter_bank, save_filter=True)

In [None]:
lowcut = 4
highcut = 40
interval = 4
filter_bank = list(np.arange(lowcut, highcut+interval, step=interval))
print("filter bank: ", filter_bank)

true_labels_path = "Shared drives/Motor Imagery/BCI competition IV dataset/2a/2a true_labels/"
true_labels = load_all_true_labels(true_labels_path)

dataset_path = 'Shared drives/Motor Imagery/BCI competition IV dataset/2a/BCICIV_2a_gdf/'
data = load_all_subject_filter(dataset_path, filter_bank=filter_bank)

In [None]:
epochs, epochs_filter_bank = create_epochs_filter(data)
#apply_inverse_and_forward_kfold(epochs, n_splits=n_splits, save_inverse=True, save_forward=True)
print(epochs.keys())

In [None]:
# # some information to help to understand functions and data structure
# my_epochs = epochs["A01"]["right"]
# my_evoked = my_epochs.average().pick("eeg")

# noise_cov = mne.compute_covariance(my_epochs, tmax=0., method=['shrunk', 'empirical'], rank=None, verbose=False)
# fwd = mne.make_forward_solution(my_epochs.info, trans=trans, src=src,
#                             bem=bem, eeg=True, meg=False, mindist=5.0, n_jobs=1)
# # forward matrix
# fwd_fixed = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
#                                          use_cps=True)

# inverse_operator = make_inverse_operator(
#     my_epochs.info, fwd, noise_cov, loose=0.2, depth=0.8)

# method = "sLORETA"
# snr = 3.
# lambda2 = 1. / snr ** 2
# stc = mne.minimum_norm.apply_inverse(my_evoked, inverse_operator, lambda2,
#                               method=method, pick_ori="normal", verbose=True)
# stc.plot(hemi='both', background='white')

# reconstruct_evoked = mne.apply_forward(fwd_fixed, stc, my_evoked.info)
# my_evoked.plot_topomap()
# reconstruct_evoked.plot_topomap()

In [None]:
# # load gpt2 generated data
# fake_data_directory_path = '/standard_generate_dataset/finetune/gpt2xcnn/var_5.0_random_size_0.4.json'
# f = open(DIRECTORY_PATH + fake_data_directory_path)
# generate_data = json.load(f)
# X_fake = []
# Y_fake = []
 
# # Iterating through the json
# # list
# for key, value in generate_data.items():
#     X_fake.append(np.array(value))
#     if int(key) < 125:
#         Y_fake.append(np.zeros(1))
#     else:
#         Y_fake.append(np.ones(1))
# f.close()

# del generate_data
# gc.collect()

# X_fake = np.array(X_fake)
# Y_fake = np.array(Y_fake).reshape(-1)

# print(X_fake.shape)
# print(Y_fake.shape)

In [None]:
# # visualize fake data and compare with real data
# times = np.linspace(0, 1.75, 10)
# info = epochs["A05"]["left"].info
# print(fake_data_directory_path)
# print("fake left")
# fake_left_evoked = mne.EpochsArray(X_fake[Y_fake==0], info, verbose=False).average().pick("eeg")
# #fake_left_evoked = mne.EpochsArray(np.expand_dims(X_fake[Y_fake==0][0], axis=0), info, verbose=False).average().pick("eeg")
# fake_left_evoked.plot_topomap(times=times)
# #fake_left_evoked.plot()
# print("fake right")
# fake_right_evoked = mne.EpochsArray(X_fake[Y_fake==1], info, verbose=False).average().pick("eeg")
# #fake_right_evoked = mne.EpochsArray(np.expand_dims(X_fake[Y_fake==1][0], axis=0), info, verbose=False).average().pick("eeg")
# fake_right_evoked.plot_topomap(times=times)
# #fake_right_evoked.plot()

# real_train_X = np.load(op.join(op.join(EXTERNAL_STORAGE_PATH, "primary motor region", "data", "reconstructed eeg", "A05"), "1_train_X.npz"), allow_pickle=True)["data"]
# real_train_Y = np.load(op.join(op.join(EXTERNAL_STORAGE_PATH, "primary motor region", "data", "reconstructed eeg", "A05"), "1_train_Y.npz"), allow_pickle=True)["data"]
# print("real left")
# real_left_evoked = mne.EpochsArray(real_train_X[real_train_Y==0], info, verbose=False).average().pick("eeg")
# real_left_evoked.plot_topomap(times=times)
# #real_left_evoked.plot()
# print("real right")
# real_right_evoked = mne.EpochsArray(real_train_X[real_train_Y==1], info, verbose=False).average().pick("eeg")
# real_right_evoked.plot_topomap(times=times)
# #real_right_evoked.plot()

# CNN Classification (original data)

In [None]:
"""
labels
left (class 0) right (class 1) foot (class 2) tongue (class 3)

channels
c3(7) cz(9) c4(11)
"""
results = {"A01": {}, "A02": {}, "A03": {}, "A04": {}, "A05": {}, "A06": {}, "A07": {}, "A08": {}, "A09": {}}
events = ["left", "right"]
#select_channels = [7, 9, 11]
select_channels = list(range(22))
classes = 2
debug = True
training = True
model_name = "eegnet_inception"
num_epochs = 200

# train model on each subject individually
data_list = []
for subject in results.keys():
  data_list.append(subject)

# train model on individual subject
# data_list = []
# data_list.append("A05")

for data_name in data_list:
  accuracy = 0
  precision = 0
  recall = 0
  f1 = 0
  kappa = 0
    
  X, Y = [], []
  for event in epochs[data_name].keys():
    for i in range(len(events)):
      if event == events[i]:
        if len(X) == 0:
          X = epochs[data_name][event].get_data()
          Y = np.zeros(len(epochs[data_name][event].get_data())) + i
        else:
          X = np.append(X, epochs[data_name][event].get_data(), axis=0)
          Y = np.append(Y, np.zeros(len(epochs[data_name][event].get_data())) + i, axis=0)

  X = np.array(X)
  Y = np.array(Y)
    
  skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)
  for train_index, test_index in skf.split(X, Y):
    X_train, X_test = X[train_index], X[test_index]
    Y_train, Y_test = Y[train_index], Y[test_index]
    
    # add fake data
    # (1) real + fake
#     X_train = np.append(X_train, X_fake, axis=0)
#     Y_train = np.append(Y_train, Y_fake, axis=0)
    # (2) fake
#     X_train = X_fake
#     Y_train = Y_fake
    
    # pick c3, cZ, c4 channels
    X_train = X_train[:, select_channels, :]
    X_test = X_test[:, select_channels, :]
    
#     # Normalization
#     X_train_temp = np.zeros(X_train.shape)
#     for i in range(X_train.shape[1]):
#         temp = X_train[:, i, :]
#         X_train_temp[:, i, :] = (temp-np.mean(temp))/np.std(temp)
#     X_test_temp = np.zeros(X_test.shape)
#     for i in range(X_test.shape[1]):
#         temp = X_test[:, i, :]
#         X_test_temp[:, i, :] = (temp-np.mean(temp))/np.std(temp)
#     X_train = np.array(X_train_temp)
#     X_test = np.array(X_test_temp)

    print(data_name)
    if model_name == "default":
        X_train, Y_train = stft_min_max(X_train, Y_train, debug)
        X_test, Y_test = stft_min_max(X_test, Y_test, debug)
    elif model_name == "eegnet":
        X_train = np.expand_dims(X_train, axis=-1)
        X_test = np.expand_dims(X_test, axis=-1)
    elif model_name == "eegnet_inception":
        X_train = np.expand_dims(X_train, axis=-1)
        X_test = np.expand_dims(X_test, axis=-1)
    
    if debug:
      print("shape of X_train and Y_train: " + str(X_train.shape) + " " + str(Y_train.shape))
      print("shape of X_test and Y_test: " + str(X_test.shape) + " " + str(Y_test.shape))

    if training:
      # create new model
      model = create_model(num_class=classes, model_name=model_name, num_channel=len(select_channels))
        
      log_dir = DIRECTORY_PATH + "/motor/logs/" + data_name + "/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
      tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
      optimizer = Adam(learning_rate=1e-5)
      if classes == 2:
          model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
      else:
          model.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
      model.fit(X_train, Y_train, validation_data=(X_test, Y_test), batch_size=32, epochs=num_epochs, callbacks=[tensorboard_callback], verbose=0)

      Y_hat = model.predict(X_test)
      if classes == 2:
          Y_hat = (Y_hat >= 0.5)
      else:
          Y_hat = np.argmax(Y_hat, axis=1)
      accuracy += accuracy_score(Y_test, Y_hat)
      precision += precision_score(Y_test, Y_hat, average="macro")
      recall += recall_score(Y_test, Y_hat, average="macro")
      f1 += f1_score(Y_test, Y_hat, average="macro")
      kappa += cohen_kappa_score(Y_test, Y_hat)
    
      # save model
      model.save_weights(DIRECTORY_PATH + "/motor/models/" + data_name + "_" + str(accuracy_score(Y_test, Y_hat))[:6] + "/")
    else:
      # load pretrained model
      model = create_model(num_class=classes, model_name=model_name, num_channel=len(select_channels))
      model.load_weights(DIRECTORY_PATH + "/motor/models/" + "A05_0.7547/")
      # freeze model
      model.trainable = False
      optimizer = Adam(learning_rate=1e-5)
      if classes == 2:
          model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
      else:
          model.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
        
      Y_hat = model.predict(X_test)
      if classes == 2:
          Y_hat = (Y_hat >= 0.5)
      else:
          Y_hat = np.argmax(Y_hat, axis=1)
      accuracy += accuracy_score(Y_test, Y_hat)
      precision += precision_score(Y_test, Y_hat, average="macro")
      recall += recall_score(Y_test, Y_hat, average="macro")
      f1 += f1_score(Y_test, Y_hat, average="macro")
      kappa += cohen_kappa_score(Y_test, Y_hat)
    
  accuracy /= n_splits
  precision /= n_splits
  recall /= n_splits
  f1 /= n_splits
  kappa /= n_splits

  if debug:
    print("accuracy: " + str(accuracy))
    print("precision: " + str(precision))
    print("recall: " + str(recall))
    print("f1: " + str(f1))
    print("kappa: " + str(kappa))

  results[data_name]["accuracy"] = accuracy
  results[data_name]["precision"] = precision
  results[data_name]["recall"] = recall
  results[data_name]["f1"] = f1
  results[data_name]["kappa"] = kappa

In [None]:
# Calculate average performance
result_accuracy = []
result_precision = []
result_recall = []
result_f1 = []
result_kappa = []
for key, value in results.items():
  result_accuracy += [value["accuracy"]]
  result_precision += [value["precision"]]
  result_recall += [value["recall"]]
  result_f1 += [value["f1"]]
  result_kappa += [value["kappa"]]

print("accuracy: (mean) " + str(np.mean(result_accuracy)) + " (std) " + str(np.std(result_accuracy)))
print("precision: (mean) " + str(np.mean(result_precision)) + " (std) " + str(np.std(result_precision)))
print("recall: (mean) " + str(np.mean(result_recall)) + " (std) " + str(np.std(result_recall)))
print("f1: (mean) " + str(np.mean(result_f1)) + " (std) " + str(np.std(result_f1)))
print("kappa: (mean) " + str(np.mean(result_kappa)) + " (std) " + str(np.std(result_kappa)))

In [None]:
# time computation
events = ["left", "right"]
#select_channels = [7, 9, 11]
select_channels = list(range(22))
classes = 2
debug = False
model_name = "default"
warm_up = 10 # initializing memory allocators, and GPU-related initializations 

data_list = []
data_list.append("A01")

for data_name in data_list:
    
  X, Y = [], []
  for event in epochs[data_name].keys():
    for i in range(len(events)):
      if event == events[i]:
        if len(X) == 0:
          X = epochs[data_name][event].get_data()
          Y = np.zeros(len(epochs[data_name][event].get_data())) + i
        else:
          X = np.append(X, epochs[data_name][event].get_data(), axis=0)
          Y = np.append(Y, np.zeros(len(epochs[data_name][event].get_data())) + i, axis=0)

  X = np.array(X)
  Y = np.array(Y) 
  X = X[:, select_channels, :]

#   # Normalization
#   X_temp = np.zeros(X.shape)
#   for i in range(X.shape[1]):
#     temp = X[:, i, :]
#     X_temp[:, i, :] = (temp-np.mean(temp))/np.std(temp)
#   X = np.array(X_temp)
    
  X_test = np.expand_dims(X[0], axis=0)
  Y_test = np.expand_dims(Y[0], axis=0)
    
  # load pretrained model
  model = create_model(num_class=classes, model_name=model_name, num_channel=len(select_channels))
  model.load_weights("D:/forward and inverse results (new)/motor/default 22 channels/original EEG/models/A01_0.8571/")
  # freeze model
  model.trainable = False
  optimizer = Adam(learning_rate=1e-5)
  if classes == 2:
      model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
  else:
      model.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
  if model_name == "default":
      model.build(input_shape=(None, 40, 16*len(select_channels), 1))
  elif model_name == "eegnet":
      model.build(input_shape=(None, len(select_channels), 500, 1))
  elif model_name == "eegnet_inception":
      model.build(input_shape=(None, len(select_channels), 500, 1))
  print(model.summary())

  for i in range(warm_up):
    if i == warm_up-1:
        start = time.time()
    if model_name == "default":
      X_time, Y_time = stft_min_max(X_test, Y_test, debug)
    elif model_name == "eegnet":
      X_time = np.expand_dims(X_test, axis=-1)
    elif model_name == "eegnet_inception":
      X_time = np.expand_dims(X_test, axis=-1)
    print(X_time.shape)
    
    Y_hat = model.predict(X_time)

    if i == warm_up-1:
        end = time.time()
        print("time used: ", (end - start)*1000, "ms")

# CNN Classification (reconstructed data)

In [None]:
"""
labels
left (class 0) right (class 1) foot (class 2) tongue (class 3)

channels
c3(7) cz(9) c4(11)
"""

results = {"A01": {}, "A02": {}, "A03": {}, "A04": {}, "A05": {}, "A06": {}, "A07": {}, "A08": {}, "A09": {}}
#select_channels = [7, 9, 11]
select_channels = list(range(22))
classes = 2
debug = True
training = True
model_name = "eegnet"
num_epochs = 200

# train model on each subject individually
data_list = []
for subject in results.keys():
  data_list.append(subject)

# train model on individual subject
# data_list = []
# data_list.append("A02")

for data_name in data_list:
  # load data from external storage
  directory_path = op.join(EXTERNAL_STORAGE_PATH, "all motor region", "data", "reconstructed eeg", data_name)
  counter = 0
  accuracy = 0
  precision = 0
  recall = 0
  f1 = 0
  kappa = 0
  while(counter < n_splits):
    counter += 1
    X_train = np.load(op.join(directory_path, str(counter)+"_train_X.npz"), allow_pickle=True)["data"]
    X_test = np.load(op.join(directory_path, str(counter)+"_test_X.npz"), allow_pickle=True)["data"]
    Y_train = np.load(op.join(directory_path, str(counter)+"_train_Y.npz"), allow_pickle=True)["data"]
    Y_test = np.load(op.join(directory_path, str(counter)+"_test_Y.npz"), allow_pickle=True)["data"]
    
    #X_train, _, Y_train, _ = train_test_split(X_fake, Y_fake, test_size=0.8, random_state=456, stratify=Y_fake)
    
    # pick c3, cZ, c4 channels
    X_train = X_train[:, select_channels, :]
    X_test = X_test[:, select_channels, :]
    
#     # Normalization
#     X_train_temp = np.zeros(X_train.shape)
#     for i in range(X_train.shape[1]):
#         temp = X_train[:, i, :]
#         X_train_temp[:, i, :] = (temp-np.mean(temp))/np.std(temp)
#     X_test_temp = np.zeros(X_test.shape)
#     for i in range(X_test.shape[1]):
#         temp = X_test[:, i, :]
#         X_test_temp[:, i, :] = (temp-np.mean(temp))/np.std(temp)
#     X_train = np.array(X_train_temp)
#     X_test = np.array(X_test_temp)

    print(data_name)
    if model_name == "default":
        X_train, Y_train = stft_min_max(X_train, Y_train, debug)
        X_test, Y_test = stft_min_max(X_test, Y_test, debug)
    elif model_name == "eegnet":
        X_train = np.expand_dims(X_train, axis=-1)
        X_test = np.expand_dims(X_test, axis=-1)
    elif model_name == "eegnet_inception":
        X_train = np.expand_dims(X_train, axis=-1)
        X_test = np.expand_dims(X_test, axis=-1)

    if debug:
      print("shape of X_train and Y_train: " + str(X_train.shape) + " " + str(Y_train.shape))
      print("shape of X_test and Y_test: " + str(X_test.shape) + " " + str(Y_test.shape))

    if training:
      # create new model
      model = create_model(num_class=classes, model_name=model_name, num_channel=len(select_channels))
      
      log_dir = DIRECTORY_PATH + "/motor/logs/" + data_name + "/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
      tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
      optimizer = Adam(learning_rate=1e-5)
      if classes == 2:
          model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
      else:
          model.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
      model.fit(X_train, Y_train, validation_data=(X_test, Y_test), batch_size=32, epochs=num_epochs, callbacks=[tensorboard_callback], verbose=0)

      Y_hat = model.predict(X_test)
      if classes == 2:
          Y_hat = (Y_hat >= 0.5)
      else:
          Y_hat = np.argmax(Y_hat, axis=1)
      accuracy += accuracy_score(Y_test, Y_hat)
      precision += precision_score(Y_test, Y_hat, average="macro")
      recall += recall_score(Y_test, Y_hat, average="macro")
      f1 += f1_score(Y_test, Y_hat, average="macro")
      kappa += cohen_kappa_score(Y_test, Y_hat)

      # save model
      model.save_weights(DIRECTORY_PATH + "/motor/models/" + data_name + "_" + str(accuracy_score(Y_test, Y_hat))[:6] + "/")
    else:
      # load pretrained model
      model = create_model(num_class=classes, model_name=model_name, num_channel=len(select_channels))
      model.load_weights(DIRECTORY_PATH + "/motor/models/" + "A09_0.9183/")
      # freeze model
      model.trainable = False
      optimizer = Adam(learning_rate=1e-5)
      if classes == 2:
          model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
      else:
          model.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
        
      Y_hat = model.predict(X_test)
      if classes == 2:
          Y_hat = (Y_hat >= 0.5)
      else:
          Y_hat = np.argmax(Y_hat, axis=1)
      accuracy += accuracy_score(Y_test, Y_hat)
      precision += precision_score(Y_test, Y_hat, average="macro")
      recall += recall_score(Y_test, Y_hat, average="macro")
      f1 += f1_score(Y_test, Y_hat, average="macro")
      kappa += cohen_kappa_score(Y_test, Y_hat)

  accuracy /= n_splits
  precision /= n_splits
  recall /= n_splits
  f1 /= n_splits
  kappa /= n_splits
  if debug:
    print("accuracy: " + str(accuracy))
    print("precision: " + str(precision))
    print("recall: " + str(recall))
    print("f1: " + str(f1))
    print("kappa: " + str(kappa))

  results[data_name]["accuracy"] = accuracy
  results[data_name]["precision"] = precision
  results[data_name]["recall"] = recall
  results[data_name]["f1"] = f1
  results[data_name]["kappa"] = kappa

In [None]:
# Calculate average performance
result_accuracy = []
result_precision = []
result_recall = []
result_f1 = []
result_kappa = []
for key, value in results.items():
  result_accuracy += [value["accuracy"]]
  result_precision += [value["precision"]]
  result_recall += [value["recall"]]
  result_f1 += [value["f1"]]
  result_kappa += [value["kappa"]]

print("accuracy: (mean) " + str(np.mean(result_accuracy)) + " (std) " + str(np.std(result_accuracy)))
print("precision: (mean) " + str(np.mean(result_precision)) + " (std) " + str(np.std(result_precision)))
print("recall: (mean) " + str(np.mean(result_recall)) + " (std) " + str(np.std(result_recall)))
print("f1: (mean) " + str(np.mean(result_f1)) + " (std) " + str(np.std(result_f1)))
print("kappa: (mean) " + str(np.mean(result_kappa)) + " (std) " + str(np.std(result_kappa)))

In [None]:
# time computation
events = ["left", "right"]
#select_channels = [7, 9, 11]
select_channels = list(range(22))
classes = 2
debug = False
model_name = "eegnet"
warm_up = 10 # initializing memory allocators, and GPU-related initializations 

data_list = []
data_list.append("A01")

information = get_inverse_and_forward_information(epochs)
print(information)
leadfield = information["leadfield"]
inverse_operator = information["inverse_operator"]
my_left_points = information["my_left_points"]
my_right_points = information["my_right_points"]
info = epochs["A01"]["left"].info

for data_name in data_list:
    
  X, Y = [], []
  for event in epochs[data_name].keys():
    for i in range(len(events)):
      if event == events[i]:
        if len(X) == 0:
          X = epochs[data_name][event].get_data()
          Y = np.zeros(len(epochs[data_name][event].get_data())) + i
        else:
          X = np.append(X, epochs[data_name][event].get_data(), axis=0)
          Y = np.append(Y, np.zeros(len(epochs[data_name][event].get_data())) + i, axis=0)

  X = np.array(X)
  Y = np.array(Y) 

  print("shape of X and Y: " + str(X.shape) + " " + str(Y.shape))

  X_test = np.expand_dims(X[0], axis=0)
  Y_test = np.expand_dims(Y[0], axis=0)
    
  # load pretrained model
  model = create_model(num_class=classes, model_name=model_name, num_channel=len(select_channels))
  model.load_weights("D:/forward and inverse results (new)/motor/eegnet 22 channels (channel-wise normalization)/all motor/models/A01_0.8214/")
  # freeze model
  model.trainable = False
  optimizer = Adam(learning_rate=1e-5)
  if classes == 2:
      model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
  else:
      model.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])      
  if model_name == "default":
      model.build(input_shape=(None, 40, 16*len(select_channels), 1))
  elif model_name == "eegnet":
      model.build(input_shape=(None, len(select_channels), 500, 1))
  elif model_name == "eegnet_inception":
      model.build(input_shape=(None, len(select_channels), 500, 1))

  for i in range(warm_up):
    if i == warm_up-1:
        start = time.time()
    
    X_epochs = mne.EpochsArray(X_test, info, verbose=False)
    method = "sLORETA"
    snr = 3.
    lambda2 = 1. / snr ** 2
    stc_test = apply_inverse_epochs(X_epochs, inverse_operator, lambda2,
                                  method=method, pick_ori="normal", verbose=debug)
    # slice reconstructed eeg data
    reconstructed_eeg_data = []
    for source in stc_test:
        motor_source = np.zeros_like(source.data)
        motor_source[:len(source.vertices[0])][my_left_points] = source.data[:len(source.vertices[0])][my_left_points]
        motor_source[-len(source.vertices[1]):][my_right_points] = source.data[-len(source.vertices[1]):][my_right_points]
        motor_eeg = np.dot(leadfield, motor_source)
        reconstructed_eeg_data.append(motor_eeg)
    reconstructed_eeg_data = np.array(reconstructed_eeg_data)    
    reconstructed_eeg_data = reconstructed_eeg_data[:, select_channels, :]
    print(reconstructed_eeg_data.shape)
    
#     # Normalization
#     reconstructed_eeg_data_temp = np.zeros(reconstructed_eeg_data.shape)
#     for j in range(reconstructed_eeg_data.shape[1]):
#       temp = reconstructed_eeg_data[:, j, :]
#       reconstructed_eeg_data_temp[:, j, :] = (temp-np.mean(temp))/np.std(temp)
#     reconstructed_eeg_data = np.array(reconstructed_eeg_data_temp)
    
    if model_name == "default":
      X_time, Y_time = stft_min_max(reconstructed_eeg_data, Y_test, debug)
    elif model_name == "eegnet":
      X_time = np.expand_dims(reconstructed_eeg_data, axis=-1) 
    elif model_name == "eegnet_inception":
      X_time = np.expand_dims(reconstructed_eeg_data, axis=-1)
    
    Y_hat = model.predict(X_time)

    if i == warm_up-1:
        end = time.time()
        print("time used: ", (end - start)*1000, "ms")

# CNN Classification (Auto-select Source Activity)

In [None]:
information = get_inverse_and_forward_information(epochs)
print(information)

print(information["leadfield"][:, :len(information["left_vertices"])][:, information["my_left_points"]].shape)
print(information["leadfield"][:, -len(information["right_vertices"]):][:, information["my_right_points"]].shape)
forward_matrix = np.concatenate((information["leadfield"][:, :len(information["left_vertices"])][:, information["my_left_points"]], 
                          information["leadfield"][:, -len(information["right_vertices"]):][:, information["my_right_points"]]),
                          axis = 1)
print(forward_matrix.shape)   

In [None]:
"""
labels
left (class 0) right (class 1) foot (class 2) tongue (class 3)

channels
c3(7) cz(9) c4(11)
"""

results = {"A01": {}, "A02": {}, "A03": {}, "A04": {}, "A05": {}, "A06": {}, "A07": {}, "A08": {}, "A09": {}}
#select_channels = [7, 9, 11]
select_channels = list(range(22))
debug = True
training = True
random_select = False
use_mask = True
kqv = False
model_name = "eegnet"
num_epochs = 200

# train model on each subject individually
data_list = []
for subject in results.keys():
  data_list.append(subject)

# train model on individual subject
# data_list = []
# data_list.append("A02")
# data_list.append("A05")

for data_name in data_list:
  # load data from external storage
  directory_path = op.join(EXTERNAL_STORAGE_PATH, "all motor region", "data", "source activity", data_name)
  counter = 0
  accuracy = 0
  precision = 0
  recall = 0
  f1 = 0
  kappa = 0
    
  while(counter < n_splits):
    counter += 1
    X_train = np.load(op.join(directory_path, str(counter)+"_train_X.npz"), allow_pickle=True)["data"].astype(np.float32)
    X_test = np.load(op.join(directory_path, str(counter)+"_test_X.npz"), allow_pickle=True)["data"].astype(np.float32)
    Y_train = np.load(op.join(directory_path, str(counter)+"_train_Y.npz"), allow_pickle=True)["data"].astype(np.float32)
    Y_test = np.load(op.join(directory_path, str(counter)+"_test_Y.npz"), allow_pickle=True)["data"].astype(np.float32)
    
    if debug:
      print(data_name)
      print("shape of X_train and Y_train: " + str(X_train.shape) + " " + str(Y_train.shape))
      print("shape of X_test and Y_test: " + str(X_test.shape) + " " + str(Y_test.shape))

    if training:
      # create new model
      model = AutoSelect(select_channels, forward_matrix, random_select, use_mask, kqv, model_name)
      #model.build(input_shape=(None, forward_matrix.shape[1], 500))
      #print(model.summary())
      
      log_dir = DIRECTORY_PATH + "/motor/logs/" + data_name + "/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
      tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
      optimizer = Adam(learning_rate=1e-5)
      model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
      model.fit(X_train, Y_train, validation_data=(X_test, Y_test), batch_size=32, epochs=num_epochs, callbacks=[tensorboard_callback], verbose=0)
        
      Y_hat = model.predict(X_test, batch_size=8)
      Y_hat = (Y_hat >= 0.5)
      accuracy += accuracy_score(Y_test, Y_hat)
      precision += precision_score(Y_test, Y_hat, average="macro")
      recall += recall_score(Y_test, Y_hat, average="macro")
      f1 += f1_score(Y_test, Y_hat, average="macro")
      kappa += cohen_kappa_score(Y_test, Y_hat)

      # save model
      model.save_weights(DIRECTORY_PATH + "/motor/models/" + data_name + "_" + str(accuracy_score(Y_test, Y_hat))[:6] + "/")
    else:
      # load pretrained model
      model = AutoSelect(select_channels, forward_matrix, random_select, use_mask, kqv, model_name)
      model.load_weights(DIRECTORY_PATH + "/motor/models/" + "A09_0.9183/")
      # freeze model
      model.trainable = False
      optimizer = Adam(learning_rate=1e-5)
      model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
        
      Y_hat = model.predict(X_test)
      Y_hat = (Y_hat >= 0.5)
      accuracy += accuracy_score(Y_test, Y_hat)
      precision += precision_score(Y_test, Y_hat, average="macro")
      recall += recall_score(Y_test, Y_hat, average="macro")
      f1 += f1_score(Y_test, Y_hat, average="macro")
      kappa += cohen_kappa_score(Y_test, Y_hat)
    
    del X_train, Y_train, X_test, Y_test
    gc.collect()

  accuracy /= n_splits
  precision /= n_splits
  recall /= n_splits
  f1 /= n_splits
  kappa /= n_splits
  if debug:
    print("accuracy: " + str(accuracy))
    print("precision: " + str(precision))
    print("recall: " + str(recall))
    print("f1: " + str(f1))
    print("kappa: " + str(kappa))

  results[data_name]["accuracy"] = accuracy
  results[data_name]["precision"] = precision
  results[data_name]["recall"] = recall
  results[data_name]["f1"] = f1
  results[data_name]["kappa"] = kappa

In [None]:
# Calculate average performance
result_accuracy = []
result_precision = []
result_recall = []
result_f1 = []
result_kappa = []
for key, value in results.items():
  result_accuracy += [value["accuracy"]]
  result_precision += [value["precision"]]
  result_recall += [value["recall"]]
  result_f1 += [value["f1"]]
  result_kappa += [value["kappa"]]

print("accuracy: (mean) " + str(np.mean(result_accuracy)) + " (std) " + str(np.std(result_accuracy)))
print("precision: (mean) " + str(np.mean(result_precision)) + " (std) " + str(np.std(result_precision)))
print("recall: (mean) " + str(np.mean(result_recall)) + " (std) " + str(np.std(result_recall)))
print("f1: (mean) " + str(np.mean(result_f1)) + " (std) " + str(np.std(result_f1)))
print("kappa: (mean) " + str(np.mean(result_kappa)) + " (std) " + str(np.std(result_kappa)))

In [None]:
# time computation
events = ["left", "right"]
#select_channels = [7, 9, 11]
select_channels = list(range(22))
debug = False
random_select = False
use_mask = False
kqv = False
model_name = "eegnet"
warm_up = 10 # initializing memory allocators, and GPU-related initializations 

data_list = []
data_list.append("A01")

information = get_inverse_and_forward_information(epochs)
print(information)
print(information["leadfield"][:, :len(information["left_vertices"])][:, information["my_left_points"]].shape)
print(information["leadfield"][:, -len(information["right_vertices"]):][:, information["my_right_points"]].shape)
forward_matrix = np.concatenate((information["leadfield"][:, :len(information["left_vertices"])][:, information["my_left_points"]], 
                          information["leadfield"][:, -len(information["right_vertices"]):][:, information["my_right_points"]]),
                          axis = 1)
print(forward_matrix.shape)
inverse_operator = information["inverse_operator"]
my_left_points = information["my_left_points"]
my_right_points = information["my_right_points"]
info = epochs["A01"]["left"].info

for data_name in data_list:
    
  X, Y = [], []
  for event in epochs[data_name].keys():
    for i in range(len(events)):
      if event == events[i]:
        if len(X) == 0:
          X = epochs[data_name][event].get_data()
          Y = np.zeros(len(epochs[data_name][event].get_data())) + i
        else:
          X = np.append(X, epochs[data_name][event].get_data(), axis=0)
          Y = np.append(Y, np.zeros(len(epochs[data_name][event].get_data())) + i, axis=0)

  X = np.array(X)
  Y = np.array(Y) 
  X_test = np.expand_dims(X[0], axis=0)
  Y_test = np.expand_dims(Y[0], axis=0)
    
  # load pretrained model
  model = AutoSelect(select_channels, forward_matrix, random_select, use_mask, kqv, model_name)
  model.load_weights("D:/forward and inverse results (new)/motor/eegnet 22 channels (channel-wise normalization)/all motor attention/models/A01_0.8181/")
  # freeze model
  model.trainable = False
  optimizer = Adam(learning_rate=1e-5)
  model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])      
  model.build(input_shape=(None, forward_matrix.shape[1], 500))

  for i in range(warm_up):
    if i == warm_up-1:
        start = time.time()
    
    X_epochs = mne.EpochsArray(X_test, info, verbose=False)
    method = "sLORETA"
    snr = 3.
    lambda2 = 1. / snr ** 2
    stc_test = apply_inverse_epochs(X_epochs, inverse_operator, lambda2,
                                  method=method, pick_ori="normal", verbose=debug)
    # slice source activity data
    left_hemi_data = []
    right_hemi_data = []
    for source in stc_test:
        left_hemi_data.append(source.data[:len(source.vertices[0])][my_left_points])
        right_hemi_data.append(source.data[-len(source.vertices[1]):][my_right_points])
    left_hemi_data = np.array(left_hemi_data)
    right_hemi_data = np.array(right_hemi_data)
    
    X_time = np.append(left_hemi_data, right_hemi_data, axis=1)
    print(X_time.shape)
    Y_hat = model.predict(X_time)

    if i == warm_up-1:
        end = time.time()
        print("time used: ", (end - start)*1000, "ms")

# Visualizing Enhanced EEG Signals

## Save Enhanced Signals

In [None]:
# data_list = ["A01", "A02", "A03", "A04", "A05", "A06", "A07", "A08", "A09"]
# select_channels = list(range(22))
# #select_channels = [7, 9, 11]
# debug = True
# model_name = "eegnet"
# random_select = False
# use_mask = True
# kqv = False

# for data_name in data_list:
#   # load data from external storage
#   directory_path = op.join(EXTERNAL_STORAGE_PATH, "all motor region", "data", "source activity", data_name)
#   counter = 0
#   accuracy = 0
#   precision = 0
#   recall = 0
#   f1 = 0
#   kappa = 0
#   print(data_name)
    
#   while(counter < n_splits):
#     counter += 1
#     X_train = np.load(op.join(directory_path, str(counter)+"_train_X.npz"), allow_pickle=True)["data"].astype(np.float32)
#     X_test = np.load(op.join(directory_path, str(counter)+"_test_X.npz"), allow_pickle=True)["data"].astype(np.float32)
#     Y_train = np.load(op.join(directory_path, str(counter)+"_train_Y.npz"), allow_pickle=True)["data"].astype(np.float32)
#     Y_test = np.load(op.join(directory_path, str(counter)+"_test_Y.npz"), allow_pickle=True)["data"].astype(np.float32)
    
#     print("before...")
#     print("X_train shape: {}, X_test shape: {}".format(X_train.shape, X_test.shape))    
    
#     model = AutoSelectData(select_channels, forward_matrix, random_select, use_mask, kqv, model_name)
#     #weight_path = "D:/forward and inverse results (new)/motor/ml motor/"+ model_name + "/models/" + data_name + "/"
#     weight_path = "D:/forward and inverse results (new)/motor/ml motor/eegnet 22 channels/models/" + data_name + "/"
    
#     for weight_file in os.listdir(weight_path):
#         if weight_file.split("_")[0] == str(counter):
#             break
#     load_weights_file = os.path.join(weight_path, weight_file) + "/"
#     model.load_weights(load_weights_file)
#     #model.build((None, X_train.shape[1], X_train.shape[2]))
#     #print(model.summary())
#     model.trainable = False
#     X_train = model(X_train, training=False).numpy()
#     X_test = model(X_test, training=False).numpy()
    
#     print("after...")
#     print("X_train shape: {}, X_test shape: {}".format(X_train.shape, X_test.shape))
    
#     enhanced_eeg_path = op.join(EXTERNAL_STORAGE_PATH, "all motor region", "data", "enhanced eeg", data_name)
#     if not op.exists(enhanced_eeg_path):
#         os.makedirs(enhanced_eeg_path)
#     np.savez_compressed(op.join(enhanced_eeg_path, str(counter)+"_train_X.npz"), data=X_train)
#     np.savez_compressed(op.join(enhanced_eeg_path, str(counter)+"_train_Y.npz"), data=Y_train)
#     np.savez_compressed(op.join(enhanced_eeg_path, str(counter)+"_test_X.npz"), data=X_test)
#     np.savez_compressed(op.join(enhanced_eeg_path, str(counter)+"_test_Y.npz"), data=Y_test)

In [None]:
data_list = ["A01", "A02", "A03", "A04", "A05", "A06", "A07", "A08", "A09"]
events = ["left", "right"]
epochs_visualize = {}

for data_name in data_list:
  epochs_visualize[data_name] = {}
  accuracy = 0
  precision = 0
  recall = 0
  f1 = 0
  kappa = 0
    
  X, Y = [], []
  for event in epochs[data_name].keys():
    for i in range(len(events)):
      if event == events[i]:
        if len(X) == 0:
          X = epochs[data_name][event].get_data()
          Y = np.zeros(len(epochs[data_name][event].get_data())) + i
        else:
          X = np.append(X, epochs[data_name][event].get_data(), axis=0)
          Y = np.append(Y, np.zeros(len(epochs[data_name][event].get_data())) + i, axis=0)

  X = np.array(X)
  Y = np.array(Y)
    
  print(data_name)
    
  skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)
  counter = 0
  for train_index, test_index in skf.split(X, Y):
    counter += 1
    X_train, X_test = X[train_index], X[test_index]
    Y_train, Y_test = Y[train_index], Y[test_index]
    
    directory_path = op.join(EXTERNAL_STORAGE_PATH, "all motor region", "data", "enhanced eeg", data_name)
    X_train_enhanced = np.load(op.join(directory_path, str(counter)+"_train_X.npz"), allow_pickle=True)["data"].astype(np.float32)
    X_test_enhanced = np.load(op.join(directory_path, str(counter)+"_test_X.npz"), allow_pickle=True)["data"].astype(np.float32)
    Y_train_enhanced = np.load(op.join(directory_path, str(counter)+"_train_Y.npz"), allow_pickle=True)["data"].astype(np.float32)
    Y_test_enhanced = np.load(op.join(directory_path, str(counter)+"_test_Y.npz"), allow_pickle=True)["data"].astype(np.float32)
    
    assert(len(Y_train) == np.sum(Y_train == Y_train_enhanced))
    assert(len(Y_test) == np.sum(Y_test == Y_test_enhanced))
    
    # visualize eeg signals in time series
    info = epochs["A01"]["left"].info
    X_train_0 = X_train[np.where(Y_train == 0)[0]]
    X_train_1 = X_train[np.where(Y_train == 1)[0]]
    X_test_0 = X_test[np.where(Y_test == 0)[0]]
    X_test_1 = X_test[np.where(Y_test == 1)[0]]
    
    X_train_enhanced_0 = X_train_enhanced[np.where(Y_train == 0)[0]]
    X_train_enhanced_1 = X_train_enhanced[np.where(Y_train == 1)[0]]
    X_test_enhanced_0 = X_test_enhanced[np.where(Y_test == 0)[0]]
    X_test_enhanced_1 = X_test_enhanced[np.where(Y_test == 1)[0]]
    
    print(counter)
    print(len(X_train_0), len(X_train_1), len(X_test_0), len(X_test_1))
    print(len(X_train_enhanced_0), len(X_train_enhanced_1), len(X_test_enhanced_0), len(X_test_enhanced_1))
    
    epochs_visualize[data_name]["original"] = {}
    epochs_visualize[data_name]["enhanced"] = {}
    
    epochs_visualize[data_name]["original"]["X_train_0"] = mne.EpochsArray(X_train_0, info, verbose=False)
    epochs_visualize[data_name]["original"]["X_train_1"] = mne.EpochsArray(X_train_1, info, verbose=False)
    epochs_visualize[data_name]["original"]["X_test_0"] = mne.EpochsArray(X_test_0, info, verbose=False)
    epochs_visualize[data_name]["original"]["X_test_1"] = mne.EpochsArray(X_test_1, info, verbose=False)
    
    epochs_visualize[data_name]["enhanced"]["X_train_0"] = mne.EpochsArray(X_train_enhanced_0, info, verbose=False)
    epochs_visualize[data_name]["enhanced"]["X_train_1"] = mne.EpochsArray(X_train_enhanced_1, info, verbose=False)
    epochs_visualize[data_name]["enhanced"]["X_test_0"] = mne.EpochsArray(X_test_enhanced_0, info, verbose=False)
    epochs_visualize[data_name]["enhanced"]["X_test_1"] = mne.EpochsArray(X_test_enhanced_1, info, verbose=False)

## Plot Graph

In [None]:
%matplotlib qt

In [None]:
picked_time = [0.75, 1, 1.25, 1.5]
topomap_subject = "A09"

In [None]:
epochs_visualize[topomap_subject]["original"]["X_train_0"].average().plot_topomap(res=256, size=4)

In [None]:
epochs_visualize[topomap_subject]["original"]["X_train_1"].average().plot_topomap(res=256, size=4)

In [None]:
epochs_visualize[topomap_subject]["enhanced"]["X_train_0"].average().plot_topomap(res=256, size=4)

In [None]:
epochs_visualize[topomap_subject]["enhanced"]["X_train_1"].average().plot_topomap(res=256, size=4)

In [None]:
epochs_visualize[topomap_subject]["original"]["X_train_0"].average().plot(spatial_colors=True)

In [None]:
epochs_visualize[topomap_subject]["original"]["X_train_1"].average().plot(spatial_colors=True)

In [None]:
epochs_visualize[topomap_subject]["enhanced"]["X_train_0"].average().plot(spatial_colors=True)

In [None]:
epochs_visualize[topomap_subject]["enhanced"]["X_train_1"].average().plot(spatial_colors=True)

In [None]:
def visualize_subject_spectrogram(evoked_data, plot_channel=[7,9,11]):
    ch_num_to_ch_name = {7:"C3", 9:"Cz", 11:"C4"}
    Zxx = tf.signal.stft(evoked_data, frame_length=128, frame_step=16)
    Zxx = tf.abs(Zxx)
    
    print("shape of evoked_data: " + str(evoked_data.shape))
    print("shape of Zxx: " + str(Zxx.shape))
    
    # plot spectrogram
    fig, ax = plt.subplots(3, 1, figsize=(10, 10))
    fig.suptitle('Short Time Fourier Transform Magnitude')
    fig.supxlabel('frequency (Hz)')
    fig.supylabel('Time [sec]')
    for i, channel_i in enumerate(plot_channel):    
        log_spec = tf.math.log(tf.transpose(Zxx[channel_i]))
        height = 50
        width = log_spec.shape[1]
        x_axis = tf.linspace(0, 2, num=width)
        y_axis = range(height)
        ax[i].title.set_text(ch_num_to_ch_name[channel_i])
        im = ax[i].pcolormesh(x_axis, y_axis, log_spec[:50, ])
        fig.colorbar(mappable=im, ax=ax[i])
    plt.show()

In [None]:
visualize_subject_spectrogram(epochs_visualize[topomap_subject]["original"]["X_train_0"].average().get_data())

In [None]:
visualize_subject_spectrogram(epochs_visualize[topomap_subject]["original"]["X_train_1"].average().get_data())

In [None]:
visualize_subject_spectrogram(epochs_visualize[topomap_subject]["enhanced"]["X_train_0"].average().get_data())

In [None]:
visualize_subject_spectrogram(epochs_visualize[topomap_subject]["enhanced"]["X_train_1"].average().get_data())

# Visualizing Mask and Attention

In [None]:
#select_channels = [7, 9, 11]
select_channels = list(range(22))
model = AutoSelect(select_channels, forward_matrix, False, True, False, "eegnet")
model.build(input_shape=(None, 6433, 500))
#print(model.summary())
load_model_directory = "all motor mask initialize 1/models/"
load_model_subject = "A01_0.8928/"

model.load_weights("D:/forward and inverse results (new)/motor/eegnet 22 channels/" + load_model_directory + load_model_subject)

In [None]:
for weight in model.get_weights():
    print(weight.shape)

In [None]:
mni_lh = mne.vertex_to_mni(information["left_vertices"], 0, mne_subject)
print(mni_lh.shape)
mni_rh = mne.vertex_to_mni(information["right_vertices"], 1, mne_subject)
print(mni_rh.shape)

mni_left_points = mni_lh[information["my_left_points"]]
print(mni_left_points.shape)
mni_right_points = mni_rh[information["my_right_points"]]
print(mni_right_points.shape)

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(projection='3d')
ax.scatter(mni_lh[:, 0], mni_lh[:, 1], mni_lh[:, 2], s=15, marker='|')
ax.scatter(mni_rh[:, 0], mni_rh[:, 1], mni_rh[:, 2], s=15, marker='_')
ax.scatter(mni_left_points[:, 0], mni_left_points[:, 1], mni_left_points[:, 2], s=15, marker='o')
ax.scatter(mni_right_points[:, 0], mni_right_points[:, 1], mni_right_points[:, 2], s=15, marker='^')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

plt.show()

## Mask Visualization

In [None]:
# for mask
#np.sum(model.get_weights()[-1] >= 1)

for pts in model.get_weights()[-1]:
    print(pts)
print(np.max(model.get_weights()[-1]))
print(np.min(model.get_weights()[-1]))
print(np.mean(model.get_weights()[-1]))
print(np.std(model.get_weights()[-1]))

In [None]:
left_weight = np.zeros((information["my_left_points"].shape))
right_weight =  np.zeros((information["my_right_points"].shape))
left_weight[information["my_left_points"]] = model.get_weights()[-1].reshape(-1)[:np.sum(information["my_left_points"])]
right_weight[information["my_right_points"]] = model.get_weights()[-1].reshape(-1)[np.sum(information["my_left_points"]):]

total_weight = np.append(left_weight, right_weight, axis=0)
print(total_weight)
print(total_weight.shape)

# binary weights
total_weight = (total_weight >= 1).astype(int)
print(total_weight)
print("original: ", np.sum(information["my_left_points"])+np.sum(np.sum(information["my_right_points"])))
print("mask: ", np.sum(total_weight))

In [None]:
vertices_tree = KDTree(mm_coords.reshape(-1, 3))
final_weight = np.zeros(mm_coords.reshape(-1, 3).shape)

max_points = 0
for i, vertex_i in enumerate(vertices_tree.query_ball_point(np.append(mni_lh, mni_rh, axis=0), 2)):
    if vertex_i != []:
        if np.max(vertex_i) > max_points:
            max_points = np.max(vertex_i)
        #print(mm_coords.reshape(-1, 3)[np.max(vertex_i)])
        #print(np.append(mni_lh, mni_rh, axis=0)[i])
        
        # fill all
        for vertex_j in vertex_i:
            final_weight[vertex_j] = total_weight[i]
        # fill biggest index
        #final_weight[np.max(vertex_i)] = total_weight[i]
        #print(vertex_i)
print(max_points)

final_weight = final_weight.reshape(mm_coords.shape)[:, :, :, 0]

visualization_mask = Nifti1Image(final_weight, ch2_img.affine, ch2_img.header)
nib.save(visualization_mask, os.path.join('C:/Users/ivanlim/Desktop/EEG-forward-and-inverse', 'visualization_mask.nii.gz'))  

In [None]:
visualization_mask_with_shape = Nifti1Image(np.asanyarray(final_weight), ch2_img.affine, ch2_img.header)
nib.save(visualization_mask_with_shape, os.path.join('C:/Users/ivanlim/Desktop/EEG-forward-and-inverse', 'visualization_mask_with_shape.nii.gz')) 

In [None]:
plotting.plot_glass_brain("C:/Users/ivanlim/Desktop/EEG-forward-and-inverse/visualization_mask_with_shape.nii.gz", display_mode='xz', threshold=0.9)

# Statistical test

In [None]:
def parse_tensorboard(path, scalars):
    """returns a dictionary of pandas dataframes for each requested scalar"""
    ea = event_accumulator.EventAccumulator(
        path,
        size_guidance={event_accumulator.SCALARS: 0},
    )
    _absorb_print = ea.Reload()
    #print(ea.Tags())
    # make sure the scalars are in the event accumulator tags
    assert all(
        s in ea.Tags()['tensors'] for s in scalars
    ), "some scalars were not found in the event accumulator"
    return {k: pd.DataFrame(ea.Tensors(k)) for k in scalars}

## Right-tailed test (paired)

### subject-wise

In [None]:
method_path = "D:/forward and inverse results (new)/motor/eeg-inception 22 channels (channel-wise normalization)"
accuracy_dict = {}

for method in os.listdir(method_path):
    logs_path = os.path.join(method_path, method, "logs")
    if not accuracy_dict.get(method):
        accuracy_dict[method] = {}
    for subject in os.listdir(logs_path):
        if not accuracy_dict[method].get(subject):
            accuracy_dict[method][subject] = {}
        subject_path = os.path.join(logs_path, subject)
        for time in os.listdir(subject_path):
            validation_path = os.path.join(subject_path, time, "validation")
            tensorboard_logs = os.path.join(validation_path, os.listdir(validation_path)[0])
            df = parse_tensorboard(tensorboard_logs, ["epoch_loss", "epoch_accuracy"])
            #print(subject, time)
            last_test_accuracy = tf.constant(tf.make_ndarray(df["epoch_accuracy"]["tensor_proto"][len(df["epoch_accuracy"]["tensor_proto"])-1]))
            accuracy_dict[method][subject][time] = last_test_accuracy.numpy()
            
accuracy_mean_std = {}
for method in accuracy_dict.keys():
    if method == "original EEG":
        continue
    accuracy_mean_std[method] = {}
    for subject in accuracy_dict[method].keys():
        accuracy_mean_std[method][subject] = {}
        accuracy_mean_std[method][subject]["accuracy"] = []
        for i in range(len(list(accuracy_dict[method][subject].values()))):
            accuracy_method = list(accuracy_dict[method][subject].values())[i]
            accuracy_control = list(accuracy_dict["original EEG"][subject].values())[i]
            accuracy_mean_std[method][subject]["accuracy"].append(accuracy_method - accuracy_control)
        mean = np.mean(accuracy_mean_std[method][subject]["accuracy"])
        std = np.std(accuracy_mean_std[method][subject]["accuracy"])
        accuracy_mean_std[method][subject]["mean"] = mean
        accuracy_mean_std[method][subject]["std"] = std
        #print(method, subject, mean, std)
        
subject_list = ["A01", "A02", "A03", "A04", "A05", "A06", "A07", "A08", "A09"]
t_test_results = {}

for i, subject in enumerate(subject_list):
    t_test_results[subject] = {}
    count = 0
    for method in accuracy_mean_std.keys():
#         if method == "all motor attention with dropout 0.5":
#             continue
            
        mean = accuracy_mean_std[method][subject]["mean"]
        std = accuracy_mean_std[method][subject]["std"]
        n = 5
        t_test = mean / np.sqrt((std**2)/n)
        df = n - 1
        if t_test > 0:
            p_value = stats.t.sf(np.abs(t_test), df)
            if p_value < 0.05:
                count += 1
                print(subject, method, mean)
        #print(t_test, df, p_value)
        
    print(count)

### method-wise

In [None]:
method_path = "D:/forward and inverse results (new)/motor/eeg-inception 22 channels (channel-wise normalization)"
accuracy_dict = {}

for method in os.listdir(method_path):
    logs_path = os.path.join(method_path, method, "logs")
    if not accuracy_dict.get(method):
        accuracy_dict[method] = {}
    for subject in os.listdir(logs_path):
        if not accuracy_dict[method].get(subject):
            accuracy_dict[method][subject] = {}
        subject_path = os.path.join(logs_path, subject)
        for time in os.listdir(subject_path):
            validation_path = os.path.join(subject_path, time, "validation")
            tensorboard_logs = os.path.join(validation_path, os.listdir(validation_path)[0])
            df = parse_tensorboard(tensorboard_logs, ["epoch_loss", "epoch_accuracy"])
            #print(subject, time)
            last_test_accuracy = tf.constant(tf.make_ndarray(df["epoch_accuracy"]["tensor_proto"][len(df["epoch_accuracy"]["tensor_proto"])-1]))
            accuracy_dict[method][subject][time] = last_test_accuracy.numpy()
            
accuracy_mean_std = {}
for method in accuracy_dict.keys():
    accuracy_mean_std[method] = {}
    for subject in accuracy_dict[method].keys():
        accuracy_time = list(accuracy_dict[method][subject].values())
        mean = np.mean(accuracy_time)
        std = np.std(accuracy_time)
        accuracy_mean_std[method][subject] = {}
        accuracy_mean_std[method][subject]["mean"] = mean
        accuracy_mean_std[method][subject]["std"] = std
        #print(method, subject, mean)
        
subject_list = ["A01", "A02", "A03", "A04", "A05", "A06", "A07", "A08", "A09"]
t_test_results = {}

for method in accuracy_mean_std.keys():
#     if method == "original EEG" or method == "all motor attention with dropout 0.5":
#         continue
    if method == "original EEG":
        continue
    t_test_results[method] = []
    for subject in subject_list:
        accuracy_1 = accuracy_mean_std[method][subject]["mean"]
        accuracy_2 = accuracy_mean_std["original EEG"][subject]["mean"]
        mean = accuracy_1 - accuracy_2
        t_test_results[method].append(mean)
        #print(mean)

for method in accuracy_mean_std.keys():
#     if method == "original EEG" or method == "all motor attention with dropout 0.5":
#         continue
    if method == "original EEG":
        continue
    mean = np.mean(t_test_results[method])
    std = np.std(t_test_results[method])
    n = 9
    t_test = mean / np.sqrt((std**2)/n)
    df = n - 1
    if t_test > 0:
        p_value = stats.t.sf(np.abs(t_test), df)
        if p_value < 0.05:
            print(method, mean)
    #print(t_test, df, p_value)

# Comparison

## Traditional machine learning method

### Generate trainable attention model weights

In [None]:
information = get_inverse_and_forward_information(epochs)
print(information)

print(information["leadfield"][:, :len(information["left_vertices"])][:, information["my_left_points"]].shape)
print(information["leadfield"][:, -len(information["right_vertices"]):][:, information["my_right_points"]].shape)
forward_matrix = np.concatenate((information["leadfield"][:, :len(information["left_vertices"])][:, information["my_left_points"]], 
                          information["leadfield"][:, -len(information["right_vertices"]):][:, information["my_right_points"]]),
                          axis = 1)
print(forward_matrix.shape)   

In [None]:
"""
labels
left (class 0) right (class 1) foot (class 2) tongue (class 3)

channels
c3(7) cz(9) c4(11)
"""

results = {"A01": {}, "A02": {}, "A03": {}, "A04": {}, "A05": {}, "A06": {}, "A07": {}, "A08": {}, "A09": {}}
#select_channels = [7, 9, 11]
select_channels = list(range(22))
debug = True
training = True
random_select = False
use_mask = True
kqv = False
model_name = "eegnet"
num_epochs = 200

# train model on each subject individually
data_list = []
for subject in results.keys():
  data_list.append(subject)

# train model on individual subject
# data_list = []
# data_list.append("A02")
# data_list.append("A05")

for data_name in data_list:
  # load data from external storage
  directory_path = op.join(EXTERNAL_STORAGE_PATH, "all motor region", "data", "source activity", data_name)
  counter = 0
  accuracy = 0
  precision = 0
  recall = 0
  f1 = 0
  kappa = 0
    
  while(counter < n_splits):
    counter += 1
    X_train = np.load(op.join(directory_path, str(counter)+"_train_X.npz"), allow_pickle=True)["data"].astype(np.float32)
    X_test = np.load(op.join(directory_path, str(counter)+"_test_X.npz"), allow_pickle=True)["data"].astype(np.float32)
    Y_train = np.load(op.join(directory_path, str(counter)+"_train_Y.npz"), allow_pickle=True)["data"].astype(np.float32)
    Y_test = np.load(op.join(directory_path, str(counter)+"_test_Y.npz"), allow_pickle=True)["data"].astype(np.float32)
    
    if debug:
      print(data_name)
      print("shape of X_train and Y_train: " + str(X_train.shape) + " " + str(Y_train.shape))
      print("shape of X_test and Y_test: " + str(X_test.shape) + " " + str(Y_test.shape))

    if training:
      # create new model
      model = AutoSelect(select_channels, forward_matrix, random_select, use_mask, kqv, model_name)
      #model.build(input_shape=(None, forward_matrix.shape[1], 500))
      #print(model.summary())
      
      log_dir = DIRECTORY_PATH + "/" + model_name + "/ml motor/logs/" + data_name + "/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
      tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
      optimizer = Adam(learning_rate=1e-5)
      model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
      model.fit(X_train, Y_train, validation_data=(X_test, Y_test), batch_size=32, epochs=num_epochs, callbacks=[tensorboard_callback], verbose=0)
        
      Y_hat = model.predict(X_test, batch_size=8)
      Y_hat = (Y_hat >= 0.5)
      accuracy += accuracy_score(Y_test, Y_hat)
      precision += precision_score(Y_test, Y_hat, average="macro")
      recall += recall_score(Y_test, Y_hat, average="macro")
      f1 += f1_score(Y_test, Y_hat, average="macro")
      kappa += cohen_kappa_score(Y_test, Y_hat)

      # save model
      model.save_weights(DIRECTORY_PATH + "/" + model_name + "/ml motor/models/" + data_name + "/" + str(counter) + "_" + str(accuracy_score(Y_test, Y_hat))[:6] + "/")
    else:
      # load pretrained model
      model = AutoSelect(select_channels, forward_matrix, random_select, use_mask, kqv, model_name)
      model.load_weights(DIRECTORY_PATH + "/default/ml motor/models/A09/" + "1_0.9183/")
      # freeze model
      model.trainable = False
      optimizer = Adam(learning_rate=1e-5)
      model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
        
      Y_hat = model.predict(X_test)
      Y_hat = (Y_hat >= 0.5)
      accuracy += accuracy_score(Y_test, Y_hat)
      precision += precision_score(Y_test, Y_hat, average="macro")
      recall += recall_score(Y_test, Y_hat, average="macro")
      f1 += f1_score(Y_test, Y_hat, average="macro")
      kappa += cohen_kappa_score(Y_test, Y_hat)
    
    del X_train, Y_train, X_test, Y_test
    gc.collect()

  accuracy /= n_splits
  precision /= n_splits
  recall /= n_splits
  f1 /= n_splits
  kappa /= n_splits
  if debug:
    print("accuracy: " + str(accuracy))
    print("precision: " + str(precision))
    print("recall: " + str(recall))
    print("f1: " + str(f1))
    print("kappa: " + str(kappa))

  results[data_name]["accuracy"] = accuracy
  results[data_name]["precision"] = precision
  results[data_name]["recall"] = recall
  results[data_name]["f1"] = f1
  results[data_name]["kappa"] = kappa

In [None]:
# Calculate average performance
result_accuracy = []
result_precision = []
result_recall = []
result_f1 = []
result_kappa = []
for key, value in results.items():
  result_accuracy += [value["accuracy"]]
  result_precision += [value["precision"]]
  result_recall += [value["recall"]]
  result_f1 += [value["f1"]]
  result_kappa += [value["kappa"]]

print("accuracy: (mean) " + str(np.mean(result_accuracy)) + " (std) " + str(np.std(result_accuracy)))
print("precision: (mean) " + str(np.mean(result_precision)) + " (std) " + str(np.std(result_precision)))
print("recall: (mean) " + str(np.mean(result_recall)) + " (std) " + str(np.std(result_recall)))
print("f1: (mean) " + str(np.mean(result_f1)) + " (std) " + str(np.std(result_f1)))
print("kappa: (mean) " + str(np.mean(result_kappa)) + " (std) " + str(np.std(result_kappa)))

### Lda, SVM, Random Forest, Adaboost

In [None]:
"""
labels
left (class 0) right (class 1) foot (class 2) tongue (class 3)

channels
c3(7) cz(9) c4(11)
"""
results = {"A01": {}, "A02": {}, "A03": {}, "A04": {}, "A05": {}, "A06": {}, "A07": {}, "A08": {}, "A09": {}}
events = ["left", "right"]
select_channels = list(range(22))
#select_channels = [7, 9, 11]
classes = 2
debug = True
use_csp = True

warm_up = 10 # initializing memory allocators, and GPU-related initializations 

# train model on each subject individually
data_list = []
for subject in results.keys():
  data_list.append(subject)

# train model on individual subject
# data_list = []
# data_list.append("A05")

for data_name in data_list:
  accuracy = []
  precision = []
  recall = []
  f1 = []
  kappa = []
    
  X, Y = [], []
  if use_csp: 
    for freq in epochs_filter_bank[data_name].keys():
        #print(freq)
        X_freq, Y_freq = [], []
        
        for event in epochs_filter_bank[data_name][freq].keys():
            for i in range(len(events)):
                if event == events[i]:
                    if len(X_freq) == 0:
                        X_freq = epochs_filter_bank[data_name][freq][event].get_data()
                        Y_freq = np.zeros(len(epochs_filter_bank[data_name][freq][event].get_data())) + i
                    else:
                        X_freq = np.append(X_freq, epochs_filter_bank[data_name][freq][event].get_data(), axis=0)
                        Y_freq = np.append(Y_freq, np.zeros(len(epochs_filter_bank[data_name][freq][event].get_data())) + i, axis=0)
        if len(X) == 0:
            X = np.expand_dims(X_freq, axis=0)
            Y = Y_freq
        else:
            X = np.append(X, np.expand_dims(X_freq, axis=0), axis=0)
  else:
    for event in epochs[data_name].keys():
      for i in range(len(events)):
        if event == events[i]:
          if len(X) == 0:
            X = epochs[data_name][event].get_data()
            Y = np.zeros(len(epochs[data_name][event].get_data())) + i
          else:
            X = np.append(X, epochs[data_name][event].get_data(), axis=0)
            Y = np.append(Y, np.zeros(len(epochs[data_name][event].get_data())) + i, axis=0)
  X = np.array(X)
  Y = np.array(Y)
  print(X.shape, Y.shape)

  count = 0
  skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)
  for train_index, test_index in skf.split(X[0], Y) if len(X.shape) == 4 else skf.split(X, Y):
    count += 1
    
    if use_csp:
        X_train_csp, Y_train_csp = [], []
        X_test_csp, Y_test_csp = [], []
        for i in range(len(X)):
            #print("current freq: ", i)
            X_train, X_test = X[i][train_index], X[i][test_index]
            Y_train, Y_test = Y[train_index], Y[test_index]
            csp_pipeline = make_pipeline(
                CSP(n_components=4, reg='diagonal_fixed', log=True, norm_trace=False, rank='full')
            )
            
            X_train = csp_pipeline.fit_transform(X_train, Y_train)
            X_test = csp_pipeline.transform(X_test)
            
            if len(X_train_csp) == 0:
                X_train_csp = X_train
                X_test_csp = X_test
                Y_train_csp = Y_train
                Y_test_csp = Y_test
            else:
                X_train_csp = np.append(X_train_csp, X_train, axis=1)
                X_test_csp = np.append(X_test_csp, X_test, axis=1)
                
        X_train = np.array(X_train_csp, copy=True)
        X_test = np.array(X_test_csp, copy=True)
        
    else:
      X_train, X_test = X[train_index], X[test_index]
      Y_train, Y_test = Y[train_index], Y[test_index]
    
    print(X_train.shape, Y_train.shape)
    print(X_test.shape, Y_test.shape)
    
    # add fake data
    # (1) real + fake
#     X_train = np.append(X_train, X_fake, axis=0)
#     Y_train = np.append(Y_train, Y_fake, axis=0)
    # (2) fake
#     X_train = X_fake
#     Y_train = Y_fake
    
    # pick c3, cZ, c4 channels
    if not use_csp:
        X_train = X_train[:, select_channels, :]
        X_test = X_test[:, select_channels, :]

    print(data_name)
    X_train = X_train.reshape(X_train.shape[0], -1)
    X_test = X_test.reshape(X_test.shape[0], -1)

    if debug:
      print("shape of X_train and Y_train: " + str(X_train.shape) + " " + str(Y_train.shape))
      print("shape of X_test and Y_test: " + str(X_test.shape) + " " + str(Y_test.shape))

    #model = make_pipeline(StandardScaler(), SVC(gamma='auto', random_state=0))
    #model = make_pipeline(StandardScaler(), RandomForestClassifier(random_state=0))
    #model = make_pipeline(StandardScaler(), AdaBoostClassifier(n_estimators=100, random_state=0))
    model = make_pipeline(StandardScaler(), LinearDiscriminantAnalysis())
    model.fit(X_train, Y_train)
    Y_hat = model.predict(X_test)
    
    accuracy.append(accuracy_score(Y_test, Y_hat))
    precision.append(precision_score(Y_test, Y_hat, average="macro"))
    recall.append(recall_score(Y_test, Y_hat, average="macro"))
    f1.append(f1_score(Y_test, Y_hat, average="macro"))
    kappa.append(cohen_kappa_score(Y_test, Y_hat))
    
    # time computation
    if data_name != "A01" or count != 1:
        continue
    for i in range(warm_up):
        if i == warm_up-1:
            start = time.time()
        X_time = np.expand_dims(X_test[0], 0)
        #print(X_time.shape)

        Y_hat = model.predict(X_time)

        if i == warm_up-1:
            end = time.time()
            print("time used: ", (end - start)*1000, "ms")

  if debug:
    print("accuracy: " + str(np.mean(accuracy)))
    print("precision: " + str(np.mean(precision)))
    print("recall: " + str(np.mean(recall)))
    print("f1: " + str(np.mean(f1)))
    print("kappa: " + str(np.mean(kappa)))

  results[data_name]["accuracy"] = accuracy
  results[data_name]["precision"] = precision
  results[data_name]["recall"] = recall
  results[data_name]["f1"] = f1
  results[data_name]["kappa"] = kappa

In [None]:
# Calculate average performance
result_accuracy = []
result_precision = []
result_recall = []
result_f1 = []
result_kappa = []
for key, value in results.items():
  result_accuracy += [value["accuracy"]]
  result_precision += [value["precision"]]
  result_recall += [value["recall"]]
  result_f1 += [value["f1"]]
  result_kappa += [value["kappa"]]

print("accuracy: (mean) " + str(np.mean(result_accuracy)) + " (std) " + str(np.std(result_accuracy)))
print("precision: (mean) " + str(np.mean(result_precision)) + " (std) " + str(np.std(result_precision)))
print("recall: (mean) " + str(np.mean(result_recall)) + " (std) " + str(np.std(result_recall)))
print("f1: (mean) " + str(np.mean(result_f1)) + " (std) " + str(np.std(result_f1)))
print("kappa: (mean) " + str(np.mean(result_kappa)) + " (std) " + str(np.std(result_kappa)))

#np.savez(os.path.join(DIRECTORY_PATH, "lda.npz"), results=results)

In [None]:
# svm = np.load(os.path.join(DIRECTORY_PATH, "svm.npz"), allow_pickle=True)
# print(svm["results"].item()["A01"])

In [None]:
"""
labels
left (class 0) right (class 1) foot (class 2) tongue (class 3)

channels
c3(7) cz(9) c4(11)
"""
results = {"A01": {}, "A02": {}, "A03": {}, "A04": {}, "A05": {}, "A06": {}, "A07": {}, "A08": {}, "A09": {}}
events = ["left", "right"]
use_csp = False
select_channels = list(range(22))
#select_channels = [7, 9, 11]
classes = 2
debug = True
model_name = "eegnet"
random_select = False
use_mask = True
kqv = False


# train model on each subject individually
data_list = []
for subject in results.keys():
  data_list.append(subject)

# train model on individual subject
# data_list = []
# data_list.append("A05")

for data_name in data_list:
  # load data from external storage
  directory_path = op.join(EXTERNAL_STORAGE_PATH, "all motor region", "data", "source activity", data_name)
  counter = 0
  accuracy = []
  precision = []
  recall = []
  f1 = []
  kappa = []
    
  while(counter < n_splits):
    counter += 1
    X_train = np.load(op.join(directory_path, str(counter)+"_train_X.npz"), allow_pickle=True)["data"].astype(np.float32)
    X_test = np.load(op.join(directory_path, str(counter)+"_test_X.npz"), allow_pickle=True)["data"].astype(np.float32)
    Y_train = np.load(op.join(directory_path, str(counter)+"_train_Y.npz"), allow_pickle=True)["data"].astype(np.float32)
    Y_test = np.load(op.join(directory_path, str(counter)+"_test_Y.npz"), allow_pickle=True)["data"].astype(np.float32)
    
    if debug:
      print(data_name)
      print("shape of X_train and Y_train: " + str(X_train.shape) + " " + str(Y_train.shape))
      print("shape of X_test and Y_test: " + str(X_test.shape) + " " + str(Y_test.shape))

    model = AutoSelectData(select_channels, forward_matrix, random_select, use_mask, kqv, model_name)
    #weight_path = "D:/forward and inverse results (new)/motor/ml motor/"+ model_name + "/models/" + data_name + "/"
    weight_path = "D:/forward and inverse results (new)/motor/ml motor/eegnet 22 channels/models/" + data_name + "/"
    
    for weight_file in os.listdir(weight_path):
        if weight_file.split("_")[0] == str(counter):
            break
    load_weights_file = os.path.join(weight_path, weight_file) + "/"
    model.load_weights(load_weights_file)
    #model.build((None, X_train.shape[1], X_train.shape[2]))
    #print(model.summary())
    model.trainable = False
    X_train = model(X_train, training=False).numpy()
    X_test = model(X_test, training=False).numpy()
    
    if not use_csp:
        X_train = X_train.reshape(X_train.shape[0], -1)
        X_test = X_test.reshape(X_test.shape[0], -1)
    
    if use_csp:
        # old version
        lowcut = 4
        highcut = 40
        interval = 4
        filter_bank = list(np.arange(lowcut, highcut+interval, step=interval))
        info = epochs[data_name]["left"].info
        
        X_train = mne.EpochsArray(X_train, info, verbose=False)
        X_test = mne.EpochsArray(X_test, info, verbose=False)
        
        X_train_csp = []
        X_test_csp = []
        
        for i in range(len(filter_bank)-1):
            low = filter_bank[i]
            high = filter_bank[i+1]
            print("current frequency: ", low, " to ", high)

            # filter frequency
            iir_params = dict(order=5, ftype='butter')
            X_train_filter = X_train.copy()
            X_train_filter.filter(low, high, method="iir", iir_params=iir_params)
            
            X_test_filter = X_test.copy()
            X_test_filter.filter(low, high, method="iir", iir_params=iir_params)
            
            csp_pipeline = make_pipeline(
                CSP(n_components=4, reg='diagonal_fixed', log=True, norm_trace=False, rank='full')
            )
            
            X_train_csp_feature = csp_pipeline.fit_transform(X_train_filter.get_data(), Y_train)
            X_test_csp_feature = csp_pipeline.transform(X_test_filter.get_data())
            
            if len(X_train_csp) == 0:
                X_train_csp = X_train_csp_feature
                X_test_csp = X_test_csp_feature
            else:
                X_train_csp = np.append(X_train_csp, X_train_csp_feature, axis=1)
                X_test_csp = np.append(X_test_csp, X_test_csp_feature, axis=1)
        
        X_train = np.array(X_train_csp, copy=True)
        X_test = np.array(X_test_csp, copy=True)
        
    print("X train: ", X_train.shape, " X test: ", X_test.shape)

    #model = make_pipeline(StandardScaler(), SVC(gamma='auto', random_state=0))
    #model = make_pipeline(StandardScaler(), RandomForestClassifier(random_state=0))
    model = make_pipeline(StandardScaler(), AdaBoostClassifier(n_estimators=100, random_state=0))
    #model = make_pipeline(StandardScaler(), LinearDiscriminantAnalysis())
    model.fit(X_train, Y_train)
    Y_hat = model.predict(X_test)
    
    accuracy.append(accuracy_score(Y_test, Y_hat))
    precision.append(precision_score(Y_test, Y_hat, average="macro"))
    recall.append(recall_score(Y_test, Y_hat, average="macro"))
    f1.append(f1_score(Y_test, Y_hat, average="macro"))
    kappa.append(cohen_kappa_score(Y_test, Y_hat))

  if debug:
    print("accuracy: " + str(np.mean(accuracy)))
    print("precision: " + str(np.mean(precision)))
    print("recall: " + str(np.mean(recall)))
    print("f1: " + str(np.mean(f1)))
    print("kappa: " + str(np.mean(kappa)))

  results[data_name]["accuracy"] = accuracy
  results[data_name]["precision"] = precision
  results[data_name]["recall"] = recall
  results[data_name]["f1"] = f1
  results[data_name]["kappa"] = kappa

### Masked raw -> bandpass filter -> CSP

In [None]:
"""
labels
left (class 0) right (class 1) foot (class 2) tongue (class 3)

channels
c3(7) cz(9) c4(11)
"""
results = {"A01": {}, "A02": {}, "A03": {}, "A04": {}, "A05": {}, "A06": {}, "A07": {}, "A08": {}, "A09": {}}
events = ["left", "right"]
select_channels = list(range(22))
#select_channels = [7, 9, 11]
classes = 2
debug = True

# train model on each subject individually
data_list = []
for subject in results.keys():
  data_list.append(subject)

# train model on individual subject
# data_list = []
# data_list.append("A01")

for data_name in data_list:
  counter = 0
  accuracy = []
  precision = []
  recall = []
  f1 = []
  kappa = []
    
  skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)
    
  while(counter < n_splits):
    counter += 1
    
    X, Y = [], []
    for freq in epochs_filter_bank[data_name][str(counter)].keys():
        #print(freq)
        X_freq, Y_freq = [], []

        for event in epochs_filter_bank[data_name][str(counter)][freq].keys():
            for i in range(len(events)):
                if event == events[i]:
                    if len(X_freq) == 0:
                        X_freq = epochs_filter_bank[data_name][str(counter)][freq][event].get_data()
                        Y_freq = np.zeros(len(epochs_filter_bank[data_name][str(counter)][freq][event].get_data())) + i
                    else:
                        X_freq = np.append(X_freq, epochs_filter_bank[data_name][str(counter)][freq][event].get_data(), axis=0)
                        Y_freq = np.append(Y_freq, np.zeros(len(epochs_filter_bank[data_name][str(counter)][freq][event].get_data())) + i, axis=0)
        if len(X) == 0:
            X = np.expand_dims(X_freq, axis=0)
            Y = Y_freq
        else:
            X = np.append(X, np.expand_dims(X_freq, axis=0), axis=0)
    if debug:
        print("shape of X and Y: " + str(X.shape) + " " + str(Y.shape))
    split_index = list(skf.split(X[0], Y))
    train_index, test_index = split_index[counter-1]
    
    X_train_csp, Y_train_csp = [], []
    X_test_csp, Y_test_csp = [], []
    for i in range(len(X)):
        #print("current freq: ", i)
        X_train, X_test = X[i][train_index], X[i][test_index]
        Y_train, Y_test = Y[train_index], Y[test_index]
        csp_pipeline = make_pipeline(
            CSP(n_components=4, reg='diagonal_fixed', log=True, norm_trace=False, rank='full')
        )
            
        X_train = csp_pipeline.fit_transform(X_train, Y_train)
        X_test = csp_pipeline.transform(X_test)
            
        if len(X_train_csp) == 0:
            X_train_csp = X_train
            X_test_csp = X_test
            Y_train_csp = Y_train
            Y_test_csp = Y_test
        else:
            X_train_csp = np.append(X_train_csp, X_train, axis=1)
            X_test_csp = np.append(X_test_csp, X_test, axis=1)
                
        X_train = np.array(X_train_csp, copy=True)
        X_test = np.array(X_test_csp, copy=True)
    
    
    if debug:
      print(data_name)
      print("shape of X_train and Y_train: " + str(X_train.shape) + " " + str(Y_train.shape))
      print("shape of X_test and Y_test: " + str(X_test.shape) + " " + str(Y_test.shape))

    #model = make_pipeline(StandardScaler(), SVC(gamma='auto', random_state=0))
    #model = make_pipeline(StandardScaler(), RandomForestClassifier(random_state=0))
    model = make_pipeline(StandardScaler(), AdaBoostClassifier(n_estimators=100, random_state=0))
    #model = make_pipeline(StandardScaler(), LinearDiscriminantAnalysis())
    model.fit(X_train, Y_train)
    Y_hat = model.predict(X_test)
    
    accuracy.append(accuracy_score(Y_test, Y_hat))
    precision.append(precision_score(Y_test, Y_hat, average="macro"))
    recall.append(recall_score(Y_test, Y_hat, average="macro"))
    f1.append(f1_score(Y_test, Y_hat, average="macro"))
    kappa.append(cohen_kappa_score(Y_test, Y_hat))

  if debug:
    print("accuracy: " + str(np.mean(accuracy)))
    print("precision: " + str(np.mean(precision)))
    print("recall: " + str(np.mean(recall)))
    print("f1: " + str(np.mean(f1)))
    print("kappa: " + str(np.mean(kappa)))

  results[data_name]["accuracy"] = accuracy
  results[data_name]["precision"] = precision
  results[data_name]["recall"] = recall
  results[data_name]["f1"] = f1
  results[data_name]["kappa"] = kappa

In [None]:
# Calculate average performance
result_accuracy = []
result_precision = []
result_recall = []
result_f1 = []
result_kappa = []
for key, value in results.items():
  result_accuracy += [value["accuracy"]]
  result_precision += [value["precision"]]
  result_recall += [value["recall"]]
  result_f1 += [value["f1"]]
  result_kappa += [value["kappa"]]

print("accuracy: (mean) " + str(np.mean(result_accuracy)) + " (std) " + str(np.std(result_accuracy)))
print("precision: (mean) " + str(np.mean(result_precision)) + " (std) " + str(np.std(result_precision)))
print("recall: (mean) " + str(np.mean(result_recall)) + " (std) " + str(np.std(result_recall)))
print("f1: (mean) " + str(np.mean(result_f1)) + " (std) " + str(np.std(result_f1)))
print("kappa: (mean) " + str(np.mean(result_kappa)) + " (std) " + str(np.std(result_kappa)))

#np.savez(os.path.join(DIRECTORY_PATH, "adaboost_boosted.npz"), results=results)

In [None]:
# svm_boosted = np.load(os.path.join(DIRECTORY_PATH, "svm_boosted.npz"), allow_pickle=True)
# print(svm_boosted["results"].item()["A01"])

Refer to raw eeg -> source -> bandpass filter -> csp

In [None]:
# time computation
events = ["left", "right"]
use_csp = False
select_channels = list(range(22))
#select_channels = [7, 9, 11]
debug = False
random_select = False
use_mask = True
kqv = False
model_name = "eegnet"
warm_up = 10 # initializing memory allocators, and GPU-related initializations 

data_list = []
data_list.append("A01")

information = get_inverse_and_forward_information(epochs)
print(information)
print(information["leadfield"][:, :len(information["left_vertices"])][:, information["my_left_points"]].shape)
print(information["leadfield"][:, -len(information["right_vertices"]):][:, information["my_right_points"]].shape)
forward_matrix = np.concatenate((information["leadfield"][:, :len(information["left_vertices"])][:, information["my_left_points"]], 
                          information["leadfield"][:, -len(information["right_vertices"]):][:, information["my_right_points"]]),
                          axis = 1)
print(forward_matrix.shape)
inverse_operator = information["inverse_operator"]
my_left_points = information["my_left_points"]
my_right_points = information["my_right_points"]
info = epochs["A01"]["left"].info

for data_name in data_list:
    
  X, Y = [], []
  for event in epochs[data_name].keys():
    for i in range(len(events)):
      if event == events[i]:
        if len(X) == 0:
          X = epochs[data_name][event].get_data()
          Y = np.zeros(len(epochs[data_name][event].get_data())) + i
        else:
          X = np.append(X, epochs[data_name][event].get_data(), axis=0)
          Y = np.append(Y, np.zeros(len(epochs[data_name][event].get_data())) + i, axis=0)

  X = np.array(X)
  Y = np.array(Y) 
  X_test = np.expand_dims(X[0], axis=0)
  Y_test = np.expand_dims(Y[0], axis=0)
    
  # load pretrained model
  model = AutoSelectData(select_channels, forward_matrix, random_select, use_mask, kqv, model_name)
  #weight_path = "D:/forward and inverse results (new)/motor/ml motor/"+ model_name + "/models/" + data_name + "/"
  weight_path = "D:/forward and inverse results (new)/motor/ml motor/eegnet 22 channels/models/" + data_name + "/"
  counter = 0

  for weight_file in os.listdir(weight_path):
      if weight_file.split("_")[0] == str(counter):
          break
  load_weights_file = os.path.join(weight_path, weight_file) + "/"
  model.load_weights(load_weights_file)
  #model.build((None, X_train.shape[1], X_train.shape[2]))
  #print(model.summary())
  model.trainable = False

  #ml_model = make_pipeline(StandardScaler(), SVC(gamma='auto', random_state=0))
  #ml_model = make_pipeline(StandardScaler(), RandomForestClassifier(random_state=0))
  ml_model = make_pipeline(StandardScaler(), AdaBoostClassifier(n_estimators=100, random_state=0))
  #ml_model = make_pipeline(StandardScaler(), LinearDiscriminantAnalysis())

  # dummy training
  if use_csp:
    ml_model.fit(X[:, select_channels,:].reshape(X.shape[0], -1)[:, :36], Y)
  else:
    ml_model.fit(X[:, select_channels,:].reshape(X.shape[0], -1), Y)
  
  csp_pipeline = make_pipeline(
      CSP(n_components=4, reg='diagonal_fixed', log=True, norm_trace=False, rank='full')
  )
  csp_pipeline.fit(X, Y)

  for i in range(warm_up):
    if i == warm_up-1:
        start = time.time()
    
    X_epochs = mne.EpochsArray(X_test, info, verbose=False)
    method = "sLORETA"
    snr = 3.
    lambda2 = 1. / snr ** 2
    stc_test = apply_inverse_epochs(X_epochs, inverse_operator, lambda2,
                                  method=method, pick_ori="normal", verbose=debug)
    # slice source activity data
    left_hemi_data = []
    right_hemi_data = []
    for source in stc_test:
        left_hemi_data.append(source.data[:len(source.vertices[0])][my_left_points])
        right_hemi_data.append(source.data[-len(source.vertices[1]):][my_right_points])
    left_hemi_data = np.array(left_hemi_data)
    right_hemi_data = np.array(right_hemi_data)
    
    X_time = np.append(left_hemi_data, right_hemi_data, axis=1)
    X_time = model(X_time, training=False).numpy()
    
    if not use_csp:
        X_time = X_time.reshape(X_time.shape[0], -1)
    
    if use_csp:
        lowcut = 4
        highcut = 40
        interval = 4
        filter_bank = list(np.arange(lowcut, highcut+interval, step=interval))
        info = epochs[data_name]["left"].info
        
        X_time = mne.EpochsArray(X_time, info, verbose=False)
        
        X_time_csp = []
        
        for freq_i in range(len(filter_bank)-1):
            low = filter_bank[freq_i]
            high = filter_bank[freq_i+1]
            #print("current frequency: ", low, " to ", high)

            # filter frequency
            iir_params = dict(order=5, ftype='butter')
            X_time_filter = X_time.copy()
            X_time_filter.filter(low, high, method="iir", iir_params=iir_params)
            X_time_csp_feature = csp_pipeline.transform(X_time_filter.get_data())
            
            if len(X_time_csp) == 0:
                X_time_csp = X_time_csp_feature
            else:
                X_time_csp = np.append(X_time_csp, X_time_csp_feature, axis=1)
        
        X_time = np.array(X_time_csp, copy=True)
    
    Y_hat = ml_model.predict(X_time)

    if i == warm_up-1:
        end = time.time()
        print("time used: ", (end - start)*1000, "ms")

### Results

In [None]:
results_path = os.path.join(DIRECTORY_PATH, "ml motor results (22 channels)")

for ml_result in os.listdir(results_path):
    print(ml_result.split(".")[0])
    result = np.load(os.path.join(results_path, ml_result), allow_pickle=True)
    subject_mean_accuracy = []
    for subject, metrics in result["results"].item().items():
        print(subject)
        print(np.mean(metrics["accuracy"])*100)
        subject_mean_accuracy.append(np.mean(metrics["accuracy"])*100)
    print("average: ", np.mean(subject_mean_accuracy), "std: ", np.std(subject_mean_accuracy))

### Statistical test (subject-wise)

In [None]:
results_path = os.path.join(DIRECTORY_PATH, "ml motor results csp (22 channels)")

for ml_result in os.listdir(results_path):
    if "boosted" in ml_result.split(".")[0]:
        continue
    print(ml_result.split(".")[0])
    result_original = np.load(os.path.join(results_path, ml_result), allow_pickle=True)
    result_boosted = np.load(os.path.join(results_path, ml_result.split(".")[0]+"_boosted.npz"), allow_pickle=True)
    original_accuracy = {}
    boosted_accuracy = {}
    for subject, metrics in result_original["results"].item().items():
        original_accuracy[subject] = np.array(metrics["accuracy"])
    for subject, metrics in result_boosted["results"].item().items():
        boosted_accuracy[subject] = np.array(metrics["accuracy"])
    count = 0
    for subject in original_accuracy.keys():
        accuracy_mean_std[subject] = {}
        mean = np.mean(boosted_accuracy[subject] - original_accuracy[subject])
        std = np.std(boosted_accuracy[subject] - original_accuracy[subject])

        n = 5
        t_test = mean / np.sqrt((std**2)/n)
        df = n - 1
        if t_test > 0:
            p_value = stats.t.sf(np.abs(t_test), df)
            if p_value < 0.05:
                count += 1
                print(subject, mean)
        #print(t_test, df, p_value)
    print(count)

### Statistical test (method-wise)

In [None]:
results_path = os.path.join(DIRECTORY_PATH, "ml motor results csp (22 channels)")

for ml_result in os.listdir(results_path):
    if "boosted" in ml_result.split(".")[0]:
        continue
    print(ml_result.split(".")[0])
    result_original = np.load(os.path.join(results_path, ml_result), allow_pickle=True)
    result_boosted = np.load(os.path.join(results_path, ml_result.split(".")[0]+"_boosted.npz"), allow_pickle=True)
    original_accuracy = {}
    boosted_accuracy = {}
    accuracy_mean = []
    for subject, metrics in result_original["results"].item().items():
        original_accuracy[subject] = np.mean(np.array(metrics["accuracy"]))
    for subject, metrics in result_boosted["results"].item().items():
        boosted_accuracy[subject] = np.mean(np.array(metrics["accuracy"]))
    for subject in original_accuracy.keys():
        accuracy_mean.append(boosted_accuracy[subject] - original_accuracy[subject])

    mean = np.mean(accuracy_mean)
    std = np.std(accuracy_mean)
    n = 9
    t_test = mean / np.sqrt((std**2)/n)
    df = n - 1
    if t_test > 0:
        p_value = stats.t.sf(np.abs(t_test), df)
        if p_value < 0.05:
            print(mean)
    #print(t_test, df, p_value)