In [32]:
''' 
Model for common spatial pattern (CSP) feature calculation and classification for EEG data
'''

import numpy as np
import time
from sklearn.svm import LinearSVC, SVC
from sklearn.model_selection import KFold
from scipy import signal 
from scipy.signal import butter, sosfilt, sosfreqz
import scipy.io as sio
import pyriemann.utils.mean as rie_mean
from scipy.special import binom
from scipy import linalg

def gevd(x1,x2,no_pairs):
    '''Solve generalized eigenvalue decomposition
    
    Keyword arguments:
    x1 -- numpy array of size [NO_channels, NO_samples]
    x2 -- numpy array of size [NO_channels, NO_samples]
    no_pairs -- number of pairs of eigenvectors to be returned 
    Return:	numpy array of 2*No_pairs eigenvectors 
    '''
    ev,vr= linalg.eig(x1,x2,right=True) 
    evAbs = np.abs(ev)
    sort_indices = np.argsort(evAbs)
    chosen_indices = np.zeros(2*no_pairs).astype(int)
    chosen_indices[0:no_pairs] = sort_indices[0:no_pairs]
    chosen_indices[no_pairs:2*no_pairs] = sort_indices[-no_pairs:]
    
    w = vr[:,chosen_indices] # ignore nan entries 
    return w

def butter_fir_filter(signal_in,filter_coeff):
    if filter_coeff.ndim == 2: # butter worth 
        return sosfilt(filter_coeff, signal_in)
    elif filter_coeff.ndim ==1: # fir filter 
        NO_channels ,NO_samples = signal_in.shape 
        sig_filt = np.zeros((NO_channels ,NO_samples))
        for channel in range(0,NO_channels):
            sig_filt[channel] = signal.convolve(signal_in[channel,:],filter_coeff,mode='same') # signal has same size as signal_in (centered)
        return sig_filt

def butter_bandpass(lowcut, highcut, fs, order=5):
        nyq = 0.5 * fs
        low = lowcut / nyq
        high = highcut / nyq
        sos = butter(order, [low, high], analog=False, btype='band', output='sos')
        return sos

def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
        sos = butter_bandpass(lowcut, highcut, fs, order=order)
        y = sosfilt(sos, data)
        return y


def load_bands(bandwidth,f_s,max_freq = 40):
    ''' Filter N channels with fir filter of order 101
    Keyword arguments:
    bandwith -- numpy array containing bandwiths ex. [2,4,8,16,32]
    f_s -- sampling frequency
    Return: numpy array of normalized frequency bands
    '''
    f_bands = np.zeros((99,2)).astype(float)
    band_counter = 0
    for bw in bandwidth:
        startfreq = 4
        while (startfreq + bw <= max_freq): 
            f_bands[band_counter] = [startfreq, startfreq + bw]
            if bw ==1: # do 1Hz steps
                startfreq = startfreq +1
            elif bw == 2: # do 2Hz steps
                startfreq = startfreq +2 
            else : # do 4 Hz steps if Bandwidths >= 4Hz
                startfreq = startfreq +4

            band_counter += 1 

    # convert array to normalized frequency 
    f_bands_nom = 2*f_bands[:band_counter]/f_s
    return f_bands_nom

def csp_one_one(cov_matrix,NO_csp,NO_classes):
    '''
    calculate spatial filter for class all pairs of classes 
    Keyword arguments:
    cov_matrix -- numpy array of size [NO_channels, NO_channels]
    NO_csp -- number of spatial filters (24)
    Return:spatial filter numpy array of size [22,NO_csp] 
    '''
    N, _ = cov_matrix[0].shape 
    n_comb = binom(NO_classes,2)

    NO_filtpairs = int(NO_csp/(n_comb*2))
    
    w = np.zeros((N,NO_csp))
    
    kk = 0 # internal counter 
    for cc1 in range(0,NO_classes):
        for cc2 in range(cc1+1,NO_classes):
            w[:,NO_filtpairs*2*(kk):NO_filtpairs*2*(kk+1)] = gevd(cov_matrix[cc1], cov_matrix[cc2],NO_filtpairs)
            kk +=1
    return w 

