In [23]:
%matplotlib inline
import mne
import os
import pyedflib
import pandas as pd
import numpy as np
import seaborn as sns
from mne.io import read_raw_edf, RawArray, concatenate_raws
from mne.stats import permutation_cluster_1samp_test as pcluster_test
import warnings
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
import autoreject
from autoreject import get_rejection_threshold
from mne.decoding import CSP
from sklearn.model_selection import train_test_split, cross_val_score
import pywt
from scipy.stats import pointbiserialr
from math import sqrt

from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import classification_report,confusion_matrix
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC 
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.naive_bayes import GaussianNB

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D,MaxPooling2D, Flatten
from tensorflow.keras import datasets, layers, models

import mrmr
from mrmr import mrmr_classif

from torch.utils.data import DataLoader, TensorDataset
import torch
from torch import nn

In [24]:
def load_data(subject_id:list, task_id:list, montage_name:str):
    dataset_path = "Online Dataset/eeg-motor-movementimagery-dataset-1.0.0/files"

    # --- Full Path ---
    filenames = []
    for i in range(len(subject_id)):
        for j in range(len(task_id)):
            filenames.append("S"+subject_id[i]+"/S"+subject_id[i]+"R"+task_id[j]+".edf")
            
    path = [os.path.join(dataset_path, filename).replace("\\", "/") for filename in filenames]

    # --- Read EDF Files ---
    subject_raws = []
    for file_path in path:
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=RuntimeWarning)
            data = mne.io.read_raw_edf(file_path, preload=True, verbose=False)
        subject_raws.append(data)

    # print("Path:",path)
    raws_data = concatenate_raws(subject_raws)
    # print("RawEDF data:", raws_data)

    # --- Check Signal Channels ---
    with pyedflib.EdfReader(path[-1]) as edf_file:
        signal_labels = edf_file.getSignalLabels()
        # print("signal_channels:", signal_labels)
    
    with open("Online Dataset/eeg-motor-movementimagery-dataset-1.0.0/files/wfdbcal", "r") as file:
        content = file.readlines()
    
    print(subject_raws)
    
    chan_name = []
    chan_order = []
    chan_mapping = {}
    order = 1
    for line in content:
        parts = line.split('\t')
        channel_name = parts[0].strip()
        channel_name = channel_name.replace(".", "")
        chan_name.append(channel_name)
        order_name = "# " + str(order)
        chan_order.append(order_name)
        chan_mapping[order_name] = channel_name
        order += 1 

    channel_names = [chan_mapping[f'# {i+1}'] for i in range(64)]
    old_ch_names = raws_data.info['ch_names']

    raws_data.rename_channels({old: new for old, new in zip(old_ch_names, channel_names)})

    # Set montage
    # montage = mne.channels.make_standard_montage('standard_1020')
    raws_data.set_montage(montage = mne.channels.make_standard_montage(montage_name))

    # Plot channel locations
    # raws_data.plot_sensors(show_names=True);
    
    # raws_data.compute_psd().plot_topomap();
    # raws_data.compute_psd().plot();
    
    return raws_data

In [25]:
def preprocessing(raws_data, event, chans_selected:list):
    # --- Apply Re-reference by Common Average Reference (CAR) ---
    streams = raws_data.copy().set_eeg_reference('average', projection=True)
    streams.apply_proj()
    
    # --- Filter Data ---
    stream_filter = streams.copy().filter(l_freq=8.0, h_freq=15.0, method = 'iir', iir_params= {"order": 6, "ftype":'butter'})
    
    # --- Apply ICA ---
    ica = mne.preprocessing.ICA(n_components=63, random_state=97, max_iter=800)
    ica.fit(stream_filter.copy())
    
    # --- Find Bad Components ---
    bad_idx, scores = ica.find_bads_eog(stream_filter.copy(), ch_name='T9', threshold=1.5)
    
    # --- Remove Bad Components ---
    ica.exclude = bad_idx
    
    # --- Apply ICA to Filtered Data ---
    stream_ica = ica.apply(stream_filter.copy(),exclude=ica.exclude)
    
    # --- Events ---
    events, event_dict = mne.events_from_annotations(stream_filter)
    
    # --- Epoch ---
    epochs = mne.Epochs(stream_ica.copy().filter(l_freq=8, h_freq=15.0, method = 'iir', iir_params= {"order": 6, "ftype":'butter'}), events, tmin = -0.5, tmax = 4, 
                    event_id = event, preload= True, verbose=False, event_repeated='drop')
    
    # --- Selected channels are interested ---
    epochs = epochs.pick_channels(chans_selected)
    
    # --- Baseline Correction ---
    Baseline = epochs.copy().filter(l_freq=8.0, h_freq=15.0, method = 'iir', iir_params= {"order": 6, "ftype":'butter'})
    stream_mi = Baseline.copy().apply_baseline((-0.5, 0))
    
    # --- reject bad channels ---
    def autoreject_epochs(epochs):
        reject = get_rejection_threshold(epochs)  
        reject.update(reject)
        epochs.drop_bad(reject = reject)
        return epochs
    stream_mi = autoreject_epochs(stream_mi.copy())
    
    return stream_mi

