### Loading Dataset and Visualizing Data
Each .mat file has a four-way tensor electroencephalogram (EEG) data for each subject. 
Please see the reference paper for the detail.

size(eeg) = [Num. of targets, 
Num. of channels, 
Num. of sampling points, 
Num. of trials]

Num. of Targets 	: 12

Num. of Channels 	: 8

Num. of sampling points : 1114

Num. of trials 		: 15


Sampling rate 		: 256 Hz
* The order of the stimulus frequencies in the EEG data: 
[9.25, 11.25, 13.25, 9.75, 11.75, 13.75, 10.25, 12.25, 14.25, 10.75, 12.75, 14.75] Hz
(e.g., eeg(1,:,:,:) and eeg(5,:,:,:) are the EEG data while a subject was gazing at the visual stimuli flickering at 9.25 Hz and 11.75Hz, respectively.)
* The onset of visual stimulation is at 39th sample point.

Reference:
Masaki Nakanishi, Yijun Wang, Yu-Te Wang and Tzyy-Ping Jung,
"A Comparison Study of Canonical Correlation Analysis Based Methods for Detecting Steady-State Visual Evoked Potentials,"
PLoS One, vol.10, no.10, e140703, 2015.

Channel Locations
![image.png](attachment:image.png)

In [298]:
#Standard imports
import scipy.io as sio
from filt_tools import butter_bandpass_filter, butter_notch_filter
from scipy.signal import welch
import matplotlib.pyplot as plt
import numpy as np
import os
%matplotlib qt

In [145]:
subject = "s1"
data_path = f"C:/Users/nelso/Google Drive/Colab Notebooks/SSVEP_CNN_Project/cca_ssvep/{subject}.mat"
stim_freq = {1:9.25,2:11.25,3:13.25,4:9.75,5:11.75,6:13.75,7:10.25,8:12.25,9:14.25,10:10.75,11:12.75,12:14.75}
chan_placement = {1:"PO7",2:"PO3",3:"POz",4:"PO4",5:"PO8",6:"O1",7:"Oz",8:"O2"}

df = sio.loadmat(data_path)

In [431]:
def eeg_loader_loo(data_path,leave_out,fmin=7,fmax=30,num_wind=4,visualize=True):
    stim_freq = {1:9.25,2:11.25,3:13.25,4:9.75,5:11.75,6:13.75,7:10.25,8:12.25,9:14.25,10:10.75,11:12.75,12:14.75}
    eeg_train = np.zeros((1,8,1114-90))
    def data_transform(eeg_dict):
        #Putting all trails in first dimension:
        eeg_trial = np.zeros((1,8,1114))
        for i in range(eeg_dict['eeg'].shape[-1]):
            eeg_trial = np.vstack((eeg_trial,eeg_dict['eeg'][:,:,:,i]))
        eeg_trial = eeg_trial[1:]

        #Filtering EEG data:
        filt_eeg = np.zeros((1,8,1114-90))
        for i in range(eeg_trial.shape[0]):
            filt_eeg_trial = butter_bandpass_filter(eeg_trial[i:i+1,:,:], fmin, fmax, fs=256, order=2)
            filt_eeg_trial = filt_eeg_trial[:,:,90:]
            filt_eeg = np.vstack((filt_eeg,filt_eeg_trial))
        filt_eeg = filt_eeg[1:]
        return filt_eeg
    
    #Iterate through file to extract each filtered eeg trial
    for file, direct, d in os.walk(data_path):
        for sub_f in d:
            if sub_f.endswith('mat') and sub_f[:-4] != leave_out:
                eeg_dict = sio.loadmat(file + sub_f)
                eeg_train= np.vstack((eeg_train,data_transform(eeg_dict)))
            elif sub_f.endswith('mat'):
                eeg_dict = sio.loadmat(file + sub_f)
                eeg_test = data_transform(eeg_dict)
    eeg_train= eeg_train[1:]
    
    #Split Dataset based on number of time windows
    train_wind_splits = np.split(eeg_train,num_wind,-1)
    eeg_train= np.vstack(train_wind_splits)
    test_wind_splits = np.split(eeg_test,num_wind,-1)
    eeg_test = np.vstack(test_wind_splits)
    print("New train shape", eeg_train.shape)
    print("New test shape", eeg_test.shape)
    
    #Create labels, which correspond to the SSVEP frequency of each recording
    y_train = 15*[1,2,3,4,5,6,7,8,9,10,11,12] #labels per participant
    y_train = 9*y_train #labels for all 9 participants
    y_train = 4*y_train #labels for all 1 second time windows
    
    y_test = 15*[1,2,3,4,5,6,7,8,9,10,11,12]
    y_test = 4*y_test
    
    #Randomize order of train and test set
    idx_train = np.random.permutation(eeg_train.shape[0])
