# Classification using filter specific features

After determining discriminative features for filters corresponding to one class of EEG signals (right hand, left hand, rest, right foot), we now want to see if the they tend to appear more often in signals of the corresponding class and less in the others. For that, the feature of interest is extracted and an LDA binary classifier (right hand or not right hand) is trained on it. If the classifier is to some degree able to correctly dinstinguish between classes with only that one feature, we can assume that it is characteristic for at least a subset of signals.

In [None]:
import numpy as np
from numpy.random import RandomState
import scipy
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import StratifiedKFold,StratifiedShuffleSplit

import logging
log = logging.getLogger()
log.setLevel("DEBUG")
from braindecode.scripts.train_experiments import setup_logging
setup_logging()

from braindecode.veganlasagne.layer_util import print_layers

import os

%load_ext autoreload
%autoreload 2

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns

In [None]:
import receptive_field
import analysis
import utils

In [None]:
max_baseline_inputs = 5000

In [None]:
modelpath = '/home/hartmank/braindecode/data/models/'
modelname = 'paper/ours/cnt/deep4/car/22'
savepath  = '/home/hartmank/data/convvisual/RF_data/'

In [None]:
exp,model,datasets = utils.get_dataset(os.path.join(modelpath,modelname))

In [None]:
# Plotting functions
def cut_input_data(RF_data,filt,separated_channels=True,use_class_start_samples=False):
    reshape_channels = n_chans
    if separated_channels:
        reshape_channels = 1
    
    max_units_in_filters = np.asarray(RF_data.results[cl].max_units_in_filters)
    filt_input_indeces = max_units_in_filters[:,1]==filt
    max_units_in_filters = max_units_in_filters[filt_input_indeces]
    
    X_RF_cropped = utils.get_cropped_RF(RF_data.results[cl].RF_X[filt_input_indeces].squeeze(),([0],reshape_channels,-1))
    window_indeces = np.unique(max_units_in_filters[:,2])
    
    classes = RF_data.classes
    neg_classes = np.delete(classes,cl)
    inputs_baseline = np.array([])
    for c in neg_classes:
        inputs_baseline = np.vstack([inputs_baseline,RF_data.results[c].inputs]) if inputs_baseline.size else RF_data.results[c].inputs
        
    if use_class_start_samples:
        X_baseline = utils.cut_ind_windows(inputs_baseline,X_RF_cropped.shape[2],window_indeces,wins_per_input=100).squeeze()
    else:
        X_baseline = utils.cut_rand_windows(inputs_baseline,X_RF_cropped.shape[2],100).squeeze()
    
    if separated_channels:
        X_baseline = X_baseline.reshape((-1,X_RF_cropped.shape[1],X_RF_cropped.shape[2]))
    indx = rng.permutation(len(X_baseline))[:max_baseline_inputs]
    X_baseline = X_baseline[indx]
    
    return X_RF_cropped,X_baseline 


def print_features(score,p,labels,indeces):
    for idx in indeces:
        print 'Score %f  p %f  : %s'%(score[idx],p[idx],labels[idx])
        
def plot_avg(m,s,title='',color='b'):
    plt.fill_between(np.arange(m.shape[0]),m-s,m+s,color=color,zorder=100,alpha=0.3) 
    plt.plot(np.arange(m.shape[0]),m,color=color,zorder=101,label=title)
    
    
def plot_dist_comparison(features,features_base,idx):
    sns.distplot(features[:,idx],label='Class')
    sns.distplot(features_base[:,idx],label='Baseline')
    plt.legend()

In [None]:
def scorer(pred, y):
    T_pos = np.mean(pred[y==1]==1)
    T_neg = np.mean(pred[y==0]==0)
    
    return [np.mean([T_pos,T_neg]),T_pos,T_neg]

# Get characteristic features for Layer 16

In this notebook we will investigate the features that were determined to be characteristic for filter 70 Layer 16 in the filter analysis notebook. It seemed to strongly react to the signal being in a specific phase shift for 11.9 Hz (occuring on the right motor cortex):  
Score 0.752000  p 0.000000  : Phase 11.9047619048 FCC4h  
Score 0.653000  p 0.000000  : Phase 11.9047619048 FC4  
Score 0.601000  p 0.000000  : Phase 11.9047619048 C4  
Score 0.545000  p 0.000000  : Phase 11.9047619048 C2

We will try to find windows of signals from class 0 that are locked into that phase.

In [None]:
RF_save = receptive_field.ReceptiveFieldInputsIO()
filename = 'Layer28_nUnits200_nFilters05_filterdiff_traindata.data'
RF_save = RF_save.load(os.path.join(savepath,modelname,filename))

In [None]:
cl = 0
n_chans = RF_save.results[cl].n_chans
sampling_rate = RF_save.results[cl].sampling_rate
sensor_names = RF_save.results[cl].sensor_names
X_RF_tmp = utils.get_cropped_RF(RF_save.results[cl].RF_X,([0],[1],n_chans,-1))[0]

In [None]:
X_RF_cropped,X_baseline = cut_input_data(RF_save,0,separated_channels=False)
feat_mean_diff,feat_p,index_labels,features_class,features_base = utils.extract_features_and_diff(X_RF_cropped,X_baseline,sampling_rate)
sort_mean_diff = feat_mean_diff.argsort()[::-1]
frequencies = scipy.fftpack.fftfreq(X_RF_tmp.shape[2], 1./sampling_rate)
frequencies = frequencies[:frequencies.shape[0]/2].astype(str)