In [26]:
def extract_CSP(epochs, n_component = 7):
    X = epochs.get_data()
    y = epochs.events[:, -1]
    
    # Initilize CSP
    csp = CSP(n_components = n_component, norm_trace=False)
    csp_wt = CSP(n_components = n_component, reg=None, log=None, norm_trace=False, transform_into='csp_space')
    
    # Fit CSP to data 
    csp.fit(X, y)
    csp_wt.fit(X, y)
    
    new_data = csp_wt.transform(X)
    
    # Visualize CSP patterns
    csp.plot_patterns(epochs.info);
    print(csp)
    return X, y, csp, csp_wt, new_data

In [27]:
def extract_WT(epochs, target , n_component = 7):
    epochs_data = epochs.get_data()
    labels = epochs.events[:,-1]
    
    train_size = len(labels)
    train_data_cwt = np.ndarray(shape=(train_size ,epochs_data.shape[1] , 8, epochs_data.shape[2],))
    
    scales = range(8,16)

    X_mi, y_mi, csp, csp_wt, data_csp = extract_CSP(epochs.copy(), n_component)
    
    for ii in range(train_size):
        for jj in range(n_component):
            signal = data_csp[ii, jj, :]
            coeff, _ = pywt.cwt(signal, scales, 'morl', 1)
            coeff_ = coeff[:, :epochs_data.shape[2]]  # (8,epochs_data.shape[2])
            train_data_cwt[ii, jj, :, :] = np.abs(coeff_)  

    for j in range(4):
        sample = train_data_cwt[j]
        # Create subplots for each channel with a larger figure size
        fig, axs = plt.subplots(1, n_component, figsize=(40, 5))  # Increase the width (20) to make the figure larger
        for i in range(n_component):
            axs[i].imshow(sample[i], aspect='auto', cmap='viridis')
            axs[i].set_title(f'Channel {i+1}')
        fig.suptitle(f'Class {labels[j]}', fontsize=16)
        plt.show()

    wt_shape = train_data_cwt.shape
    data_WT = np.reshape(train_data_cwt,(wt_shape[0],wt_shape[1]*wt_shape[2]*wt_shape[3]))
    
    return train_data_cwt, data_WT, labels, y_mi

In [28]:
# --- Data Details --- 
subject_id = ["001"]
task_id = ["04", "08", "12"]

# --- Set Montage ---
montage_name = 'standard_1020'