#     print(idx_train)
    eeg_train = eeg_train[idx_train,:,:]
    y_train = np.array(y_train)[idx_train]
    
    idx_test = np.random.permutation(eeg_test.shape[0])
#     print(idx_test)
    eeg_test = eeg_test[idx_test,:,:]
    y_test = np.array(y_test)[idx_test]
    
    if visualize == True:
        for i in range(12,eeg_train.shape[0]):
            eeg_plot = eeg_train[i,:,:]
            plt.figure()
            plt.plot(eeg_plot.T)
            plt.title(f'Raw {stim_freq[y_train[i]]}')
            
            df_freq ,df_psd = welch(eeg_plot,fs=256)
            plt.figure()
            plt.plot(df_freq,df_psd.T)
            plt.xlim([7,30])
            plt.title(f'PSD {stim_freq[y_train[i]]}')
            if i >= 24:
                break
    return eeg_train,eeg_test,y_train,y_test

In [468]:
%matplotlib qt
X_train,X_test,y_train,y_test = eeg_loader_loo('C:/Users/nelso/Google Drive/Colab Notebooks/SSVEP_CNN_Project/cca_ssvep/','s1')

New train shape (6480, 8, 256)
New test shape (720, 8, 256)


  plt.figure()


In [497]:
from sklearn.preprocessing import LabelEncoder
from keras.utils import np_utils

#One hot encode y:
encoder = LabelEncoder()
encoder.fit(y_train)
encoded_Y = encoder.transform(y_train)
y_train_encoded = np_utils.to_categorical(encoded_Y)

encoder = LabelEncoder()
encoder.fit(y_test)
encoded_Y = encoder.transform(y_test)
y_test_encoded = np_utils.to_categorical(encoded_Y)

### Model and training models

In [471]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.regularizers import l1_l2
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm
from tensorflow.keras import backend as K

In [458]:
# model = Sequential()

# F1 = 96
# kernLength = 256
# Chans = 8
# D = 1
# dropoutRate = 0.5
# nb_classes = 12
# Samples = 256
# F2 = 96
# model.add(layers.InputLayer(input_shape=(Chans,Samples,1)))
# model.add(layers.Conv2D(F1, (1, kernLength), padding = 'same',
#                             input_shape = (Chans, Samples,1),
#                             use_bias = False))
# model.add(layers.BatchNormalization())
# model.add(layers.DepthwiseConv2D((Chans, 1), use_bias = False, 
#                                    depth_multiplier = D,
#                                    depthwise_constraint = max_norm(1.)))
# model.add(layers.BatchNormalization())

# model.add(layers.Activation('elu'))
# model.add(layers.AveragePooling2D((1, 4)))
# model.add(layers.Dropout(dropoutRate))
# model.add(layers.SeparableConv2D(F2, (1, 16),use_bias = False, 
#                                  padding = 'same'))
# model.add(layers.BatchNormalization())
# model.add(layers.Activation('elu'))
# model.add(layers.AveragePooling2D((1, 8)))
# model.add(layers.Dropout(dropoutRate))
# model.add(layers.Flatten())
# model.add(layers.Dense(nb_classes, name = 'dense',activation='softmax'))

