In [None]:
# --------------------------------------------------------------------deep learning--------------------------------------------------------------------
!pip install mne &> /dev/null
!pip install skorch -U &> /dev/null
!pip install mne-icalabel &> /dev/null
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import os
import numpy as np
import mne
import matplotlib.pyplot as plt
import math
from scipy.io import loadmat
from mne.preprocessing import ICA
from mne_icalabel import label_components
from sklearn.base import BaseEstimator, TransformerMixin
import pickle
from skorch import NeuralNetClassifier
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import cross_val_score, cross_val_predict
from mne.decoding import CSP
from mne.filter import filter_data
from sklearn.preprocessing import StandardScaler

Mounted at /content/drive


In [None]:
mne.set_log_level(False)

In [None]:
sbj_n = 4 #@param {type:"integer"}
sbj_path = 'Subject ' + str(sbj_n)

In [None]:
base_path = '/content/drive/MyDrive/EEG data'
subj_folder = os.path.join(base_path, sbj_path)
l = [file for file in os.listdir(subj_folder) if file.endswith('gdf')]

In [None]:
type_1 = [file for file in os.listdir(subj_folder) if file.endswith('1.gdf')]
type_2 = [file for file in os.listdir(subj_folder) if file.endswith('2.gdf')]

In [None]:
smaller = min(len(type_1), len(type_2))
type_1 = type_1[:smaller]
type_2 = type_2[:smaller]

In [None]:
# Constant
ch_list = ['Fp1', 'Fp2', 'AF3', 'AF4', 'F7', 'F3', 'Fz', 'F4', 'F8', 'FT7', 'FC3', 'FCz','FC4', 'FT8', 'T7', 'C3', 'Cz', 'C4', 'T8', 'TP7', 'CP3', 'CPz', 'CP4', 'TP8', 'P7', 'P3', 'Pz', 'P4', 'P8', 'O1', 'Oz', 'O2']
# For epoching
st = 0
# For sub-epochs
sub_dur = 2
stride = 0.25
# For cross-validation dataset
n_splits = 4

In [None]:
def load_ds(pt, f):       # For BCIC Dataset
  fp = os.path.join(pt, f)
  print("File path: {}".format(fp))
  if fp.endswith('gdf'):
    raw_data = mne.io.read_raw_gdf(fp, preload=True)
  else:
    raw_data = None
  return raw_data

def select_event(events_from_annot, event_dict, sfreq):
  # specify needed event
  l_ev = event_dict['769']
  r_ev = event_dict['770']
  x_ev = event_dict['33024']
  end_trial = event_dict['800']
  needed_event = [l_ev, r_ev, x_ev, end_trial]
  # Remove unecessary last part
  re = events_from_annot[::-1, 2].tolist()  # reverse
  last_id = re.index(end_trial)             # Find index of last end signal
  if last_id > 0:                           # Filter out last part
    events_annot = events_from_annot[:-last_id]
  else:
    events_annot = events_from_annot[:]
  # Filter put other event except the one used
  mask = np.isin(events_annot[:, 2], needed_event)
  filtered_events = events_annot[mask]
  # Get durations
  a = np.diff(filtered_events[:, 0])/sfreq
  dur = round(np.mean(a[::2]),2)
  print('Average duration of this session trials : ',dur)
  return needed_event, dur

def sub_epochs(epochs):
  smaller_epochs = []
  sf = epochs.info['sfreq']
  for epoch in epochs:
    data = epoch[np.newaxis,:,:]
    # Calculate the number of smaller epochs that can be created
    n_epochs = (data.shape[2] - (sub_dur * sf)) // (stride * sf) + 1
    for i in range(int(n_epochs)):
        start_sample = int(i * (stride * sf))
        end_sample = start_sample + int(sub_dur * sf)
        smaller_epoch_data = data[:, :, start_sample:end_sample]

        # Create a new Epoch object with the smaller epoch data
        smaller_epoch = mne.EpochsArray(smaller_epoch_data, info=epochs.info)
        smaller_epochs.append(smaller_epoch)
  # Combine all the smaller epochs into a single Epochs object
  smaller_epochs = mne.epochs.concatenate_epochs(smaller_epochs)
  return smaller_epochs

def create_label(size, lbl):
  return np.full(size, lbl)