# --- Set Channels Select ---
chans_selected = ['Fc5', 'Fc3', 'Fc1', 'Fcz', 'Fc2', 'Fc4', 'Fc6', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'Cp5', 'Cp3', 'Cp1', 'Cpz', 'Cp2', 'Cp4', 'Cp6', 'Fp1', 'Fpz', 'Fp2', 'Af7', 'Af3', 'Afz', 'Af4', 'Af8', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'Ft7', 'Ft8', 'T7', 'T8', 'Tp7', 'Tp8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'Po7', 'Po3', 'Poz', 'Po4', 'Po8', 'O1', 'Oz', 'O2', 'Iz']

#['C5','C3','C1','Cz','C2','C4','C6']

# --- Set event ---
# event = {'rest':1, 'left': 2, 'right': 3}
# target = ['rest','left','right']
# numclass = [1,2,3]

event = {'left': 2, 'right': 3}
target = ['left','right']
numclass = [2,3]

In [29]:
np.array(chans_selected).shape

(62,)

In [30]:
raws_data = load_data(subject_id, task_id, montage_name)
epochs = preprocessing(raws_data, event, chans_selected)

[<RawEDF | S001R04.edf, 64 x 60000 (375.0 s), ~29.4 MB, data loaded>, <RawEDF | S001R08.edf, 64 x 20000 (125.0 s), ~9.8 MB, data loaded>, <RawEDF | S001R12.edf, 64 x 20000 (125.0 s), ~9.8 MB, data loaded>]
EEG channel type selected for re-referencing
Adding average EEG reference projection.
1 projection items deactivated
Average reference projection was added, but has not been applied yet. Use the apply_proj method to apply it.
Created an SSP operator (subspace dimension = 1)
1 projection items activated
SSP projectors applied...
Filtering raw data in 3 contiguous segments
Setting up band-pass filter from 8 - 15 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 24 (effective, after forward-backward)
- Cutoffs at 8.00, 15.00 Hz: -6.02, -6.02 dB

Fitting ICA to data using 64 channels (please be patient, this may take a while)
    Applying projection operator with 1 vector (pre-whitener computat

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s


... filtering target
Setting up band-pass filter from 1 - 10 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed frequency-domain design (firwin2) method
- Hann window
- Lower passband edge: 1.00
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.75 Hz)
- Upper passband edge: 10.00 Hz
- Upper transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 10.25 Hz)
- Filter length: 1600 samples (10.000 s)



[Parallel(n_jobs=1)]: Done  63 out of  63 | elapsed:    0.1s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s finished


Applying ICA to Raw instance
    Applying projection operator with 1 vector (pre-whitener application)
    Transforming to ICA space (63 components)
    Zeroing out 14 ICA components
    Projecting back using 64 PCA components
Used Annotations descriptions: ['T0', 'T1', 'T2']
Filtering raw data in 3 contiguous segments
Setting up band-pass filter from 8 - 15 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 24 (effective, after forward-backward)
- Cutoffs at 8.00, 15.00 Hz: -6.02, -6.02 dB

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Setting up band-pass filter from 8 - 15 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 24 (effective, after forward-backward)
- Cutoffs at 8.00, 15.00 Hz: -6.02, -6.02 dB

Applying baseline correction (mode: mean)
Estimating reject

In [31]:
data = epochs.get_data()

In [32]:
train_X = np.expand_dims(np.moveaxis(data, 0, 0), -1)
train_X.shape

(45, 34, 721, 1)

In [33]:
from sklearn.preprocessing import OneHotEncoder

encoder = OneHotEncoder(sparse=False)

y = epochs.events[:, -1]
y = encoder.fit_transform(np.array(y).reshape(-1, 1))
y.shape

(45, 2)

In [34]:
X_train, X_test, y_train, y_test = train_test_split(train_X, y, test_size=0.35, random_state=42)

In [35]:
X_train.shape, X_test.shape

((29, 34, 721, 1), (16, 34, 721, 1))

In [36]:
y_train.shape, y_test.shape

((29, 2), (16, 2))

In [37]:
class MSNN(tf.keras.Model):
    tf.keras.backend.set_floatx("float64")
    def __init__(self):
        super(MSNN, self).__init__()
        self.C = 34 # the number of electrodes
        self.fs = 160 # the sampling frequency

        # Regularizer
        self.regularizer = tf.keras.regularizers.L1L2(l1=.001, l2=.01)

        # Activation functions
        self.activation = tf.keras.layers.LeakyReLU()
        self.softmax = tf.keras.layers.Softmax()
        
        # Define convolutions
        conv = lambda D, kernel : tf.keras.layers.Conv2D(D, kernel, kernel_regularizer=self.regularizer)
        sepconv = lambda D, kernel : tf.keras.layers.SeparableConv2D(D, kernel, padding="same",
                                                                    depthwise_regularizer=self.regularizer,
                                                                    pointwise_regularizer=self.regularizer)
        
        # Spectral convoltuion
        self.conv0 = conv(4, (1, int(self.fs/2)))
        
        # Spatio-temporal convolution
        self.conv1t = sepconv(16, (1, 25))
        self.conv1s = conv(16, (self.C, 1))
        
        self.conv2t = sepconv(32, (1, 15))
        self.conv2s = conv(32, (self.C, 1))
        
        self.conv3t = sepconv(64, (1, 6))
        self.conv3s = conv(64, (self.C, 1))

        # Flatteninig
        self.flatten = tf.keras.layers.Flatten()

        # Dropout
        self.dropout = tf.keras.layers.Dropout(0.5)

        # Decision making
        self.dense = tf.keras.layers.Dense(2, activation=None, kernel_regularizer=self.regularizer)

    def embedding(self, x, random_mask=False):
        x = self.activation(self.conv0(x))

        x = self.activation(self.conv1t(x))
        f1 = self.activation(self.conv1s(x))

        x = self.activation(self.conv2t(x))
        f2 = self.activation(self.conv2s(x))

        x = self.activation(self.conv3t(x))
        f3 = self.activation(self.conv3s(x))

        feature = tf.concat((f1, f2, f3), -1)
        return feature

    def classifier(self, feature):
        # Flattening, dropout, mapping into the decision nodes
        feature = self.flatten(feature)
        feature = self.dropout(feature)
        y_hat = self.softmax(self.dense(feature))
        return y_hat

    def GAP(self, feature):
        return tf.reduce_mean(feature, -2)

    def call(self, x):
        # Extract feature using MSNN encoder
        feature = self.embedding(x)

        # Global Average Pooling
        feature = self.GAP(feature)

        # Decision making
        y_hat = self.classifier(feature)
        return y_hat

