In [1]:
import numpy as np 
import pandas as pd
import scipy
import matplotlib.pyplot as plt
import sys
from sklearn.linear_model import LinearRegression
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize
import dca
from dca.dca import DynamicalComponentsAnalysis
from sklearn.model_selection import KFold
from os import listdir
from os.path import isfile, join
import os
import contextlib
from IPython.utils import io
from sklearn import svm

In [2]:
sys.path.append('../..')

In [3]:
import loaders
import decoders
from loaders import load_shenoy_large
from decoders import lr_decoder


Find datasets where behaviors are in circles

use spike_rates to classify orientation (SVM)

variance test: premotor cortex v.s motor cortex

try to classify the target position from the neural data. We should restrict to **only reaches that go from the origin outward** (i.e. ignore when the target pos is the origin). 
We need to try a few things. 
(1) Iterate over different bin widths. 

(2) Try location='M1', and location='PMC', **Done**

(3) interval='before_go', 'after_go', 'full', 

(4) make sure spike threshold is set to 0. 

(5) Compare training the classifier on all trials to training on the dataset where we average over all trials (i.e. for each reach orientation, first average the neural data over all trials before fitting to the classifier)

Question: Do I need to toggle trialize to False?

In [4]:
dat_path = '/mnt/Secondary/data'

In [5]:
JFiles = [f for f in listdir("%s/000121/sub-JenkinsC" % dat_path) if isfile(join("%s/000121/sub-JenkinsC" % dat_path, f))]
RFiles = [f for f in listdir("%s/000121/sub-Reggie" % dat_path) if isfile(join("%s/000121/sub-Reggie" % dat_path, f))]

In [9]:
RFiles

['sub-Reggie_ses-20170115T125333_behavior+ecephys.nwb',
 'sub-Reggie_ses-20170116T102856_behavior+ecephys.nwb',
 'sub-Reggie_ses-20170117T104643_behavior+ecephys.nwb',
 'sub-Reggie_ses-20170118T094022_behavior+ecephys.nwb',
 'sub-Reggie_ses-20170119T123128_behavior+ecephys.nwb',
 'sub-Reggie_ses-20170124T094957_behavior+ecephys.nwb',
 'sub-Reggie_ses-20170125T100800_behavior+ecephys.nwb']

In [None]:
sub-Reggie_ses-20170119T123128_behavior+ecephys.nwb

In [20]:
RFiles[4]

'sub-Reggie_ses-20170119T123128_behavior+ecephys.nwb'

In [17]:
# Attempt to load with bin size 10 ms
dat = load_shenoy_large('%s/000121/sub-Reggie/%s' % (dat_path, RFiles[4]), spike_threshold=0, trialize=True, bin_width=10, location='PMC')

3069 valid trials
96 valid  units
Trializing spike times


96it [00:52,  1.84it/s]


Processing spikes


100%|██████████| 3069/3069 [01:07<00:00, 45.54it/s]


Trializing Behavior


3069it [00:59, 51.60it/s]


In [18]:
dat['spike_rates'][0].shape

(238, 89)

In [41]:
# Filter out non-horizontal-vertical experiments
for file_name in RFiles[:]:
    with io.capture_output() as captured: ##To disable prints from loader function
        dat = load_shenoy_large("F:/dandi/000121/sub-Reggie/" + file_name, spike_threshold = 0, trialize = True)
    target_pos = dat['target_pos']
    if sum(target_pos[:,1]!=0) != 0 and sum(target_pos[:,0]!=0) != 0:
        #This means the target pos is not purely horizontal or vertical
        print(file_name)

sub-Reggie_ses-20170119T123128_behavior+ecephys.nwb
sub-Reggie_ses-20170125T100800_behavior+ecephys.nwb
sub-Reggie_ses-20170117T104643_behavior+ecephys.nwb
sub-Reggie_ses-20170124T094957_behavior+ecephys.nwb


In [42]:
# Filter out non-horizontal-vertical experiments
for file_name in JFiles[:]:
    with io.capture_output() as captured: ##To disable prints from loader function
        dat = load_shenoy_large("F:/dandi/000121/sub-JenkinsC/" + file_name, spike_threshold = 0, trialize = True)
    target_pos = dat['target_pos']
    if sum(target_pos[:,1]!=0) != 0 and sum(target_pos[:,0]!=0) != 0:
        #This means the target pos is not purely horizontal or vertical
        print(file_name)

sub-JenkinsC_ses-20151015T151424_behavior+ecephys.nwb
sub-JenkinsC_ses-20160127T110630_behavior+ecephys.nwb
sub-JenkinsC_ses-20160128T160749_behavior+ecephys.nwb


