Created 5.19.20

Authors: Margot Wagner, Sam Russman

COGS 260 Neural Data Analysis

In [191]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import glob
from math import ceil

%matplotlib inline

### Load dataset

In [192]:
def load_data(filename, transpose=False):
    '''
    Load in data and return values
    
    param:   filename (str) - filename/path to data
             transpose - transposes data if true (default False)
    return:  data - data values (n_samples, n_features) 
    '''
    data = pd.read_csv(filename, header=None)
    
    # needs to be samples x features
    if transpose:    
        data = data.transpose()
        
    return data.values

In [193]:
def norm_data(data):
    '''
    Normalize data to z-values (0 mean and 1 std dev)
    
    param:   data - data values (n_samples, n_features)           
    return:  data - normalized data
    '''        
    means = data.mean(axis=0)    # mean for each feature
    stdevs = data.std(axis=0)    # std dev for each feature
    data = (data - means) / stdevs    # normalized data (Z-score)
    return data

### PCA 

In [194]:
from sklearn.decomposition import PCA
def run_pca(data):
    '''
    Does PCA on normalized data set. Can optionally set fewer components
    
    param: data - normalized data (n_samples, n_features)
    return: 
        data_pc: data transformed onto princinpal components
        components:  principal axes in feature space, array, shape (n_components, n_features)
        weights: percentage of variance explained by each of the selected components. array, shape (n_components,)
    '''
    # create PCA model
    pca = PCA() 
    
    # fit model to data
    data_pc = pca.fit(data)  
    
    # obtain components and components' weights
    components = pca.components_
    weights = pca.explained_variance_ratio_
    
    return data_pc, components, weights, pca

In [196]:
def cum_var_plot(weights, desired_var, plot=False):
    '''
    Cumulative variance plot (number of components vs cumulative variance captured) with calculated number 
    of PCs required to get to a certain desired variance explained
    
    params:
        weights: percentage of variance explained by each of the selected components. array, shape (n_components,)
        desired_var:  percent variance to find number of PCs for
    return
        pcs_req:   pcs required to captured at least desired variance
        captured_var   exact variance captured by pcs_req
    
    '''
    INDEX_SHIFT = 1
    # cumulative variance captured
    cum_var = np.cumsum(weights) 
    
    # find pcs req to get desired variance
    pcs_req = int(round(np.min(np.where(cum_var > desired_var)) + 1)) 
    
    # actual variance captured
    captured_var = cum_var[pcs_req-INDEX_SHIFT]
    
    # plotting
    if plot:
        plt.figure()
        plt.plot(range(INDEX_SHIFT,len(cum_var)+INDEX_SHIFT), cum_var)
        plt.axvline(x=pcs_req, ymin=0, ymax=1, color='k', linestyle='--')
        plt.xlabel('Number of components')
        plt.ylabel('Cumulative variance captured')
        plt.title('Cumulative Variance Captured by Principal Components')
    
    return pcs_req, captured_var

In [197]:
def biplot(data, pca, plot=False):
    '''
    Creates biplot for data mapping data onto top 2 or 3 principal components
    
    params
        data - normalized data, array shape (n_samples, n_features)
        pca - pca model
        
    return
        top_pcs[:,PC1_IDX] - pcs captured by top PC
        top_var[PC1_IDX] - variance captured by top PC
    '''
    PC1_IDX = 0    # index of first PC
    PC2_IDX = 1    # index of second PC
    DIM = 2        # dimensions to plot 
    
    # top 2 pc's and their variance explained
    top_pcs = pca.transform(data)[:,:DIM]
    top_var = pca.explained_variance_ratio_[:DIM]
    
    # plotting
    if plot:
        plt.figure()
        plt.scatter(top_pcs[:,PC1_IDX], top_pcs[:,PC2_IDX])
        plt.xlabel('PC1 ({:.1%})'.format(top_var[PC1_IDX]))
        plt.ylabel('PC2 ({:.1%})'.format(top_var[PC2_IDX]))
        plt.axvline(x=0, ymin=np.min(top_pcs[:,PC2_IDX]), ymax=np.max(top_pcs[:,PC2_IDX]), color='k', ls='--')
        plt.axhline(y=0, xmin=np.min(top_pcs[:,PC1_IDX]), xmax=np.max(top_pcs[:,PC1_IDX]), color='k', ls='--')
    
    return top_pcs[:,PC1_IDX], top_var[PC1_IDX]