In [39]:
class Shallow_convnet(tf.keras.Model):
    tf.keras.backend.set_floatx("float64")
    def __init__(self):
        super(Shallow_convnet, self).__init__()
        self.C = 34 # the number of electrodes
        self.fs = 160 # the sampling frequency

        # Regularizer
        self.regularizer = tf.keras.regularizers.L1L2(l1=.001, l2=.01)

        # Activation functions
        self.activation = tf.keras.layers.LeakyReLU()
        self.softmax = tf.keras.layers.Softmax()
        
        # Define convolutions
        conv = lambda D, kernel : tf.keras.layers.Conv2D(D, kernel, kernel_regularizer=self.regularizer)
        sepconv = lambda D, kernel : tf.keras.layers.SeparableConv2D(D, kernel, padding="same",
                                                                    depthwise_regularizer=self.regularizer,
                                                                    pointwise_regularizer=self.regularizer)
        pool = lambda p_size : tf.keras.layers.MaxPool2D(p_size, strides=None, padding='valid',
                                                 data_format=None, name=None)
        
        # Spectral convoltuion
        self.conv0 = conv(16, (1, int(self.fs/2)))
        
        # Spatio-temporal convolution
        # self.conv1t = sepconv(16, (1, 25))
        self.conv1s = conv(16, (self.C, 1))
        
        # pooling layer
        self.pooling0 = pool((2, 2))
        
        # Flatteninig
        self.flatten = tf.keras.layers.Flatten()

        # Dropout
        self.dropout = tf.keras.layers.Dropout(0.5)

        # Decision making
        self.dense = tf.keras.layers.Dense(2, activation=None, kernel_regularizer=self.regularizer)

    def embedding(self, x, random_mask=False):
        x = self.activation(self.conv0(x))

        # x = self.activation(self.conv1t(x))
        f1 = self.activation(self.conv1s(x))

        feature = tf.concat((f1), -1)
        return feature

    def classifier(self, feature):
        # Flattening, dropout, mapping into the decision nodes
        feature = self.flatten(feature)
        feature = self.dropout(feature)
        y_hat = self.softmax(self.dense(feature))
        return y_hat

    def GAP(self, feature):
        return tf.reduce_mean(feature, -2)

    def call(self, x):
        # Extract feature using MSNN encoder
        feature = self.embedding(x)

        # Global Average Pooling
        feature = self.GAP(feature)

        # Decision making
        y_hat = self.classifier(feature)
        return y_hat

In [40]:
# Define the informative segments selection agent module.
# Define actor network (for categorical actions: selection/rejection)
class ACTOR(tf.keras.Model):
    def __init__(self, n_actions=2):
        super().__init__()
        self.actor = tf.keras.layers.Dense(n_actions, activation=None, 
                                          kernel_regularizer=tf.keras.regularizers.L1L2(l1=.001, l2=.01))        
    def call(self, segment):
        return self.actor(segment) # Outputs logit vector.
    