def epoch_array(epoch_l, epoch_r, epoch_x):
  mini_ep_l = sub_epochs(epoch_l)   # Create sub epochs
  mini_ep_r = sub_epochs(epoch_r)
  mini_ep_x = sub_epochs(epoch_x)
  ec_l = len(mini_ep_l)/8       # Calculate sub-epochs per trial
  ec_r = len(mini_ep_r)/8
  ec_x = len(mini_ep_x)/8
  ep_l = mini_ep_l.get_data(copy=True)   # Turn to array
  ep_r = mini_ep_r.get_data(copy=True)
  ep_x = mini_ep_x.get_data(copy=True)
  lbl_l = create_label(ep_l.shape[0], 1)    # Create Labels
  lbl_r = create_label(ep_r.shape[0], 2)
  lbl_x = create_label(ep_x.shape[0], 0)
  # Combine arrays
  epoch_data = np.concatenate((ep_l, ep_r, ep_x),axis=0)
  label_data = np.concatenate((lbl_l, lbl_r, lbl_x),axis=0)
  return epoch_data, label_data

def balance_dataset(X, y):
  unique, counts = np.unique(y, return_counts=True)
  groups = np.hstack((np.repeat(np.arange(8), ([int(counts[1]/8)]*8)), np.repeat(np.arange(8), ([int(counts[2]/8)]*8)), np.repeat(np.arange(8), ([int(counts[0]/8)]*8))))
  # Randomly sample from class 0, get indices
  balanced_indices = []
  for group_val in range(8):
      class_2_indices = np.where((groups == group_val) & (y == 0))[0]
      selected_indices = np.random.choice(class_2_indices, size=int(counts[1]/8), replace=False)
      balanced_indices.extend(selected_indices)
  balanced_indices.sort()
  balanced_X = np.concatenate([X[y != 0], X[balanced_indices]])
  balanced_y = np.concatenate([y[y != 0], y[balanced_indices]])
  unique2, counts2 = np.unique(balanced_y, return_counts=True)
  new_groups = np.hstack((np.repeat(np.arange(8), ([int(counts2[1]/8)]*8)), np.repeat(np.arange(8), ([int(counts2[2]/8)]*8)), np.repeat(np.arange(8), ([int(counts2[0]/8)]*8))))
  return balanced_X, balanced_y, new_groups

def save_epoch(X, y, file_name):
  # Names
  epoch_folder = subj_folder + '/Epoch'
  rec_name = file_name[:-4] + '_data.txt'
  lbl_name = file_name[:-4] + '_label.txt'
  # Path
  ep_file = os.path.join(epoch_folder, rec_name)
  lbl_file = os.path.join(epoch_folder, lbl_name)
  arr_reshaped = X.reshape(X.shape[0], -1)
  np.savetxt(ep_file, arr_reshaped)
  np.savetxt(lbl_file, y)

In [None]:
def load_dataset(file_list: list):
  combined_set = []
  for data in file_list:
    x = load_ds(subj_folder, data)    # Load data
    combined_set.append(x)
  raw_cat = mne.concatenate_raws(combined_set)
  channel_mapping = {old_name: new_name for old_name, new_name in zip(raw_cat.ch_names, ch_list)}   # Remap channels
  raw_cat.rename_channels(channel_mapping)
  montage = mne.channels.make_standard_montage('standard_1020')
  _ = raw_cat.set_montage(montage)
  events_from_annot, event_dict = mne.events_from_annotations(raw_cat)
  needed_event, dur = select_event(events_from_annot, event_dict, raw_cat.info['sfreq'])
  return raw_cat, needed_event, dur, events_from_annot

In [None]:
for data in type_1:
  x = load_ds(subj_folder, data)    # Load data
  events_from_annot, event_dict = mne.events_from_annotations(x)
  needed_event, dur = select_event(events_from_annot, event_dict, x.info['sfreq'])
  print(needed_event)

File path: /content/drive/MyDrive/EEG data/Subject 4/record-[2024.03.10]_S9_1.gdf
Average duration of this session trials :  6.21
[3, 4, 2, 5]
File path: /content/drive/MyDrive/EEG data/Subject 4/record-[2024.03.10]_S7_1.gdf
Average duration of this session trials :  6.21
[3, 4, 2, 5]


