In [87]:
from os.path import join as pjoin
import numpy as np
import scipy.signal
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from scipy.io import loadmat
import json

from matplotlib import pyplot as plt

In [88]:
def trca_matrix(X):
    
    n_chans = X.shape[1]
    n_trial = X.shape[0]
    S = np.zeros((n_chans, n_chans))

    # Computation of correlation matrices
    for trial_i in range(n_trial):
        for trial_j in range(n_trial):
            x_i = X[trial_i, :, :]
            x_j = X[trial_j, :, :]
            S = S + np.dot(x_i, x_j.T)
    X = np.transpose(X, (1, 2, 0))
    X1 = X.reshape((n_chans, -1),order='F')
    X1 = X1 - np.mean(X1, axis=1, keepdims=True)
    Q = np.dot(X1, X1.T)
    S = np.matrix(S)
    Q = np.matrix(Q)
    # TRCA eigenvalue algorithm
    [W, V] = np.linalg.eig(np.dot(Q.I,S))

    return V[:, 0].reshape(-1)

In [89]:
def train_test_split(split_type='K_Fold',test_sample_num = None, total_sample_num = None):
    if split_type == "K_Fold":
        if total_sample_num%test_sample_num != 0:
            raise Exception('The sample cannot be divided into test subsets')
        else:
            folders_num = int(total_sample_num/test_sample_num)
            split_index = np.zeros((folders_num,total_sample_num))
            for folder_index in range(folders_num):
                split_index[folder_index, folder_index*test_sample_num:(folder_index+1)*test_sample_num] = np.ones((test_sample_num,))
        return split_index
    
    raise Exception('split_type was not defined, use \'K_Fold \' or define it manually')

In [90]:
# Test function for train_test_split()
split_index = train_test_split(test_sample_num=5,total_sample_num=30)
print(split_index)

[[1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1.
  1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 1. 1. 1. 1. 1.]]


In [91]:
def pattern_match(testsample,spatial_filters,template_storage):
    if len(testsample.shape) != 2:
        raise TypeError('testsample should be a two dimensional vector')
    corrcoef_storage = list()
    spatial_filter = np.squeeze(np.array(list(spatial_filters.values())))
    for template_iter in template_storage.keys():
        corrcoef_storage.append(np.corrcoef(np.dot(spatial_filter,testsample),np.dot(spatial_filter,template_storage[template_iter]))[0,1])
    
    corrcoef_storage = np.array(corrcoef_storage)

    return corrcoef_storage

In [92]:
def result_counter(coef_vector,label,counter,right_counter):
    if np.argmax(coef_vector) == int(label):
        counter += 1
        right_counter += 1
    else:
        counter += 1

In [98]:
class filter_applyer():
    def __init__(self,sample_rate,filter_type,high_cut_frequency,low_cut_frequency,filter_order) -> None:
        self.sample_rate = sample_rate
        self.filter_type = filter_type
        self.high_cut_frequency = high_cut_frequency
        self.low_cut_frequency = low_cut_frequency
        self.filter_order = filter_order
        self._filter_design()
    
    def _filter_design(self):
        if self.filter_type == 'FIR':
            if self.filter_order%2 != 0:
                self.filter_order+=1
            # Nyquist rate of signal
            nyq_rate = self.sample_rate/2.0
            # 1-D array cut of
            freq_cutoff = [self.high_cut_frequency/nyq_rate, self.low_cut_frequency/nyq_rate]
            # Get the fir filter coef
            taps = scipy.signal.firwin(self.filter_order, freq_cutoff, window='hamming', pass_zero='bandpass')
            self.filter_b = taps
            self.filter_a = 1
        elif self.filter_type == 'IIR':
            pass
            
            
    def filter_apply(self,data):
        filter_data = scipy.signal.lfilter(self.filter_b,self.filter_a,data)
        return filter_data

In [94]:
# Test filter
filter = filter_applyer(250,'FIR',6,80,48)
print(filter.sample_rate)
print(filter.filter_b)

250
[-4.71842418e-04  1.26631211e-03  5.85815119e-04 -1.66662178e-03
  1.04112726e-03  6.11615084e-04 -5.35670174e-03 -8.66699792e-04
  1.14171297e-04 -1.32803884e-02 -6.99903835e-03 -9.47797484e-04
 -2.59797972e-02 -1.94487854e-02 -9.84152473e-04 -4.20330734e-02
 -3.98547428e-02  4.76130221e-03 -5.81367401e-02 -7.32096984e-02
  3.09527471e-02 -7.01321623e-02 -1.67038331e-01  2.38958350e-01
  5.91273788e-01  2.38958350e-01 -1.67038331e-01 -7.01321623e-02
  3.09527471e-02 -7.32096984e-02 -5.81367401e-02  4.76130221e-03
 -3.98547428e-02 -4.20330734e-02 -9.84152473e-04 -1.94487854e-02
 -2.59797972e-02 -9.47797484e-04 -6.99903835e-03 -1.32803884e-02
  1.14171297e-04 -8.66699792e-04 -5.35670174e-03  6.11615084e-04
  1.04112726e-03 -1.66662178e-03  5.85815119e-04  1.26631211e-03
 -4.71842418e-04]