labels = utils.make_labels_from_index_labels(index_labels.tolist(),
                                    {'FFT':[frequencies,sensor_names],
                                    'FFTc':[frequencies,sensor_names],
                                    'Phase':[frequencies[1:],sensor_names],
                                    'Phasec':[frequencies[1:],sensor_names],
                                    'Mean':[sensor_names],
                                    'Meanc':[sensor_names]})

In [None]:
del RF_save

### Get test set and cut it
Here we get the test set and cut the trials into n windows of size t, shifting by 1 sample.  
t: length of the signals in the receptive field of a unit in the layer  
n: total length of signal - t

Windowing is necessary, because the phase feature is time depended and can change quickly. We can not depend on it being present throughout the complete signal, but search for windows in trials that exhibit distinguishable values.

In [None]:
batch_test = utils.get_dataset_batches(exp,datasets['test'],1000,True)[0]
inputs,targets = batch_test
targets = targets.reshape((len(inputs),-1,4))
targets = targets.sum(axis=1).argmax(axis=1)
inputs_class = inputs[targets==cl]
inputs_base = inputs[targets!=cl]

In [None]:
del batch_test,inputs,targets

In [None]:
inputs_class_windows = utils.cut_all_windows(inputs_class,X_RF_tmp.shape[2]).squeeze()
inputs_base_windows = utils.cut_all_windows(inputs_base,X_RF_tmp.shape[2]).squeeze()

<b>Extract features from windows</b>

In [None]:
features_test_class,_ = utils.extract_features(inputs_class_windows,sampling_rate)
features_test_class = features_test_class[:,sort_mean_diff[:20]]

features_test_base = list()
input_batches = np.array_split(np.arange(len(inputs_base_windows)),10)
for batch in input_batches:
    if len(batch)==0:
        break
    tmp,_ = utils.extract_features(inputs_base_windows[batch],sampling_rate)
    features_test_base.extend(tmp[:,sort_mean_diff[:20]])
features_test_base = np.asarray(features_test_base)

n_windows = inputs.shape[2]-X_RF_tmp.shape[2]
features_test_class_perwin = features_test_class.reshape((n_windows,inputs_class.shape[0],features_test_class.shape[1]))
features_test_base_perwin = features_test_base.reshape((n_windows,inputs_base.shape[0],features_test_base.shape[1]))

# Feature value over trials
This plot shows the mean value of the feature in windows from right hand (blue) and not right hand (green) trials. The shaded area shows the corresponding 1 std confidence interval.

In the case for Phase 11.9047619048 FCC4h, it shows the the two phase means to be in counterphase in windows starting between sample 200 and sample 400. This observation is in agreement with the distribution of starting samples for that feature in the filter analysis notebook.

In [None]:
mean_class = features_test_class_perwin.mean(axis=1)
std_class = features_test_class_perwin.std(axis=1)
mean_base = features_test_base_perwin.mean(axis=1)
std_base = features_test_base_perwin.std(axis=1)
plot_avg(mean_class[:,0],std_class[:,0],color='b',title='Class')
plot_avg(mean_base[:,0],std_base[:,0],color='g',title='Baseline')
plt.xlabel('Window starting sample')
plt.ylabel('Feature value')
plt.title('Phase 11.9047619048 FCC4h in windows starting at different samples')
plt.legend()
plt.show()

To further investigate the difference between class and no class windows, we plot the distributions of the feature in the windows starting at sample 260

In [None]:
plot_dist_comparison(features_test_class_perwin[250],features_test_base_perwin[0],0)
plt.title('Phase 11.9047619048 FCC4h in window starting at sample 0')
plt.xlabel('Feature value')
plt.show()

# Classification with LDA
Here we classify trials belonging to right hand / not right hand using the phase feature. For this we will calculate for each trial the phase values for windows starting at samples 250-254. These 5 values of Phase 11.9047619048 FCC4h will be our classification features.

We will train an LDA with shrinkage. In total we perform a 10-fold crossvalidation 500 times and compute the mean and standard deviation for  
1: True positives (right hand)  
2: True negatives (not right hand)  
3: Overall performance

Because our two class sets are imbalanced, we will randomly sample 80 trials from the not right hand set for each crossvalidation.

In [None]:
rng = RandomState()

data_class = features_test_class_perwin[0,:,0].T
data_base = features_test_base_perwin[0,:,0].T
data = np.vstack([data_class,data_base])
t = np.hstack([np.zeros((data_class.shape[0])),np.ones((data_base.shape[0]))])

scores = list()
for i in range(500):
    data_base_tmp = data_base[rng.randint(data_base.shape[0],size=(data_class.shape[0]))]
    data = np.vstack([data_class,data_base_tmp])
    t = np.hstack([np.zeros((data_class.shape[0])),np.ones((data_base_tmp.shape[0]))])
    clf = LDA(solver='lsqr', shrinkage='auto')
    
    skf = StratifiedKFold(random_state=rng.randint(999999))
    for train, test in skf.split(data, t):
        clf.fit(data[train],t[train])
        pred = clf.predict(data[test])
        scores.append(scorer(pred,t[test]))
        
scores_mean = np.mean(scores,axis=0)
scores_std = np.std(scores,axis=0)
print 'True positives: %f+-%f'%(scores_mean[1],scores_std[1])
print 'True negatives: %f+-%f'%(scores_mean[2],scores_std[2])
print 'Total Score: %f+-%f'%(scores_mean[0],scores_std[0])