In [44]:
JFiles_validOri = ["sub-JenkinsC_ses-20151015T151424_behavior+ecephys.nwb", "sub-JenkinsC_ses-20160127T110630_behavior+ecephys.nwb", "sub-JenkinsC_ses-20160128T160749_behavior+ecephys.nwb"]
RFiles_validOri = ["sub-Reggie_ses-20170119T123128_behavior+ecephys.nwb", "sub-Reggie_ses-20170125T100800_behavior+ecephys.nwb", "sub-Reggie_ses-20170117T104643_behavior+ecephys.nwb", "sub-Reggie_ses-20170124T094957_behavior+ecephys.nwb"]

In [92]:
def process(dat):
    spike_rates = dat['spike_rates']
    target_pos = dat['target_pos']
    if len(spike_rates[0].shape) == 1:
        return -1

    valid_ori_indices = [x and y for x,y in zip(target_pos[:,1]!=0, target_pos[:,0]!=0) ] #Get indices of reaches where both x and y axes are non-zero

    # TODO: Use the orientations defined in papers
    target_pos_valid = target_pos[valid_ori_indices]
    tan_fraction = target_pos_valid[:,1] / target_pos_valid[:,0]
    orientations = np.arctan(tan_fraction)
    # arctan returns [-pi/2, pi/2], but we want an orientation from [0, 2pi], so process for 2nd, 3rd, 4th quadrant
    for i in range(orientations.shape[0]):
        if target_pos_valid[i, 0] > 0 and target_pos_valid[i, 1] < 0:
            orientations[i] += np.pi * 2
            continue
        if target_pos_valid[i, 0] < 0 and target_pos_valid[i, 1] > 0:
            orientations[i] += np.pi
            continue
        if target_pos_valid[i, 0] < 0 and target_pos_valid[i, 1] < 0:
            orientations[i] += np.pi
    #print(target_pos_valid, tan_fraction, orientations)

    #To put in SVM for classification, we need to first label orientation to different classes
    #Just define 8 different kinds of orientation from 0 to 2pi
    bins = np.arange(0,2*np.pi,.25 * np.pi)
    binned_orientations = np.digitize(orientations, bins)
    #print(binned_orientations.shape)

    # Averaging the spike_rates (not across trials but inside every single trial)

    spike_rates_averaged = np.array([[np.average(spike_rates[reach_i], axis = 0)[neuron_idx] \
                                                for reach_i in range(len(spike_rates))]\
                                                    for neuron_idx in range(spike_rates[0].shape[1])])
    spike_rates_ave_valid = spike_rates_averaged.T[valid_ori_indices,:]
    spike_rates_ave_valid.shape

    clf = svm.SVC().fit(spike_rates_ave_valid, binned_orientations)
    prediction = clf.predict(spike_rates_ave_valid)
    precision = np.average(prediction == binned_orientations)
    return precision

In [86]:
locations = ['M1', 'PMC']
intervals = ['before_go', 'after_go', 'full']
spike_threshold = 0
bin_widths = [10,50,100]

In [94]:
import itertools
for file_name in RFiles_validOri[:]:
    print("---------------------------")
    print("For file {0}".format(file_name))
    for params in itertools.product(locations, intervals, bin_widths):
        with io.capture_output() as captured: ##To disable prints from loader function
            dat = load_shenoy_large("F:/dandi/000121/sub-Reggie/" + file_name, spike_threshold = spike_threshold, trialize = True, \
                                    location = params[0], interval = params[1], bin_width = params[2])
        precision = process(dat)
        print("Location: {0}, Interval: {1}, Bin Width: {2}, Classification Precision: {3}".format( \
                params[0], params[1], params[2], precision))
                            

---------------------------
For file sub-Reggie_ses-20170119T123128_behavior+ecephys.nwb
Location: M1, Interval: before_go, Bin Width: 10, Classification Precision: -1
Location: M1, Interval: before_go, Bin Width: 50, Classification Precision: 0.3592039800995025
Location: M1, Interval: before_go, Bin Width: 100, Classification Precision: 0.35323383084577115
Location: M1, Interval: after_go, Bin Width: 10, Classification Precision: 0.32424537487828625
Location: M1, Interval: after_go, Bin Width: 50, Classification Precision: 0.29892891918208375
Location: M1, Interval: after_go, Bin Width: 100, Classification Precision: 0.3037974683544304
Location: M1, Interval: full, Bin Width: 10, Classification Precision: 0.29600778967867575
Location: M1, Interval: full, Bin Width: 50, Classification Precision: 0.28237585199610515
Location: M1, Interval: full, Bin Width: 100, Classification Precision: 0.28432327166504384
Location: PMC, Interval: before_go, Bin Width: 10, Classification Precision: -1
L

In [47]:
print(dat['target_pos'].shape)
print(dat['go_times'].shape)
print(len(dat['spike_rates']))
print(len(dat['behavior']))
print(dat['behavior'][1].shape)
print(dat['spike_rates'][1].shape)

(3069, 3)
(3069,)
3069
3069
(25, 2)
(25, 46)