# LDA Classification

In [None]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import precision_recall_curve
def lda(data):
    '''
    Linear discriminant analysis (LDA) classifier for input data into 3 stimulus conditions (left, right, no stimulus)
    
    param:
        data - normalized and potentially transformed to pcs
    
    '''
    
    N_TRIALS = 20
    labels = np.array(['left', 'right', 'none'])
    labels = np.repeat(labels, N_TRIALS)
    
    # TRANSFORM DATA TO LOWER DIM
    # Automatically sets test as 0.25
    # Random state so that it's the same every run
    X_train, X_test, y_train, y_test = train_test_split(data, labels, random_state = 0)

    # train LDA classifier
    # look into using LDA for dimensionality reduction??
    lda_model = LinearDiscriminantAnalysis.fit(X_train , y_train)
    y_predict = lda_model.predict(X_test)
    
    # model accuracy for X_test
    lda_acc = 100*accuracy_score(y_test, y_predict)
    print('Accuracy:',round(lda_acc,2),'%')
    print(classification_report(y_test, y_predict,target_names=['Left', 'Right', 'None']))
    
    # creating a confusion matrix
    cm = confusion_matrix(y_test, y_predict)


    
    plt.show()

### Receiver Operating Characteristic (ROC)

ROC curves have true positive (predicted positive and actually positive) rate on the Y axis and false positive (predicted positive, actually negative) on the X axis. THe top left corner for the plot is the "ideal" point with a false positive rate of zero and a true positive rate of one.  
  
In order to extend ROC curve and ROC area to multi-label classification, it is necessary to binarize the output (one vs all). One ROC curve can be drawn per label and/or one can draw a ROC curve by considering each element of the label indicator matrix as a binary prediction (micro-averaging).  
  
Another evaluation for multi-label classification is macro-averaging, which gives equal weight to the classification of each label.  
https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html#plot-roc-curves-for-the-multilabel-problem

In [None]:
def roc_curve():
    #https://towardsdatascience.com/the-5-classification-evaluation-metrics-you-must-know-aa97784ff226
    #https://stackoverflow.com/questions/56090541/how-to-plot-precision-and-recall-of-multiclass-classifier/56092736
    #https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html#in-multi-label-settings
    
    # Precision-Recall curve

### TODO: Add regularization step

### TODO: Add bootstrapping?

# Data Analysis

In [198]:
''' Mounted shared Google Drive for data
Data is recordings from 360 good channels x time
Data should be samples x features (time x features)
360 features, 51 samples. '''
N_SAMPLES = 42
N_TIMEPOINTS = 51
N_EXTRA_DATA = 2

'''
Output is top pc for all time points (0:50), the variance explained by
top output(51), and the total number of pc's required for 80% variance
explained by number of samples (21 left, 21 right)'''
output = np.zeros([N_TIMEPOINTS + N_EXTRA_DATA, N_SAMPLES])