In [142]:
def create_dataset(file_list: list):
  combined_set = []
  combined_labels = []
  combined_groups = []
  for data in file_list:
    x = load_ds(subj_folder, data)    # Load data
    events_from_annot, event_dict = mne.events_from_annotations(x)
    # Event select
    needed_event, dur = select_event(events_from_annot, event_dict, x.info['sfreq'])
    x = x.set_eeg_reference("average")
    print(needed_event)
    # Load data into epochs
    #epoch_base = mne.Epochs(x.copy().crop(tmin=10-dur, tmax=10), events_from_annot)

    #--------------------origin--------------------------
    epoch_l = mne.Epochs(x, events_from_annot, event_id=needed_event[0], tmin=st, tmax=st+dur, baseline=None, preload=True)
    epoch_r = mne.Epochs(x, events_from_annot, event_id=needed_event[1], tmin=st, tmax=st+dur, baseline=None, preload=True)
    epoch_x = mne.Epochs(x, events_from_annot, event_id=needed_event[2], tmin=st, tmax=st+dur, baseline=None, preload=True)
    #--------------------origin--------------------------

    # epoch_l = mne.Epochs(x, events_from_annot, event_id=needed_event[1], tmin=st, tmax=st+dur, baseline=None, preload=True)
    # epoch_r = mne.Epochs(x, events_from_annot, event_id=needed_event[3], tmin=st, tmax=st+dur, baseline=None, preload=True)
    # epoch_x = mne.Epochs(x, events_from_annot, event_id=needed_event[2], tmin=st, tmax=st+dur, baseline=None, preload=True)

    if len(epoch_x) > (len(epoch_l)+len(epoch_r)):
      epoch_x.drop([-1])    # Remove last epoch since it's
    # Bandpass filter
    #epoch_base.filter(l_freq=1, h_freq=40)

    # retain alpha power band ?
    epoch_l.filter(l_freq=6, h_freq=14)
    epoch_r.filter(l_freq=6, h_freq=14)
    epoch_x.filter(l_freq=6, h_freq=14)
    X, y = epoch_array(epoch_l, epoch_r, epoch_x)

    new_X, new_y, current_group = balance_dataset(X, y)
    save_epoch(new_X, new_y, data)
    combined_set.append(new_X)
    combined_labels.append(new_y)
    combined_groups.append(current_group)

    # without balance data
    # ------------------------------------------------------
    # unique, counts = np.unique(y, return_counts=True)
    # groups = np.hstack((np.repeat(np.arange(8), ([int(counts[1]/8)]*8)), np.repeat(np.arange(8), ([int(counts[2]/8)]*8)), np.repeat(np.arange(8), ([int(counts[0]/8)]*8))))
    # save_epoch(X, y, data)
    # combined_set.append(X)
    # combined_labels.append(y)
    # combined_groups.append(groups)
    # ------------------------------------------------------

  if len(combined_set) > 1:
    result_X = np.vstack(combined_set)
    result_y = np.concatenate(combined_labels)
    result_group = np.concatenate(combined_groups)
  else:
    result_X = combined_set[0]
    result_y = combined_labels[0]
    result_group = combined_groups[0]
  return result_X, result_y, result_group

In [143]:
from sklearn.model_selection import (
    GroupKFold,
    KFold,
    StratifiedGroupKFold,
    StratifiedKFold,
    StratifiedShuffleSplit
)

cmap_data = plt.cm.Paired
cmap_cv = plt.cm.coolwarm

#Type1

In [None]:
ds, lbl, grouping = create_dataset(type_1)

In [68]:
print(len(ds))
print(len(lbl))
print(len(grouping))

1088
1088
1088


In [69]:
cv = StratifiedGroupKFold(n_splits)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy import signal
from skorch.callbacks import Checkpoint, EpochScoring, EarlyStopping

In [146]:
t = 2
sf = 128
cls_n = 3