# Define critic network
class CRITIC(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.critic = tf.keras.layers.Dense(1, activation=None,
                                           kernel_regularizer=tf.keras.regularizers.L1L2(l1=.001, l2=.01))
    def call(self, segment):
        return tf.keras.activations.sigmoid(self.critic(segment))

In [41]:
# Define utility functions.
def gradient(model, inputs, labels, mask=None):
    with tf.GradientTape() as tape:
        if mask is None:
            yhat = model(inputs)
        else:
            feature = model.GAP(model.embedding(inputs) * mask)
            yhat = model.classifier(feature)

        loss = tf.keras.losses.binary_crossentropy(labels, yhat)

    grad = tape.gradient(loss, model.trainable_variables)
    return loss, grad

def agent_gradient(model, actor, critic, inputs, feature, labels, state, state_next):
    gamma = 0.95 # discount factor
    with tf.GradientTape() as tape1, tf.GradientTape() as tape2:
        loss_FM = tf.keras.losses.binary_crossentropy(labels, model(inputs))
        loss_AM = tf.keras.losses.binary_crossentropy(labels, model.classifier(feature))

        # Reward, r_t
        reward = loss_FM - loss_AM
        # Advantage, A_t
        advantage = reward[:, None] + gamma * critic(state_next) - critic(state)            
        # Critic loss
        critic_loss = 0.5 * tf.math.square(advantage)            
        # Actor loss
        actor_loss = -tf.math.log(tf.nn.softmax(actor(state))) * advantage

    critic_grad = tape1.gradient(critic_loss, critic.trainable_variables)
    # print("critic_grad: ", critic_grad)
    actor_grad = tape2.gradient(actor_loss, actor.trainable_variables)
    return critic_loss, critic_grad, actor_loss, actor_grad

In [20]:
# Define experiment conducting class.
# Here, we trained and tested MSNN without the proposed agent module.
# msnn
class experiment():
    def __init__(self, train_X, train_Y, test_X, test_Y):
        # Load dataset.
        # For simplicity, we just removed validating phase here.
        self.Xtr, self.Ytr = train_X, train_Y
        self.Xts, self.Yts = test_X, test_Y
        self.Yts = np.argmax(self.Yts, axis=-1) # To use scikit-learn accuracy function
        
        # Randomize the training dataset.
        rand_idx = np.random.permutation(self.Xtr.shape[0])
        self.Xtr, self.Ytr = self.Xtr[rand_idx, :, :, :], self.Ytr[rand_idx, :]

        # Learning schedules
        self.init_LR = 1e-3
        self.num_epochs_pre = 20 # Pre-training epochs
        self.num_epochs = 30
        self.num_batch = 20
        self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=self.init_LR)
        
        # Here, we used subject 1's 2nd session data.
        # self.sbj_idx, self.sess_idx = 1, 2
        # print(f"START TRAINING Subject {self.sbj_idx}, Session {self.sess_idx}")
        
        # Call optimizer.
        self.num_batch_iter = int(self.Xtr.shape[0]/self.num_batch)
        
    def training_FM(self):
        # Call MSNN.
        msnn = MSNN()
        
        # To record the loss curve.
        loss_FM = []
        for epoch in range(self.num_epochs):
            loss_per_epoch = 0

            for batch in range(self.num_batch_iter):
                # Sample minibatch.
                xb = self.Xtr[batch * self.num_batch : (batch + 1) * self.num_batch, :, :, :]
                yb = self.Ytr[batch * self.num_batch : (batch + 1) * self.num_batch, :]

                # Estimate loss
                loss, grads = gradient(msnn, xb, yb)

                # Update the parameters
                self.optimizer.apply_gradients(zip(grads, msnn.trainable_variables))
                loss_FM.append(np.mean(loss))
                loss_per_epoch += np.mean(loss)

            loss_per_epoch /= self.num_batch_iter

            # Reporting
            print(f"Iteration {epoch + 1}, Training Loss {loss_per_epoch:>.04f}")
            
        # Test the learned model.
        Yts_hat = np.argmax(msnn(X_test), axis=-1)
        # print(f"\nSubject {self.sbj_idx}, Session {self.sess_idx},\
        print(f"Testing accuracy: {accuracy_score(self.Yts, Yts_hat)}!\n")
        return loss_FM
    
    def training_AM(self):
        # Call MSNN.
        msnn = MSNN()
        
        # To record the loss curve.
        loss_AM = []
        # Pre-training without the agent module
        for epoch in range(self.num_epochs_pre):
            loss_per_epoch = 0

            for batch in range(self.num_batch_iter):
                # Sample minibatch.
                xb = self.Xtr[batch * self.num_batch : (batch + 1) * self.num_batch, :, :, :]
                yb = self.Ytr[batch * self.num_batch : (batch + 1) * self.num_batch, :]

                # Estimate loss
                loss, grads = gradient(msnn, xb, yb)

                # Update the parameters
                self.optimizer.apply_gradients(zip(grads, msnn.trainable_variables))
                loss_AM.append(np.mean(loss))
                loss_per_epoch += np.mean(loss)

            loss_per_epoch /= self.num_batch_iter

            # Reporting
            print(f"Iteration {epoch + 1}, Training Loss {loss_per_epoch:>.04f}")
            
        # Call agent module.
        actor = ACTOR()
        critic = CRITIC()
        
        # Training with the agent module
        for epoch in range(self.num_epochs - self.num_epochs_pre):
            loss_per_epoch = 0
            
            for batch in range(self.num_batch_iter):
                # Sample minibatch.
                xb = self.Xtr[batch * self.num_batch : (batch + 1) * self.num_batch, :, :, :]
                yb = self.Ytr[batch * self.num_batch : (batch + 1) * self.num_batch, :]
                # Extract full segments.
                features = msnn.embedding(xb)
                
                agg_wo_current = np.zeros((self.num_batch, features.shape[-1]))
                num_added = np.zeros((self.num_batch, features.shape[-1])) # To estimate the denominator.
                mask = np.zeros(features.shape) # Mask generated by the agent module
                for t in range(features.shape[-2] - 1): # t = 1,...,T'
                    print('epoch :',epoch, 'batch :', batch, 't :', t)
                    deno1 = np.copy(num_added)
                    deno2 = np.copy(num_added) + 1 # For the features with the current segment.
                    # To avoid zero-division.
                    deno1[deno1 == 0] = 1.
                    
                    agg_w_current = agg_wo_current + features[:, 0, t, :]
                    
                    # Define state, s_t.
                    state = np.concatenate((agg_wo_current/deno1, agg_w_current/deno2), axis=-1)
                    # Get action, a_t.
                    action_probs = actor(state)
                    action = np.tile(tf.random.categorical(action_probs, 1).numpy(), 112) # (5, 112)
                    mask[:, 0, t, :] = action
                    num_added += action
                    
                    # Current feature after action decision, phi_t.
                    deno3 = np.copy(num_added)
                    deno3[deno3 == 0] = 1 # To avoid zero-division.
                    feature = (agg_wo_current + features[:, 0, t, :] * action)/deno3
                    
                    # Define next state, s_{t+1}, temporally.
                    agg_wo_current = feature
                    tmp = agg_wo_current + features[:, 0, t + 1, :]
                    state_next = np.concatenate((agg_wo_current/deno3, tmp/(deno3 + 1)), axis=-1)

                    # Calculate critic and actor loss values
                    critic_loss, critic_grads, actor_loss, actor_grads =\
                    agent_gradient(msnn, actor, critic, xb, feature, yb, state, state_next)
                    
                    # print(actor_grads)
                    
                    self.optimizer.apply_gradients(zip(critic_grads, critic.trainable_variables))
                    # print("critic :",critic.trainable_variables)
                    self.optimizer.apply_gradients(zip(actor_grads, actor.trainable_variables))
                                        
                # Finally, predict labels of input EEG using the selected segments.                
                # Update the parameters
                loss, grads = gradient(msnn, xb, yb, mask)
                self.optimizer.apply_gradients(zip(grads, msnn.trainable_variables))
                loss_AM.append(np.mean(loss))
                loss_per_epoch += np.mean(loss)
                
            loss_per_epoch /= self.num_batch_iter

            # Reporting
            print(f"Iteration {epoch + 1 + self.num_epochs_pre}, Training Loss {loss_per_epoch:>.04f}")
        
        # Test the learned model.
        Yts_hat = np.argmax(msnn(X_test), axis=-1)
        # print(f"\nSubject {self.sbj_idx}, Session {self.sess_idx}, \
        print(f"\nTesting accuracy: {accuracy_score(self.Yts, Yts_hat)}!\n")
        return loss_AM
        