def generate_projection(data,class_vec,NO_csp,filter_bank,time_windows,NO_classes=4): 
    ''' generate spatial filters for every timewindow and frequancy band
    Keyword arguments:
    data -- numpy array of size [NO_trials,channels,time_samples]
    class_vec -- containing the class labels, numpy array of size [NO_trials]
    NO_csp -- number of spatial filters (24)
    filter_bank -- numpy array containing butter sos filter coeffitions dim  [NO_bands,order,6]
    time_windows -- numpy array [[start_time1,end_time1],...,[start_timeN,end_timeN]] 
    Return: spatial filter numpy array of size [NO_timewindows,NO_freqbands,22,NO_csp] 
    '''
    time_windows = time_windows.reshape((-1,2))
    NO_bands = filter_bank.shape[0]
    NO_time_windows = len(time_windows[:,0])
    NO_channels = len(data[0,:,0])
    NO_trials = class_vec.size

    # Initialize spatial filter: 
    w = np.zeros((NO_time_windows,NO_bands,NO_channels,NO_csp))

    # iterate through all time windows 
    for t_wind in range(0,NO_time_windows):
        # get start and end point of current time window 
        t_start = time_windows[t_wind,0]
        t_end = time_windows[t_wind,1]

        # iterate through all frequency bandwids 
        for subband in range(0,NO_bands): 

            cov = np.zeros((NO_classes,NO_trials, NO_channels,NO_channels)) # sum of covariance depending on the class
            cov_avg = np.zeros((NO_classes,NO_channels,NO_channels))
            cov_cntr = np.zeros(NO_classes).astype(int) # counter of class occurence 

            #go through all trials and estimate covariance matrix of every class 
            for trial in range(0,NO_trials):
                #frequency band of every channel
                data_filter = butter_fir_filter(data[trial,:,t_start:t_end], filter_bank[subband])
                cur_class_idx = int(class_vec[trial]-1)

                # caclulate current covariance matrix 
                cov[cur_class_idx,cov_cntr[cur_class_idx],:,:] = np.dot(data_filter,np.transpose(data_filter))

                # update covariance matrix and class counter 
                cov_cntr[cur_class_idx] += 1

            # calculate average of covariance matrix 
            for clas in range(0,NO_classes):
                cov_avg[clas,:,:] = rie_mean.mean_covariance(cov[clas,:cov_cntr[clas],:,:], metric = 'euclid')
                w[t_wind,subband,:,:] = csp_one_one(cov_avg,NO_csp,NO_classes) 
    return w


def extract_feature(data,w,filter_bank,time_windows):
    ''' calculate features using the precalculated spatial filters
    Keyword arguments:
    data -- numpy array of size [NO_trials,channels,time_samples]
    w -- spatial filters, numpy array of size [NO_timewindows,NO_freqbands,22,NO_csp]
    filter_bank -- numpy array containing butter sos filter coeffitions dim  [NO_bands,order,6]
    time_windows -- numpy array [[start_time1,end_time1],...,[start_timeN,end_timeN]] 
    Return: features, numpy array of size [NO_trials,(NO_csp*NO_bands*NO_time_windows)] 
    '''
    NO_csp = len(w[0,0,0,:])
    time_windows = time_windows.reshape((-1,2))
    NO_time_windows = int(time_windows.size/2)
    NO_bands = filter_bank.shape[0]
    NO_trials = len(data[:,0,0])
    NO_features = NO_csp*NO_bands*NO_time_windows
    
    feature_mat = np.zeros((NO_trials, NO_time_windows,NO_bands,NO_csp))
    
    # initialize feature vector 
    feat = np.zeros((NO_time_windows,NO_bands,NO_csp))

    # go through all trials 
    for trial in range(0,NO_trials):

        # iterate through all time windows 
        for t_wind in range(0,NO_time_windows):
            # get start and end point of current time window 
            t_start = time_windows[t_wind,0]
            t_end = time_windows[t_wind,1]

            for subband in range(0,NO_bands):
                #Apply spatial Filter to data 
                cur_data_s = np.dot(np.transpose(w[t_wind,subband]),data[trial,:,t_start:t_end])
                #frequency filtering  
                cur_data_f_s = butter_fir_filter(cur_data_s,filter_bank[subband])
                # calculate variance of all channels 
                feat[t_wind,subband] = np.var(cur_data_f_s,axis=1)

        for subband in range(0,NO_bands):
            feat[:,subband] = np.log10(feat[:,subband])#/np.sum(feat[:,subband]))
        # store feature in list 
        feature_mat[trial,:,:,:] = feat
    return np.reshape(feature_mat,(NO_trials,-1)) #

def load_filterbank(bandwidth,fs, order = 4, max_freq = 40,ftype = 'butter'): 
    ''' Calculate Filters bank with Butterworth filter  
    Keyword arguments:
    bandwith -- numpy array containing bandwiths ex. [2,4,8,16,32]
    f_s -- sampling frequency
    Return: numpy array containing filters coefficients dimesnions 'butter': [N_bands,order,6] 'fir': [N_bands,order]
    '''
    f_band_nom = load_bands(bandwidth,fs,max_freq) # get normalized bands 
    n_bands = f_band_nom.shape[0]
    if ftype == 'butter': 
        filter_bank = np.zeros((n_bands,order,6))
    elif ftype == 'fir':
        filter_bank = np.zeros((n_bands,order))

    for band_idx in range(n_bands):
        if ftype == 'butter': 
            filter_bank[band_idx] = butter(order, f_band_nom[band_idx], analog=False, btype='band', output='sos')
        elif ftype == 'fir':
            filter_bank[band_idx] = signal.firwin(order,f_band_nom[band_idx],pass_zero=False)
    return filter_bank

