# SSVEP Decoding Framework

I want to build a general pipeline for SSVEP EEG signal decoding, in which it is easy to complete basic EEG processing stages like:

* cutting and slicing your data
* filtering data
* applying feature extraction method
* matching pattern and get result

The key inspiration of the framework design ethic is modular, and I want to take the manually parts like experimental information, special filters and feature extraction methods out of the main executing. So that it's much easier to re-write or add new minds into the framework. In the other words, it is a framework that flexible, easy for new learner and try something new.

There are 3 main class defined in this framework. If you just focus on get model for online experiment or do some cross validation, the `data_runner` class and `data_cross_validation` class are what you need, you can just read through and run them. In the `filter_apply` class, you can configure your own time-filter parameters, and what I must admit is that the filter parameters in the current version is not the best, and haven't been optimized at all! 

You may notice that there are several functions above the main classes, they can be named as helper functions. I can move them together and separate them to a helper collection and make the main program clear.

## Import necessary external packages here

In this framework, I try to use external packages as little as possible. Compared to use the toolbox like mne, it's a bit of troublesome, but not too much. Jump out of the mne processing and data framework can make you understand the data route more clearly.

However, if you like, you can feel free add some packages for boosting the function of the framework.

In [133]:
from os.path import join as pjoin
import numpy as np
import scipy.signal
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from scipy.io import loadmat
import json
from collections import defaultdict
from matplotlib import pyplot as plt

In [134]:
def trca_matrix(X):
    """ TRCA kernel function

    Args:
        X (Numpy array): Three dimsional ndarry matrix, shape as (n_trials, n_channels, n_samples)

    Returns:
        spatial_filter(Numpy array): A ndarray vector, shape as (n_channels, )
    
    REF:
    [1] M. Nakanishi, Y. Wang, X. Chen, Y. -T. Wang, X. Gao and T. -P. Jung, "Enhancing Detection of SSVEPs for a High-Speed Brain Speller Using Task-Related Component Analysis," in IEEE Transactions on Biomedical Engineering, vol. 65, no. 1, pp. 104-112, Jan. 2018, doi: 10.1109/TBME.2017.2694818.
    """
    n_chans = X.shape[1]
    n_trial = X.shape[0]
    S = np.zeros((n_chans, n_chans))
    # Computation of correlation matrices
    for trial_i in range(n_trial):
        for trial_j in range(n_trial):
            x_i = X[trial_i, :, :]
            x_j = X[trial_j, :, :]
            S = S + np.dot(x_i, x_j.T)
    X = np.transpose(X, (1, 2, 0))
    X1 = X.reshape((n_chans, -1),order='F')
    X1 = X1 - np.mean(X1, axis=1, keepdims=True)
    Q = np.dot(X1, X1.T)
    S = np.matrix(S)
    Q = np.matrix(Q)
    # TRCA eigenvalue algorithm
    [W, V] = np.linalg.eig(np.dot(Q.I,S))

    spatial_filter = V[:, 0].reshape(-1)

    return spatial_filter

In [135]:
def train_test_split(split_type='K_Fold',test_sample_num = None, total_sample_num = None):
    """[summary]

    Args:
        split_type (str, optional): Key words of the cross_validation method. Defaults to 'K_Fold'.
        test_sample_num (int, optional): the number of test samples (trials) contained in one cross validatioon seperation. Defaults to None.
        total_sample_num (int, optional): total samples (trials) of your dataset. Defaults to None.

    Raises:
        Exception: The sample cannot be divided into test subsets
        Exception: Split_type was not defined, use 'K_Fold' or define it manually

    Returns:
        numpy array: A matrix shape as (total_sample_num/test_sample_num, total_sample_num)
    """
    if split_type == "K_Fold":
        if total_sample_num%test_sample_num != 0:
            raise Exception('The sample cannot be divided into test subsets')
        else:
            folders_num = int(total_sample_num/test_sample_num)
            split_index = np.zeros((folders_num,total_sample_num))
            for folder_index in range(folders_num):
                split_index[folder_index, folder_index*test_sample_num:(folder_index+1)*test_sample_num] = np.ones((test_sample_num,))
        return split_index
    
    raise Exception('split_type was not defined, use \'K_Fold \' or define it manually')

In [145]:
# Test function for train_test_split()
split_index = train_test_split(test_sample_num=5,total_sample_num=30)