exp = experiment(X_train, y_train, X_test, y_test)
loss_FM = exp.training_FM()
loss_AM = exp.training_AM()

Iteration 1, Training Loss 0.6931
Iteration 2, Training Loss 0.6877
Iteration 3, Training Loss 0.6751
Iteration 4, Training Loss 0.6758
Iteration 5, Training Loss 0.6760


KeyboardInterrupt: 

In [43]:
# Define experiment conducting class.
# Here, we trained and tested MSNN without the proposed agent module.
# shallow_convnet
class experiment():
    def __init__(self, train_X, train_Y, test_X, test_Y):
        # Load dataset.
        # For simplicity, we just removed validating phase here.
        self.Xtr, self.Ytr = train_X, train_Y
        self.Xts, self.Yts = test_X, test_Y
        self.Yts = np.argmax(self.Yts, axis=-1) # To use scikit-learn accuracy function
        
        # Randomize the training dataset.
        rand_idx = np.random.permutation(self.Xtr.shape[0])
        self.Xtr, self.Ytr = self.Xtr[rand_idx, :, :, :], self.Ytr[rand_idx, :]

        # Learning schedules
        self.init_LR = 1e-3
        self.num_epochs_pre = 20 # Pre-training epochs
        self.num_epochs = 30
        self.num_batch = 20
        self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=self.init_LR)
        
        # Here, we used subject 1's 2nd session data.
        # self.sbj_idx, self.sess_idx = 1, 2
        # print(f"START TRAINING Subject {self.sbj_idx}, Session {self.sess_idx}")
        
        # Call optimizer.
        self.num_batch_iter = int(self.Xtr.shape[0]/self.num_batch)
        
    def training_FM(self):
        # Call MSNN.
        shallow = Shallow_convnet()
        
        # To record the loss curve.
        loss_FM = []
        for epoch in range(self.num_epochs):
            loss_per_epoch = 0

            for batch in range(self.num_batch_iter):
                # Sample minibatch.
                xb = self.Xtr[batch * self.num_batch : (batch + 1) * self.num_batch, :, :, :]
                yb = self.Ytr[batch * self.num_batch : (batch + 1) * self.num_batch, :]

                # Estimate loss
                loss, grads = gradient(shallow, xb, yb)

                # Update the parameters
                self.optimizer.apply_gradients(zip(grads, shallow.trainable_variables))
                loss_FM.append(np.mean(loss))
                loss_per_epoch += np.mean(loss)

            loss_per_epoch /= self.num_batch_iter

            # Reporting
            print(f"Iteration {epoch + 1}, Training Loss {loss_per_epoch:>.04f}")
            
        # Test the learned model.
        Yts_hat = np.argmax(shallow(X_test), axis=-1)
        # print(f"\nSubject {self.sbj_idx}, Session {self.sess_idx},\
        print(f"Testing accuracy: {accuracy_score(self.Yts, Yts_hat)}!\n")
        return loss_FM
    
    def training_AM(self):
        # Call MSNN.
        shallow = Shallow_convnet()
        
        # To record the loss curve.
        loss_AM = []
        # Pre-training without the agent module
        for epoch in range(self.num_epochs_pre):
            loss_per_epoch = 0

            for batch in range(self.num_batch_iter):
                # Sample minibatch.
                xb = self.Xtr[batch * self.num_batch : (batch + 1) * self.num_batch, :, :, :]
                yb = self.Ytr[batch * self.num_batch : (batch + 1) * self.num_batch, :]

                # Estimate loss
                loss, grads = gradient(shallow, xb, yb)

                # Update the parameters
                self.optimizer.apply_gradients(zip(grads, shallow.trainable_variables))
                loss_AM.append(np.mean(loss))
                loss_per_epoch += np.mean(loss)

            loss_per_epoch /= self.num_batch_iter

            # Reporting
            print(f"Iteration {epoch + 1}, Training Loss {loss_per_epoch:>.04f}")
            
        # Call agent module.
        actor = ACTOR()
        critic = CRITIC()
        
        # Training with the agent module
        for epoch in range(self.num_epochs - self.num_epochs_pre):
            loss_per_epoch = 0
            
            for batch in range(self.num_batch_iter):
                # Sample minibatch.
                xb = self.Xtr[batch * self.num_batch : (batch + 1) * self.num_batch, :, :, :]
                yb = self.Ytr[batch * self.num_batch : (batch + 1) * self.num_batch, :]
                # Extract full segments.
                features = shallow.embedding(xb)
                
                agg_wo_current = np.zeros((self.num_batch, features.shape[-1]))
                num_added = np.zeros((self.num_batch, features.shape[-1])) # To estimate the denominator.
                mask = np.zeros(features.shape) # Mask generated by the agent module
                for t in range(features.shape[-2] - 1): # t = 1,...,T'
                    print('epoch :',epoch, 'batch :', batch, 't :', t)
                    deno1 = np.copy(num_added)
                    deno2 = np.copy(num_added) + 1 # For the features with the current segment.
                    # To avoid zero-division.
                    deno1[deno1 == 0] = 1.
                    
                    agg_w_current = agg_wo_current + features[:, 0, t, :]
                    
                    # Define state, s_t.
                    state = np.concatenate((agg_wo_current/deno1, agg_w_current/deno2), axis=-1)
                    # Get action, a_t.
                    action_probs = actor(state)
                    action = np.tile(tf.random.categorical(action_probs, 1).numpy(), 16) # (5, 112)
                    mask[:, 0, t, :] = action
                    num_added += action
                    
                    # Current feature after action decision, phi_t.
                    deno3 = np.copy(num_added)
                    deno3[deno3 == 0] = 1 # To avoid zero-division.
                    feature = (agg_wo_current + features[:, 0, t, :] * action)/deno3
                    
                    # Define next state, s_{t+1}, temporally.
                    agg_wo_current = feature
                    tmp = agg_wo_current + features[:, 0, t + 1, :]
                    state_next = np.concatenate((agg_wo_current/deno3, tmp/(deno3 + 1)), axis=-1)

                    # Calculate critic and actor loss values
                    critic_loss, critic_grads, actor_loss, actor_grads =\
                    agent_gradient(shallow, actor, critic, xb, feature, yb, state, state_next)
                    
                    # print(actor_grads)
                    
                    self.optimizer.apply_gradients(zip(critic_grads, critic.trainable_variables))
                    # print("critic :",critic.trainable_variables)
                    self.optimizer.apply_gradients(zip(actor_grads, actor.trainable_variables))
                                        
                # Finally, predict labels of input EEG using the selected segments.                
                # Update the parameters
                loss, grads = gradient(shallow, xb, yb, mask)
                self.optimizer.apply_gradients(zip(grads, shallow.trainable_variables))
                loss_AM.append(np.mean(loss))
                loss_per_epoch += np.mean(loss)
                
            loss_per_epoch /= self.num_batch_iter

            # Reporting
            print(f"Iteration {epoch + 1 + self.num_epochs_pre}, Training Loss {loss_per_epoch:>.04f}")
        
        # Test the learned model.
        Yts_hat = np.argmax(shallow(X_test), axis=-1)
        # print(f"\nSubject {self.sbj_idx}, Session {self.sess_idx}, \
        print(f"\nTesting accuracy: {accuracy_score(self.Yts, Yts_hat)}!\n")
        return loss_AM
        