def get_data(training):
    ''' Loads the dataset 2a of the BCI Competition IV
    available on http://bnci-horizon-2020.eu/database/data-sets
    Keyword arguments:
    training -- if True, load training data
    if False, load testing data
    Return: data_return     numpy matrix    size = NO_valid_trial x 22 x 1750
            class_return    numpy matrix    size = NO_valid_trial
    '''
    NO_channels = 22
    NO_tests = 6*48 
    Window_Length = 7*250 # duration of one trial

    class_return = np.zeros(NO_tests)
    data_return = np.zeros((NO_tests,NO_channels,Window_Length))

    NO_valid_trial = 0
    if training:
        a = sio.loadmat('dataset/'+'A01T.mat')
    else:
        a = sio.loadmat('dataset/'+'A01E.mat')
    a_data = a['data']
    for ii in range(0,a_data.size):
        a_data1 = a_data[0,ii]
        a_data2=[a_data1[0,0]]
        a_data3=a_data2[0]
        a_X         = a_data3[0]
        a_trial     = a_data3[1] # array containing start times for all trials
        a_y         = a_data3[2]
        a_fs        = a_data3[3]
        a_classes   = a_data3[4]
        a_artifacts = a_data3[5]
        a_gender    = a_data3[6]
        a_age       = a_data3[7]
        for trial in range(0,a_trial.size):
            data_return[NO_valid_trial,:,:] = np.transpose(a_X[int(a_trial[trial]):(int(a_trial[trial])+Window_Length),:NO_channels])
            class_return[NO_valid_trial] = int(a_y[trial])
            NO_valid_trial +=1
    return data_return[0:NO_valid_trial,:,:], class_return[0:NO_valid_trial]

class CSP_Model:

    def __init__(self):
        self.svm_c  = 0.1 # 0.05 for linear, 20 for rbf, poly: 0.1
        self.fs = 250 # sampling frequency 
        self.NO_channels = 22 # number of EEG channels
        self.NO_csp = 24 # Total number of CSP feature per band and timewindow
        self.bw = np.array([2,4,8,16,32]) # bandwidth of filtered signals 
        self.ftype = 'butter' # 'fir', 'butter'
        self.forder= 2 # 4
        self.filter_bank = load_filterbank(self.bw,self.fs,order=self.forder,max_freq=40,ftype = self.ftype) # get filterbank coeffs  
        time_windows_flt = np.array([
                                [2.5,3.5],
                                [3,4],
                                [3.5,4.5],
                                [4,5],
                                [4.5,5.5],
                                [5,6],
                                [2.5,4.5],
                                [3,5],
                                [3.5,5.5],
                                [4,6],
                                [2.5,6]])*self.fs # time windows in [s] x fs for using as a feature
        self.time_windows = time_windows_flt.astype(int) 
        self.NO_bands = self.filter_bank.shape[0]
        self.NO_time_windows = int(self.time_windows.size/2)
        self.NO_features = self.NO_csp*self.NO_bands*self.NO_time_windows
        self.clf = LinearSVC(C = self.svm_c, intercept_scaling=1, loss='hinge', max_iter=1000,multi_class='ovr', penalty='l2', random_state=1, tol=0.00001)


    def run_csp(self, train_data, train_label, eval_data, eval_label):
        ################################ Training ############################################################################
        start_train = time.time()
        # 1. Apply CSP to bands to get spatial filter 
        w = generate_projection(train_data,train_label, self.NO_csp,self.filter_bank,self.time_windows)
        # 2. Extract features for training 
        feature_mat = extract_feature(train_data,w,self.filter_bank,self.time_windows)
        # 3. Stage Train SVM Model 
        # 2. Train SVM Model  
        self.clf.fit(feature_mat,train_label) 
        
        end_train = time.time()
        print('train time:' + str(end_train-start_train))
        
        ################################# Evaluation ###################################################
        start_eval = time.time()
        eval_feature_mat = extract_feature(eval_data,w,self.filter_bank,self.time_windows)
        success_rate = self.clf.score(eval_feature_mat,eval_label)
        end_eval = time.time()
        
        ml.save(f'clfs/model', self.clf)
        
        print('eval time:'+ str(end_eval-start_eval))
        return success_rate 


    def load_data(self):
        train_data,train_label = get_data(True)
        eval_data,eval_label = get_data(False)
        return train_data, train_label, eval_data, eval_label

def main():
    model = CSP_Model() 
    mode = "train"
    
    if(mode == "train"):
        # load Eval data 
        train_data, train_label, eval_data, eval_label = model.load_data()
        success_rate = model.run_csp(train_data, train_label, eval_data, eval_label)
        print("success rate: " + str(success_rate))    
    elif(mode == "pred"):
        train_data, train_label, eval_data, eval_label = model.load_data()
        eval_feature_mat = extract_feature(eval_data,w,self.filter_bank,self.time_windows)
        success_rate = self.clf.score(eval_feature_mat,eval_label)
        clf =  ml.load(f'clfs/model')
        pred = clf.predict(eval_feature_mat)
        print(pred)

if __name__ == '__main__':
    main()

NameError: name 'w' is not defined