In [137]:
def pattern_match(testsample,spatial_filters,template_storage):
    """ 
    Function for pattern matching, input a single trial test sample, then filtering with the pre-trained spatial filter, finally calculated person correlation with all the template signal(using np.corrcoef method)

    Args:
        testsample (Numpy array): A numpy array of single test sample, shape as (N_chan,N_samples)
        spatial_filters (dict): spatial filter dictionary storage, the keys are the trial indexes, values are corresponding spatial filters which trained from feature extracted method 
        template_storage (dict): template dictionary storage, the keys are the trial indexes, values are corrsponding spatial filters which calculated from train samples.

    Raises:
        TypeError: Testsample should be a two dimensional vector

    Returns:
        Numpy array: correlation coefficient storage vector, shape as (N_trials,1) 
    """
    if len(testsample.shape) != 2:
        raise TypeError('Testsample should be a two dimensional vector')
    corrcoef_storage = list()
    spatial_filter = np.squeeze(np.array(list(spatial_filters.values())))
    for template_iter in template_storage.keys():
        corrcoef_storage.append(np.corrcoef(np.dot(spatial_filter,testsample).reshape(1,-1),np.dot(spatial_filter,template_storage[template_iter]).reshape(1,-1))[0,1])
    
    corrcoef_storage = np.array(corrcoef_storage)

    return corrcoef_storage

In [138]:
class filter_applyer():
    """
    A filter warp class, you can design and build time-filter by modify or inherit and rewrite the _filter_design method. With this class, you can easily try different filter paramters setting without break the structure.

    Attention, this filter_appler only support band-pass filter paramter, for it is the most common type in EEG processing.
    """
    def __init__(self,sample_rate,high_cut_frequency,low_cut_frequency,filter_order,filter_type = 'FIR') -> None:
        """
        Args:
            sample_rate (float): Sample rate of filter
            high_cut_frequency (float): High cut-off frequence of the band-pass filter
            low_cut_frequency (float): Low cut-off frequence of the band-pass filter
            filter_order (int): Filter order of FIR filter
            filter_type (str,option): Choose a filter type, in this veision only support FIR which is also default.
        """
        self.sample_rate = sample_rate
        self.filter_type = filter_type
        self.high_cut_frequency = high_cut_frequency
        self.low_cut_frequency = low_cut_frequency
        self.filter_order = filter_order
        self._filter_design()
    
    def _filter_design(self):
        if self.filter_type == 'FIR':
            if self.filter_order%2 == 0:
                self.filter_order+=1
            # Nyquist rate of signal
            nyq_rate = self.sample_rate/2.0
            # 1-D array cut of
            freq_cutoff = [self.high_cut_frequency/nyq_rate, self.low_cut_frequency/nyq_rate]
            # Get the fir filter coef
            taps = scipy.signal.firwin(self.filter_order, freq_cutoff, window='hamming', pass_zero='bandpass')
            self.filter_b = taps
            self.filter_a = 1
        elif self.filter_type == 'IIR':
            pass
            
            
    def filter_apply(self,data):
        filter_data = scipy.signal.lfilter(self.filter_b,self.filter_a,data)
        return filter_data

In [144]:
# Test filter
filter = filter_applyer(250,6,80,48,'FIR')

In [140]:
class result_analyser():
    def __init__(self,labels,CV_loops,epoch_num) -> None:
        self.labels = list(labels)
        self.epoch_num = epoch_num
        self.CV_loops = CV_loops
        self.acc_storage = np.zeros(self.CV_loops)
        self.result_matrix = np.zeros((CV_loops,len(self.labels),epoch_num))
        self.trial_counter = 0
    
    def result_decide(self,coef_vector,CV_iter,trial_iter,epoch_iter):
        _result = self.labels[np.argmax(coef_vector)]
        self.trial_counter += 1
        if _result == trial_iter:
            self.result_matrix[CV_iter, self.labels.index(trial_iter), epoch_iter] = 1
        if self.trial_counter % len(self.labels)==0 and epoch_iter == self.epoch_num-1:
            self.acc_storage[CV_iter] = np.mean(self.result_matrix[CV_iter,:,:])
            print('ACC of the {} cross validation loop is: {}'.format(CV_iter,self.acc_storage[CV_iter]))
        
    def ACC_calculate(self):
        self.overall_ACC = np.mean(self.acc_storage)
        print('Overall ACC of current data is {}'.format(self.overall_ACC))