# Run PCA on all trials
trials = []    # initialize list of trials
index = 0      # initialize index
os.chdir('/Volumes/GoogleDrive/Shared drives/COGS 260 Project/Data/fragments3')
for filename in glob.glob('*60mA*'):
    trials.append(filename.split('processed_')[1].split('.')[0])
    
    # Initialize trial results
    trial_results = np.zeros([N_TIMEPOINTS + N_EXTRA_DATA])
    
    # Load and normalize data
    data = load_data(filename, transpose=True)
    data = norm_data(data)
    
    # Run pca
    data_pc, components, weights, pca = run_pca(data)
    
    # further analysis
    # plotting options turned off
    pcs_req, captured_var = cum_var_plot(weights, 0.8)
    top_pc, top_var = biplot(data, pca)

    # Organize results for single trial
    trial_results[:N_TIMEPOINTS] = top_pc
    trial_results[N_TIMEPOINTS] = top_var
    trial_results[N_TIMEPOINTS + 1] = pcs_req

    # Add to all trial output
    output[:, index] = trial_results

    # incremement index
    index += 1
    

In [200]:
pd.DataFrame(output, columns=trials)

Unnamed: 0,trial_1,trial_2,trial_3,trial_4,trial_5,trial_6,trial_7,trial_8,trial_9,trial_10,...,trial_12,trial_13,trial_14,trial_15,trial_16,trial_17,trial_18,trial_19,trial_20,trial_21
0,-7.525431,-12.024689,5.33369,-15.201771,-13.522962,-3.214963,-8.840905,-5.616349,2.267068,18.009457,...,8.372454,-6.747731,-8.150211,-5.91118,-7.375954,-11.190864,-6.6779,1.358868,-8.308519,-8.765593
1,-8.982833,-11.488158,-3.604749,-23.319618,-14.633243,-5.781539,-7.56347,-6.2361,-1.745985,10.371672,...,1.000795,-7.12494,-9.297901,-9.277882,-7.677249,-11.693961,-5.604511,-6.476312,-4.408922,-9.1006
2,-11.141523,-12.10397,-6.50561,31.331777,-13.59886,-5.007512,-9.575792,-7.779647,-4.553653,7.666959,...,-1.551744,-11.104538,-10.345505,-7.602133,-11.980036,-11.704638,-6.106437,-8.276917,-8.315065,-8.595269
3,-10.242713,-13.468507,26.830914,78.79078,-13.752685,-1.185144,-11.357273,-5.609592,-3.901071,35.034961,...,26.189399,-14.237173,-10.728647,-5.423194,-15.271513,-12.069536,-8.904184,35.688795,-16.293597,-8.545494
4,-7.571089,-9.937053,53.651628,39.59484,-16.090566,0.490573,-9.878586,-1.082359,-3.505413,56.125345,...,47.429164,-10.923508,-10.869301,-6.213522,-14.269988,-11.295716,-10.926428,72.033223,-16.277885,-9.19539
5,-7.812236,-4.419489,27.917408,-15.595708,-16.497828,-3.24042,-7.503044,-1.525361,-4.285373,32.900691,...,24.448618,-6.582909,-10.538577,-6.026159,-12.058131,-9.633889,-8.881263,40.055373,-9.99866,-9.358978
6,-9.263641,-5.989432,-6.248622,-13.74204,-14.131065,-7.423902,-7.711403,-5.999808,-4.011297,3.944421,...,-3.263659,-6.204981,-9.224242,-4.73328,-10.898774,-8.895161,-5.169044,-3.042554,-7.374699,-8.619074
7,-7.912177,-11.697593,-5.481162,1.329512,-12.181476,-6.163017,-9.041615,-7.321664,-2.634631,4.835585,...,-0.010915,-6.464773,-8.1371,-5.934424,-9.426252,-8.595743,-4.700914,-1.308462,-8.463546,-7.757039
8,-6.508985,-12.558046,2.471332,-5.076013,-11.405892,-3.225754,-7.223392,-4.705108,-1.029203,10.861139,...,7.934724,-6.007766,-8.547449,-6.914309,-7.694197,-7.668934,-7.413196,8.034752,-7.568919,-7.649506
9,-7.790424,-10.514532,-2.436373,-9.288329,-10.825628,-2.228198,-3.899352,-2.068588,0.446678,6.495142,...,3.282302,-6.613868,-8.993205,-5.113409,-6.749303,-7.356499,-9.0202,0.866594,-4.843447,-8.308831