In [145]:
# https://github.com/High-East/BCI-ToolBox/blob/master/models/EEGNet/EEGNet.py
class EEGNet(nn.Module):
    def __init__(self, in_chn, n_cls, input_ts, f1=8, f2=16, d=2, drop_prob=0.5):
        super(EEGNet, self).__init__()

        self.F1 = f1   # High Frequency pattern
        self.F2 = f2   # Lower Frequency patter
        self.D = d    # Dilation (?), spatial?
        #
        self.kernel_l = math.ceil(sf/2)
        self.chn = in_chn
        self.cls = n_cls
        self.drop_prob = drop_prob
        self.tp = input_ts

        # Spectral
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, self.F1, (1, self.kernel_l), padding=(0, math.ceil(self.kernel_l//2)), bias=False),
            nn.BatchNorm2d(self.F1)
        )

        # Spectral-specific Spatial
        self.conv2 = nn.Sequential(
            nn.Conv2d(self.F1, self.D*self.F1, (self.chn, 1), groups=self.F1, bias=False),
            nn.BatchNorm2d(self.D*self.F1),
            nn.ELU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(self.drop_prob)
        )

        # Temporal
        self.conv3 = nn.Sequential(
            nn.Conv2d(self.D*self.F1, self.D*self.F1, (1, math.ceil(self.kernel_l//4)), padding=(0, 8), groups=self.D*self.F1, bias=False),
            nn.Conv2d(self.D*self.F1, self.F2, (1, 1), bias=False),
            nn.BatchNorm2d(self.F2),
            nn.ELU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(self.drop_prob)
        )

        #self.classifier = nn.Linear(math.ceil(self.kernel_l/4)* math.ceil(self.tp//32), self.cls, bias=True)
        self.classifier = nn.Linear(self.F2 * math.ceil(self.tp//32), self.cls, bias=True)
        #self.softmax = nn.Softmax()

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(-1, self.F2*math.ceil(self.tp//32))
        x = self.classifier(x)
        #x = self.softmax(x)
        return x

In [None]:
new_sf = t * sf

In [149]:
n_epoch = 200
# learn_r = 0.001
learn_r = 0.0005
n_batch = 32

In [150]:
x_resample = signal.resample(ds, new_sf, axis=2)
y_resample = lbl.astype(np.int64)
x_resample = np.expand_dims(x_resample, axis=1).astype(np.float32)

In [151]:
net = NeuralNetClassifier(
    EEGNet,
    module__in_chn=x_resample.shape[-2],
    module__n_cls=cls_n,
    module__input_ts=x_resample.shape[-1],
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer = torch.optim.AdamW,
    iterator_train__shuffle=True,
    batch_size = n_batch,
    callbacks=[
        EpochScoring(scoring='accuracy', name='train_acc', on_train=True),
        Checkpoint(monitor='valid_loss_best'),  # save based on validation loss
        #EarlyStopping(patience=50, monitor='valid_loss')
    ],
    max_epochs=n_epoch,
    lr=learn_r,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

In [None]:
scores = cross_val_score(net, x_resample, y_resample, groups=grouping, cv=cv, n_jobs=None)

In [103]:
print("Cross-validated accuracy scores:", scores)
print("Mean accuracy:", np.mean(scores))

Cross-validated accuracy scores: [0.40073529 0.25       0.25       0.36029412]
Mean accuracy: 0.31525735294117646


###Shuffle

In [147]:
cv2 = StratifiedShuffleSplit(n_splits)

In [None]:
scores2 = cross_val_score(net, x_resample, y_resample, groups=grouping, cv=cv2, n_jobs=None)

In [106]:
print("Cross-validated accuracy scores:", scores2)
print("Mean accuracy:", np.mean(scores2))

Cross-validated accuracy scores: [0.69724771 0.26605505 0.25688073 0.6146789 ]
Mean accuracy: 0.45871559633027525


#Type2

In [152]:
ds2, lbl2, grouping2 = create_dataset(type_2)

File path: /content/drive/MyDrive/EEG data/Subject 4/record-[2024.03.10]_S8_2.gdf
Average duration of this session trials :  6.27
[4, 5, 2, 6]
File path: /content/drive/MyDrive/EEG data/Subject 4/record-[2024.03.10]_S10_2.gdf
Average duration of this session trials :  6.28
[4, 5, 2, 6]


In [153]:
cv = StratifiedGroupKFold(n_splits)

In [178]:
learn_r = 0.01

In [179]:
x_resample = signal.resample(ds2, new_sf, axis=2)
y_resample = lbl2.astype(np.int64)
x_resample = np.expand_dims(x_resample, axis=1).astype(np.float32)

In [180]:
net = NeuralNetClassifier(
    EEGNet,
    module__in_chn=x_resample.shape[-2],
    module__n_cls=cls_n,
    module__input_ts=x_resample.shape[-1],
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer = torch.optim.AdamW,
    iterator_train__shuffle=True,
    batch_size = n_batch,
    callbacks=[
        EpochScoring(scoring='accuracy', name='train_acc', on_train=True),
        Checkpoint(monitor='valid_loss_best'),  # save based on validation loss
        #EarlyStopping(patience=50, monitor='valid_loss')
    ],
    max_epochs=n_epoch,
    lr=learn_r,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

In [None]:
scores = cross_val_score(net, x_resample, y_resample, groups=grouping2, cv=cv, n_jobs=None)

In [182]:
print("Cross-validated accuracy scores:", scores)
print("Mean accuracy:", np.mean(scores))

Cross-validated accuracy scores: [0.375      0.41666667 0.3287037  0.36111111]
Mean accuracy: 0.3703703703703704


###Shuffle

In [183]:
cv2 = StratifiedShuffleSplit(n_splits)

In [None]:
scores2 = cross_val_score(net, x_resample, y_resample, groups=grouping2, cv=cv2, n_jobs=None)

In [185]:
print("Cross-validated accuracy scores:", scores2)
print("Mean accuracy:", np.mean(scores2))
# --------------------------------------------------------------------deep learning--------------------------------------------------------------------

Cross-validated accuracy scores: [0.4137931  0.8045977  0.42528736 0.8045977 ]
Mean accuracy: 0.6120689655172414