In [141]:
class data_trainer():
    
    def __init__(self, usrname, data_path, block_num, spatial_filter_type='TRCA', cross_validation=True) -> None:
        with open('config.json') as file:
            self.data_config = json.load(file)
        data_dir = pjoin(data_path,usrname,block_num,'EEG.mat')
        self.raw_data = loadmat(data_dir)
        self.raw_data = self.raw_data['EEG'][0]
        self.event = self.raw_data['event'][0]
        self.event_size = self.event.shape[0]
        self.data = self.raw_data['data'][0]
        # Read experiment paramter setting in json file
        self.blocks_in_data = self.data_config['blocks_in_data']
        self.epochs_in_data = self.data_config['epochs_in_trials']
        self.slice_data_storage = dict()
        self.template_storage = dict()
        self.sample_rate = self.raw_data['srate'][0][0][0]
        self.epoch_length = int(self.data_config['epoch_length'][block_num]*self.sample_rate)
        self.visual_delay = int(0.14*self.sample_rate)
        self.cross_validation = cross_validation
        # Initail time filter paramters here
        self.filter_object = filter_applyer(self.sample_rate,7.0,80.0,64,filter_type='FIR')
        self.spatial_filter_type = spatial_filter_type
        self.template_storage = dict()

        self.data_slice()
   
    def data_slice(self):
        for event_iter in range(self.event_size):
            event_type = self.event[event_iter][0][0][0]
            event_time_stamp = int(self.event[event_iter][0][1][0][0])
            epoch_cut = self.data[:,event_time_stamp+self.visual_delay:event_time_stamp+self.visual_delay+self.epoch_length]
            #Zero-mean (Must DO!)
            epoch_cut = epoch_cut-np.mean(epoch_cut,axis=-1,keepdims=True)
            filtered_epoch_cut = self.filter_object.filter_apply(epoch_cut)
            event_list = self.slice_data_storage.setdefault(event_type, list())
            event_list.append(filtered_epoch_cut)
        self.event_series = self.slice_data_storage.keys()
        print('Data sliced ready!')
        print('Total number of events: {}'.format(len(self.slice_data_storage)))
    
    def trainer(self):
        self.spatial_filters = dict()
        for train_trial_iter in self.event_series:
            self.spatial_filters[train_trial_iter] = self.feature_extract(self.slice_data_storage[train_trial_iter])
            self.template_calculate(self.slice_data_storage[train_trial_iter], train_trial_iter)

    def feature_extract(self,data):
        if self.spatial_filter_type == 'TRCA':
            return trca_matrix(data)
        raise Exception('Method not define, you can define it manually!')
    
    def template_calculate(self, train_data, event_type):
        self.template_storage[event_type] = np.mean(train_data, axis=0)
        self.template_events = list(self.template_storage.keys())
    
    def train_result_get(self):
        return self.spatial_filters, self.template_storage    

In [142]:
class data_cross_validation(data_trainer):
    
    def cross_validation_runner(self):
        self.dataset_split_index = train_test_split(split_type='K_Fold',test_sample_num=self.epochs_in_data, total_sample_num=self.epochs_in_data*self.blocks_in_data)
        self.result_saver = result_analyser(self.event_series,self.dataset_split_index.shape[0],self.epochs_in_data)
        for cross_validation_iter in range(self.dataset_split_index.shape[0]):
            self.CV_iter = cross_validation_iter
            print('cross validation loop: {}'.format(cross_validation_iter))
            validation_index = self.dataset_split_index[cross_validation_iter,:]
            self.trainer(1-validation_index)
            self.tester(validation_index)
        self.result_saver.ACC_calculate()
    
    def trainer(self,select_index):
        self.spatial_filters = dict()
        for train_trial_iter in self.event_series:
            self.spatial_filters[train_trial_iter] = self.feature_extract(np.array(self.slice_data_storage[train_trial_iter])[select_index==1,:,:])
            self.template_calculate(np.array(self.slice_data_storage[train_trial_iter])[select_index==1,:,:],train_trial_iter)
    
    def tester(self,select_index):
        self.corrcoef_storage = dict()
        for test_trial_iter in self.event_series:
            test_epoches = np.array(self.slice_data_storage[test_trial_iter])[select_index==1,:,:]
            corrcoef_list = self.corrcoef_storage.setdefault(test_trial_iter, list())
            for test_epoch_iter in range(test_epoches.shape[0]):
                coef_vector = pattern_match(test_epoches[test_epoch_iter,:,:],self.spatial_filters,self.template_storage)
                corrcoef_list.append(coef_vector)
                self.result_saver.result_decide(coef_vector,self.CV_iter,test_trial_iter,test_epoch_iter)

In [143]:
# Test cross_validation class
tester = data_cross_validation('mengqiangfan','./ThesisData/','block4')
tester.cross_validation_runner()

# TODO: Add support for IIR filter
# TODO: Add LDA module
# TODO: Test TDCA

(8, 125)
Data sliced ready!
Total number of events: 12
cross validation loop: 0
ACC of the 0 cross validation loop is: 0.9333333333333333
cross validation loop: 1
ACC of the 1 cross validation loop is: 0.85
cross validation loop: 2
ACC of the 2 cross validation loop is: 0.8666666666666667
cross validation loop: 3
ACC of the 3 cross validation loop is: 0.8666666666666667
cross validation loop: 4
ACC of the 4 cross validation loop is: 0.9
cross validation loop: 5
ACC of the 5 cross validation loop is: 0.8
Overall ACC of current data is 0.8694444444444445