# model.compile(
# optimizer='adam',
# loss='categorical_crossentropy',
# metrics=['accuracy'])

# model.summary()

Model: "sequential_17"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_15 (Conv2D)           (None, 8, 256, 96)        24576     
_________________________________________________________________
batch_normalization_23 (Batc (None, 8, 256, 96)        384       
_________________________________________________________________
depthwise_conv2d_11 (Depthwi (None, 1, 256, 96)        768       
_________________________________________________________________
batch_normalization_24 (Batc (None, 1, 256, 96)        384       
_________________________________________________________________
activation_11 (Activation)   (None, 1, 256, 96)        0         
_________________________________________________________________
average_pooling2d_11 (Averag (None, 1, 64, 96)         0         
_________________________________________________________________
dropout_11 (Dropout)         (None, 1, 64, 96)       

In [493]:
def SSVEPNet(nb_classes, Chans = 8, Samples = 256, 
             dropoutRate = 0.5, kernLength = 64, F1 = 8, 
             D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):
    
    if dropoutType == 'SpatialDropout2D':
        dropoutType = SpatialDropout2D
    elif dropoutType == 'Dropout':
        dropoutType = Dropout
    else:
        raise ValueError('dropoutType must be one of SpatialDropout2D '
                         'or Dropout, passed as a string.')
    
    input1   = Input(shape = (Chans, Samples, 1))
    ##################################################################
    block1       = Conv2D(F1, (1, kernLength), padding = 'same',
                                   input_shape = (Chans, Samples, 1),
                                   use_bias = False)(input1)
    block1       = BatchNormalization()(block1)
    block1       = DepthwiseConv2D((Chans, 1), use_bias = False, 
                                   depth_multiplier = D,
                                   depthwise_constraint = max_norm(1.))(block1)
    block1       = BatchNormalization()(block1)
    block1       = Activation('elu')(block1)
    block1       = AveragePooling2D((1, 4))(block1)
    block1       = dropoutType(dropoutRate)(block1)
    
    block2       = SeparableConv2D(F2, (1, 16),
                                   use_bias = False, padding = 'same')(block1)
    block2       = BatchNormalization()(block2)
    block2       = Activation('elu')(block2)
    block2       = AveragePooling2D((1, 8))(block2)
    block2       = dropoutType(dropoutRate)(block2)
        
    flatten      = Flatten(name = 'flatten')(block2)
    
    dense        = Dense(nb_classes, name = 'dense', 
                         kernel_constraint = max_norm(norm_rate))(flatten)
    softmax      = Activation('softmax', name = 'softmax')(dense)
    
    return Model(inputs=input1, outputs=softmax)

In [501]:
SSVEPModel = SSVEPNet(12)
SSVEPModel.compile(optimizer='adam',
          loss= "categorical_crossentropy",
          metrics=['accuracy'])

In [489]:
SSVEPModel.summary()

Model: "model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_11 (InputLayer)        [(None, 64, 256, 1)]      0         
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 64, 256, 8)        512       
_________________________________________________________________
batch_normalization_38 (Batc (None, 64, 256, 8)        32        
_________________________________________________________________
depthwise_conv2d_16 (Depthwi (None, 1, 256, 16)        1024      
_________________________________________________________________
batch_normalization_39 (Batc (None, 1, 256, 16)        64        
_________________________________________________________________
activation_21 (Activation)   (None, 1, 256, 16)        0         
_________________________________________________________________
average_pooling2d_21 (Averag (None, 1, 64, 16)         0   

In [None]:
number_of_iterations = 1000
batch_size = 32
SSVEPModel.fit(X_train, y_train_encoded, batch_size=32, epochs=number_of_iterations, verbose=2,validation_data=(X_test,y_test_encoded))

Epoch 1/1000
203/203 - 6s - loss: 2.4588 - accuracy: 0.1231 - val_loss: 2.4687 - val_accuracy: 0.1250
Epoch 2/1000
203/203 - 5s - loss: 2.2120 - accuracy: 0.2012 - val_loss: 2.3393 - val_accuracy: 0.1806
Epoch 3/1000
203/203 - 5s - loss: 1.9861 - accuracy: 0.2931 - val_loss: 2.2308 - val_accuracy: 0.2917
Epoch 4/1000
203/203 - 5s - loss: 1.8464 - accuracy: 0.3531 - val_loss: 2.1542 - val_accuracy: 0.3472
Epoch 5/1000
203/203 - 5s - loss: 1.7532 - accuracy: 0.3804 - val_loss: 2.1629 - val_accuracy: 0.2903
Epoch 6/1000
203/203 - 5s - loss: 1.6912 - accuracy: 0.4066 - val_loss: 2.0910 - val_accuracy: 0.3125
Epoch 7/1000
203/203 - 5s - loss: 1.6473 - accuracy: 0.4207 - val_loss: 2.0553 - val_accuracy: 0.3278
Epoch 8/1000
203/203 - 5s - loss: 1.6220 - accuracy: 0.4253 - val_loss: 2.0523 - val_accuracy: 0.3125
Epoch 9/1000
203/203 - 5s - loss: 1.5834 - accuracy: 0.4367 - val_loss: 2.0287 - val_accuracy: 0.3097
Epoch 10/1000
203/203 - 5s - loss: 1.5587 - accuracy: 0.4508 - val_loss: 2.0428 - 

Epoch 81/1000
203/203 - 5s - loss: 1.2813 - accuracy: 0.5235 - val_loss: 1.8826 - val_accuracy: 0.3431
Epoch 82/1000
203/203 - 5s - loss: 1.2980 - accuracy: 0.5248 - val_loss: 1.8873 - val_accuracy: 0.3667
Epoch 83/1000
203/203 - 5s - loss: 1.3016 - accuracy: 0.5235 - val_loss: 1.8761 - val_accuracy: 0.3500
Epoch 84/1000
203/203 - 5s - loss: 1.3002 - accuracy: 0.5287 - val_loss: 1.8654 - val_accuracy: 0.3681
Epoch 85/1000
203/203 - 5s - loss: 1.2806 - accuracy: 0.5347 - val_loss: 1.8435 - val_accuracy: 0.3986
Epoch 86/1000
203/203 - 5s - loss: 1.2923 - accuracy: 0.5221 - val_loss: 1.8634 - val_accuracy: 0.3597
Epoch 87/1000
203/203 - 5s - loss: 1.3065 - accuracy: 0.5198 - val_loss: 1.8594 - val_accuracy: 0.3722
Epoch 88/1000
203/203 - 5s - loss: 1.2874 - accuracy: 0.5323 - val_loss: 1.8594 - val_accuracy: 0.3736
Epoch 89/1000
203/203 - 5s - loss: 1.2822 - accuracy: 0.5390 - val_loss: 1.8525 - val_accuracy: 0.3792
Epoch 90/1000
203/203 - 5s - loss: 1.2952 - accuracy: 0.5282 - val_loss: 

Epoch 160/1000
203/203 - 5s - loss: 1.2678 - accuracy: 0.5323 - val_loss: 1.8405 - val_accuracy: 0.3861
Epoch 161/1000
203/203 - 5s - loss: 1.2592 - accuracy: 0.5338 - val_loss: 1.8519 - val_accuracy: 0.3889
Epoch 162/1000
203/203 - 5s - loss: 1.2696 - accuracy: 0.5281 - val_loss: 1.8058 - val_accuracy: 0.3861
Epoch 163/1000
203/203 - 5s - loss: 1.2583 - accuracy: 0.5360 - val_loss: 1.8396 - val_accuracy: 0.3625
Epoch 164/1000
203/203 - 5s - loss: 1.2562 - accuracy: 0.5392 - val_loss: 1.8194 - val_accuracy: 0.4028
Epoch 165/1000
203/203 - 5s - loss: 1.2598 - accuracy: 0.5367 - val_loss: 1.8369 - val_accuracy: 0.3750
Epoch 166/1000
203/203 - 5s - loss: 1.2572 - accuracy: 0.5363 - val_loss: 1.8526 - val_accuracy: 0.3681
Epoch 167/1000
203/203 - 5s - loss: 1.2618 - accuracy: 0.5424 - val_loss: 1.8400 - val_accuracy: 0.3625
Epoch 168/1000
203/203 - 5s - loss: 1.2552 - accuracy: 0.5373 - val_loss: 1.8333 - val_accuracy: 0.4028
Epoch 169/1000
203/203 - 5s - loss: 1.2569 - accuracy: 0.5426 - 

Epoch 239/1000
203/203 - 5s - loss: 1.2371 - accuracy: 0.5503 - val_loss: 1.8534 - val_accuracy: 0.3750
Epoch 240/1000
203/203 - 5s - loss: 1.2532 - accuracy: 0.5517 - val_loss: 1.8453 - val_accuracy: 0.3708
Epoch 241/1000
203/203 - 5s - loss: 1.2446 - accuracy: 0.5537 - val_loss: 1.8062 - val_accuracy: 0.4000
Epoch 242/1000
203/203 - 5s - loss: 1.2401 - accuracy: 0.5444 - val_loss: 1.7880 - val_accuracy: 0.4139
Epoch 243/1000
203/203 - 5s - loss: 1.2415 - accuracy: 0.5529 - val_loss: 1.8216 - val_accuracy: 0.3833
Epoch 244/1000
203/203 - 5s - loss: 1.2476 - accuracy: 0.5455 - val_loss: 1.8154 - val_accuracy: 0.3972
Epoch 245/1000
203/203 - 5s - loss: 1.2330 - accuracy: 0.5465 - val_loss: 1.7934 - val_accuracy: 0.3681
Epoch 246/1000
203/203 - 5s - loss: 1.2241 - accuracy: 0.5614 - val_loss: 1.8385 - val_accuracy: 0.4014
Epoch 247/1000
203/203 - 5s - loss: 1.2444 - accuracy: 0.5481 - val_loss: 1.8365 - val_accuracy: 0.3806
Epoch 248/1000
203/203 - 5s - loss: 1.2152 - accuracy: 0.5580 - 

Epoch 318/1000
203/203 - 5s - loss: 1.2250 - accuracy: 0.5529 - val_loss: 1.8566 - val_accuracy: 0.3583
Epoch 319/1000
203/203 - 5s - loss: 1.2214 - accuracy: 0.5579 - val_loss: 1.8285 - val_accuracy: 0.3833
Epoch 320/1000
203/203 - 5s - loss: 1.2284 - accuracy: 0.5627 - val_loss: 1.8877 - val_accuracy: 0.3694
Epoch 321/1000
203/203 - 5s - loss: 1.2041 - accuracy: 0.5688 - val_loss: 1.8028 - val_accuracy: 0.3722
Epoch 322/1000
203/203 - 5s - loss: 1.2246 - accuracy: 0.5515 - val_loss: 1.8271 - val_accuracy: 0.3944
Epoch 323/1000
203/203 - 5s - loss: 1.2153 - accuracy: 0.5653 - val_loss: 1.8191 - val_accuracy: 0.3875
Epoch 324/1000
203/203 - 5s - loss: 1.2112 - accuracy: 0.5654 - val_loss: 1.8140 - val_accuracy: 0.3861
Epoch 325/1000
203/203 - 5s - loss: 1.2085 - accuracy: 0.5699 - val_loss: 1.8083 - val_accuracy: 0.3958
Epoch 326/1000
203/203 - 5s - loss: 1.2051 - accuracy: 0.5662 - val_loss: 1.7970 - val_accuracy: 0.3917
Epoch 327/1000
203/203 - 5s - loss: 1.2306 - accuracy: 0.5579 - 