exp = experiment(X_train, y_train, X_test, y_test)
loss_FM = exp.training_FM()
loss_AM = exp.training_AM()

Iteration 1, Training Loss 0.6931
Iteration 2, Training Loss 0.6918
Iteration 3, Training Loss 0.6902
Iteration 4, Training Loss 0.6888
Iteration 5, Training Loss 0.6882
Iteration 6, Training Loss 0.6881
Iteration 7, Training Loss 0.6881
Iteration 8, Training Loss 0.6881
Iteration 9, Training Loss 0.6881
Iteration 10, Training Loss 0.6881
Iteration 11, Training Loss 0.6881
Iteration 12, Training Loss 0.6881
Iteration 13, Training Loss 0.6881
Iteration 14, Training Loss 0.6881
Iteration 15, Training Loss 0.6881
Iteration 16, Training Loss 0.6881
Iteration 17, Training Loss 0.6881
Iteration 18, Training Loss 0.6881
Iteration 19, Training Loss 0.6881
Iteration 20, Training Loss 0.6881
Iteration 21, Training Loss 0.6881
Iteration 22, Training Loss 0.6881
Iteration 23, Training Loss 0.6881
Iteration 24, Training Loss 0.6881
Iteration 25, Training Loss 0.6881
Iteration 26, Training Loss 0.6881
Iteration 27, Training Loss 0.6881
Iteration 28, Training Loss 0.6881
Iteration 29, Training Loss 0

KeyboardInterrupt: 