In [95]:
class data_trainer():
    
    def __init__(self, usrname, data_path, block_num, spatial_filter_type='TRCA', cross_validation=True) -> None:
        with open('config.json') as file:
            self.data_config = json.load(file)
        data_dir = pjoin(data_path,usrname,block_num,'EEG.mat')
        self.raw_data = loadmat(data_dir)
        self.raw_data = self.raw_data['EEG'][0]
        self.event = self.raw_data['event'][0]
        self.event_size = self.event.shape[0]
        self.data = self.raw_data['data'][0]
        # Read experiment paramter setting in json file
        self.blocks_in_data = self.data_config['blocks_in_data']
        self.epochs_in_data = self.data_config['epochs_in_trials']
        self.slice_data_storage = dict()
        self.template_storage = dict()
        self.sample_rate = self.raw_data['srate'][0][0][0]
        self.epoch_length = int(self.data_config['epoch_length'][block_num]*self.sample_rate)
        self.visual_delay = int(0.14*self.sample_rate)
        self.cross_validation = cross_validation
        # Initail time filter paramters here
        self.filter_object = filter_applyer(self.sample_rate,'FIR',7.0,80.0,64)
        self.spatial_filter_type = spatial_filter_type
        self.template_storage = dict()

        self.data_slice()
   
    def data_slice(self):
        for event_iter in range(self.event_size):
            event_type = self.event[event_iter][0][0][0]
            event_time_stamp = int(self.event[event_iter][0][1][0][0])
            epoch_cut = self.data[:,event_time_stamp+self.visual_delay:event_time_stamp+self.visual_delay+self.epoch_length]
            #Zero-mean
            epoch_cut = epoch_cut-np.mean(epoch_cut,axis=-1,keepdims=True)
            filtered_epoch_cut = self.filter_object.filter_apply(epoch_cut)
            #fig,ax = plt.subplots()
            #plt.plot(filtered_epoch_cut[1,:])
            event_list = self.slice_data_storage.setdefault(event_type, list())
            event_list.append(filtered_epoch_cut)
        self.event_series = self.slice_data_storage.keys()
        print(self.slice_data_storage['1'][0].shape)
        print('Data sliced ready!')
        print('Total number of events: {}'.format(len(self.slice_data_storage)))
    
    def trainer(self):
        self.spatial_filters = dict()
        for train_trial_iter in self.event_series:
            self.spatial_filters[train_trial_iter] = self.feature_extract(self.slice_data_storage[train_trial_iter])
            self.template_calculate(self.slice_data_storage[train_trial_iter], train_trial_iter)

    def feature_extract(self,data):
        if self.spatial_filter_type == 'TRCA':
            return trca_matrix(data)
        raise Exception('Method not define, you can define it manually!')
    
    def template_calculate(self, train_data, event_type):
        self.template_storage[event_type] = np.mean(train_data, axis=0)
    
    def train_result_get(self):
        return self.spatial_filters, self.template_storage    

In [96]:
class data_cross_validation(data_trainer):
    
    def cross_validation_runner(self):
        self.dataset_split_index = train_test_split(split_type='K_Fold',test_sample_num=self.epochs_in_data, total_sample_num=self.epochs_in_data*self.blocks_in_data)
        for cross_validation_iter in range(self.dataset_split_index.shape[0]):
            print('cross validation loop: {}'.format(cross_validation_iter))
            validation_index = self.dataset_split_index[cross_validation_iter,:]
            self.trainer(1-validation_index)
            self.tester(validation_index)
    
    def trainer(self,select_index):
        self.spatial_filters = dict()
        for train_trial_iter in self.event_series:
            self.spatial_filters[train_trial_iter] = self.feature_extract(np.array(self.slice_data_storage[train_trial_iter])[select_index==1,:,:])
            self.template_calculate(np.array(self.slice_data_storage[train_trial_iter])[select_index==1,:,:],train_trial_iter)
    
    def tester(self,select_index):
        self.corrcoef_storage = dict()
        self.result_counter = 0
        self.result_right_counter = 0
        for test_trial_iter in self.event_series:
            test_epoches = np.array(self.slice_data_storage[test_trial_iter])[select_index==1,:,:]
            corrcoef_list = self.corrcoef_storage.setdefault(test_trial_iter, list())
            for test_epoch_iter in range(test_epoches.shape[0]):
                coef_vector = pattern_match(test_epoches[test_epoch_iter,:,:],self.spatial_filters,self.template_storage)
                corrcoef_list.append(coef_vector)
                result_counter(coef_vector,test_trial_iter,self.result_counter,self.result_right_counter)    

In [99]:
# Test cross_validation class
tester = data_cross_validation('mengqiangfan','./ThesisData/','block2')
tester.cross_validation_runner()

(8, 100)
Data sliced ready!
Total number of events: 12
cross validation loop: 0
cross validation loop: 1
cross validation loop: 2
cross validation loop: 3
cross validation loop: 4
cross validation loop: 5
