In [9]:
#Import standard packages
import pandas as pd
import numpy as np
import xarray as xr
import math

import os
from tqdm.auto import tqdm

import matplotlib.pyplot as plt
import seaborn as sns

from pyaldata import *

import matplotlib.pyplot as plt
%matplotlib inline
from scipy import io
from scipy import stats
from sklearn.metrics import r2_score
import pickle
from tqdm import tqdm
import csv

#Import function to get the covariate matrix that includes spike history from previous bins
from Neural_Decoding.preprocessing_funcs import get_spikes_with_history

#Import metrics
from Neural_Decoding.metrics import get_R2
from Neural_Decoding.metrics import get_rho

#Import hyperparameter optimization packages
try:
    from hyperopt import fmin, hp, Trials, tpe, STATUS_OK
except ImportError:
    print("\nWARNING: hyperopt package is not installed. You will be unable to use section 5.")
    pass

#Import decoder functions
from Neural_Decoding.decoders import DenseNNDecoder
from Neural_Decoding.decoders import SVClassification

In [2]:
# Function that reassigns different angles into classes from 1 to 8
# going anticlockwise starting from +x direction
def determine_angle(angle):
    if angle == 0:
        return 1
    elif angle == math.pi/4:
        return 2
    elif angle == math.pi/2:
        return 3
    elif angle == 3*math.pi/4:
        return 4
    elif angle == math.pi:
        return 5
    elif angle == -3*math.pi/4:
        return 6
    elif angle == -math.pi/2:
        return 7
    elif angle == -math.pi/4:
        return 8

In [3]:
# Load data
data_dir = '../raw_data/'
fname = os.path.join(data_dir, "Chewie_CO_CS_2016-10-14.mat")

# load TrialData .mat file into a DataFrame
df = mat2dataframe(fname, shift_idx_fields=True)

# Keep only successful trials
df = select_trials(df, "result == 'R'")

In [4]:
## Classification preprocessing

# Preprocessing
# combine time bins into longer ones, e.g. group 3 time bins together
td_class = combine_time_bins(df, 3)

# Obtain only the interval between idx_target_on and idx_go_cue
td_class = restrict_to_interval(td_class, start_point_name='idx_target_on', end_point_name='idx_go_cue')

# Remove low-firing neurons
td_class = remove_low_firing_neurons(td_class, "M1_spikes",  5)
td_class = remove_low_firing_neurons(td_class, "PMd_spikes", 5)

# total number of trials
N = td_class.shape[0]

#Number of M1_neurons
N_M1 = td_class.M1_spikes[0].shape[1]
#Number of PMd_neurons
N_PMd = td_class.PMd_spikes[0].shape[1]

M1_spikes = np.empty([N_M1,N])
PMd_spikes = np.empty([N_PMd,N])
y = np.empty([N,1])

for i in range(N):
    # Get the neuron spikes for a given trial in train data
    M1_trial = np.transpose(td_class.M1_spikes[i])
    PMd_trial = np.transpose(td_class.PMd_spikes[i])
    
    # Sum all the spikes in the given trial and save them
    M1_spikes[:,i] = np.sum(M1_trial, axis=1)
    PMd_spikes[:,i] = np.sum(PMd_trial, axis=1)
    
    # Get the label
    y[i] = determine_angle(td_class.target_direction[i])
    
# Build a feature vector
F_M1 = np.empty([N, N_M1])
F_PMd = np.empty([N, N_PMd])
for i in range(N):#in range(M1_spikes.shape[1]):
    total_M1_spikes = np.sum(M1_spikes[:,i]);
    total_PMd_spikes = np.sum(PMd_spikes[:,i])
    
    f_M1 = np.transpose(M1_spikes[:,i])/total_M1_spikes
    f_PMd = np.transpose(PMd_spikes[:,i])/total_PMd_spikes
    
    # Store average firing rates
    F_M1[i,:] = f_M1
    F_PMd[i,:] = f_PMd
    
# Combine M1 and PMd features
F_M1_PMd = np.concatenate((F_M1, F_PMd), axis = 1)



In [5]:
# Split the data into test and train subsets
split = int(0.8*N)

y_train = y[0:split-1]
y_test = y[split:]

F_M1_PMd_train = F_M1_PMd[0:split-1,:]
F_M1_PMd_test = F_M1_PMd[split:,:]


## Train classifiers
# Support vector classification
sv_classifier = SVClassification()

sv_classifier.fit(F_M1_PMd_train, np.squeeze(y_train))

In [13]:
## Regression preprocessing
folder='../preprocessed_data/'
#ENTER THE FOLDER THAT YOUR DATA IS IN

with open(folder+'individual_tar_data.pickle','rb') as f:
    M1, M1_PMd,pos_binned,vels_binned,sizes,trial_len=pickle.load(f,encoding='latin1') #If using python 3

neural_data = M1
kinematics = [pos_binned, vels_binned] 

FFNN_models = []
label_means = []
x_flat_mean = []
x_flat_std = []
N_tar = 8
for (idx,output) in enumerate(kinematics):
    start = 0
    end = sum(trial_len[0:sizes[0]])

    # Loop over the data to obtain decoders for all 8 targets
    for tar in tqdm(range(1, N_tar+1)): 

        # Preprocess data
        bins_before=6 #How many bins of neural data prior to the output are used for decoding
        bins_current=1 #Whether to use concurrent time bin of neural data
        bins_after=0 #How many bins of neural data after the output are used for decoding

        # Format for recurrent neural networks (SimpleRNN, GRU, LSTM)
        # Function to get the covariate matrix that includes spike history from previous bins
        X=get_spikes_with_history(neural_data[start:end],bins_before,bins_after,bins_current)

        # Format for Wiener Filter, Wiener Cascade, XGBoost, and Dense Neural Network
        #Put in "flat" format, so each "neuron / time" is a single feature
        X_flat=X.reshape(X.shape[0],(X.shape[1]*X.shape[2]))

        # Output covariates
        #Set decoding output
        y=output[start:end]

        # Split into training / testing / validation sets
        #Set what part of data should be part of the training/testing/validation sets
        training_range=[0, 0.7]
        testing_range=[0.7, 0.85]
        valid_range=[0.85,1]

        # Split data:
        num_examples=X.shape[0]

        #Note that each range has a buffer of"bins_before" bins at the beginning, and "bins_after" bins at the end
        #This makes it so that the different sets don't include overlapping neural data
        training_set=np.arange(int(np.round(training_range[0]*num_examples))+bins_before,int(np.round(training_range[1]*num_examples))-bins_after)
        testing_set=np.arange(int(np.round(testing_range[0]*num_examples))+bins_before,int(np.round(testing_range[1]*num_examples))-bins_after)
        valid_set=np.arange(int(np.round(valid_range[0]*num_examples))+bins_before,int(np.round(valid_range[1]*num_examples))-bins_after)

        #Get training data
        X_train=X[training_set,:,:]
        X_flat_train=X_flat[training_set,:]
        y_train=y[training_set,:]

        #Get testing data
        X_test=X[testing_set,:,:]
        X_flat_test=X_flat[testing_set,:]
        y_test=y[testing_set,:]

        #Get validation data
        X_valid=X[valid_set,:,:]
        X_flat_valid=X_flat[valid_set,:]
        y_valid=y[valid_set,:]

        # Process covariates
        #Z-score "X" inputs. 
       # X_train_mean=np.nanmean(X_train,axis=0)
       # X_train_std=np.nanstd(X_train,axis=0)
       # X_train=(X_train-X_train_mean)/X_train_std
        #X_test=(X_test-X_train_mean)/X_train_std
       # X_valid=(X_valid-X_train_mean)/X_train_std

        #Z-score "X_flat" inputs. 
        #X_flat_train_mean=np.nanmean(X_flat_train,axis=0)
       # X_flat_train_std=np.nanstd(X_flat_train,axis=0)
        #X_flat_train=(X_flat_train-X_flat_train_mean)/X_flat_train_std
        #X_flat_test=(X_flat_test-X_flat_train_mean)/X_flat_train_std
        #X_flat_valid=(X_flat_valid-X_flat_train_mean)/X_flat_train_std
        
        
        #Zero-center outputs
       # y_train_mean=np.mean(y_train,axis=0)
       # y_train=y_train-y_train_mean
        #y_test=y_test-y_train_mean
       # y_valid=y_valid-y_train_mean

        #Do optimization
        # Define parameters for hyperoptimisation
        def dnn_evaluate2(params):
            #Put parameters in proper format
            num_units=int(params['num_units'])
            frac_dropout=float(params['frac_dropout'])
            n_epochs=int(params['n_epochs'])
            model_dnn=DenseNNDecoder(units=[num_units,num_units],dropout=frac_dropout,num_epochs=n_epochs) #Define model
            model_dnn.fit(X_flat_train,y_train) #Fit model
            y_valid_predicted_dnn=model_dnn.predict(X_flat_valid) #Get validation set predictions
            return -np.mean(get_R2(y_valid,y_valid_predicted_dnn)) #Return -R2 value of validation set

        #The range of values I'll look at for the parameter
        #"hp.quniform" will allow us to look at integer (rather than continuously spaced) values.
        #So for "num_units", we are looking at values between 50 and 700 by 10 (50,60,70,...700)
        #"hp.uniform" looks at continuously spaced values
        space = {
            'frac_dropout': hp.uniform('frac_dropout', 0., 0.5),
            'num_units': hp.quniform('num_units', 50,700,10),
            'n_epochs': hp.quniform('n_epochs', 2,15,1),
        }
        #object that holds iteration results
        trials = Trials()
        
        #Set the number of evaluations below (20 in this example)
        hyperoptBest = fmin(dnn_evaluate2, space, algo=tpe.suggest, max_evals=5, trials=trials)        
        

        # Run decoder
        #Declare model
        model_dnn=DenseNNDecoder(units=int(hyperoptBest['num_units']),dropout=hyperoptBest['frac_dropout'],num_epochs=int(hyperoptBest['n_epochs']))

        #Fit model
        model_dnn.fit(X_flat_train,y_train)


       

        #R2_vw = r2_score(y_valid,y_valid_predicted_wf, multioutput='variance_weighted')

        # Save the R2 value for a given neural data and kinematics
        #R2[idx,tar-1] = R2_vw
        FFNN_models.append(model_dnn)
        #label_means.append(y_train_mean)
        #x_flat_mean.append(X_flat_train_mean)
        #x_flat_std.append(X_flat_train_std)
        
        # Find new indexes based on trial_len and sizes variables
        start = end
        new_elements = sum(trial_len[sum(sizes[0:tar]):sum(sizes[0:tar+1])])
        end = end + new_elements

  0%|                                                                                            | 0/8 [00:00<?, ?it/s]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|██████████▏                                        | 1/5 [00:00<00:02,  1.34trial/s, best loss: 47.07449618129339][A
 40%|████████████████████▍                              | 2/5 [00:03<00:06,  2.05s/trial, best loss: 11.92874929339955][A
 60%|██████████████████████████████▌                    | 3/5 [00:04<00:03,  1.64s/trial, best loss: 11.92874929339955][A
 80%|████████████████████████████████████████▊          | 4/5 [00:06<00:01,  1.47s/trial, best loss: 11.92874929339955][A
100%|███████████████████████████████████████████████████| 5/5 [00:08<00:00,  1.75s/trial, best loss: 9.737791695175348][A


 12%|██████████▌                                                                         | 1/8 [00:10<01:13, 10.56s/it]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|█████████▊                                       | 1/5 [00:02<00:08,  2.06s/trial, best loss: -0.6555297737601651][A
 40%|███████████████████▌                             | 2/5 [00:03<00:05,  1.75s/trial, best loss: -0.6555297737601651][A
 60%|█████████████████████████████▍                   | 3/5 [00:06<00:04,  2.10s/trial, best loss: -0.6555297737601651][A
 80%|███████████████████████████████████████▏         | 4/5 [00:15<00:05,  5.08s/trial, best loss: -0.6555297737601651][A
100%|█████████████████████████████████████████████████| 5/5 [00:21<00:00,  4.29s/trial, best loss: -0.6555297737601651][A


 25%|█████████████████████                                                               | 2/8 [00:33<01:47, 17.88s/it]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|█████████▊                                       | 1/5 [00:04<00:19,  4.93s/trial, best loss: -0.2963717227267783][A
 40%|███████████████████▌                             | 2/5 [00:09<00:14,  4.86s/trial, best loss: -0.2963717227267783][A
 60%|█████████████████████████████▍                   | 3/5 [00:11<00:06,  3.43s/trial, best loss: -0.2963717227267783][A
 80%|███████████████████████████████████████▏         | 4/5 [00:17<00:04,  4.27s/trial, best loss: -0.2963717227267783][A
100%|█████████████████████████████████████████████████| 5/5 [00:20<00:00,  4.01s/trial, best loss: -0.2963717227267783][A


 38%|███████████████████████████████▌                                                    | 3/8 [00:56<01:40, 20.03s/it]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|█████████▊                                       | 1/5 [00:01<00:05,  1.30s/trial, best loss: -0.6201126741845309][A
 40%|███████████████████▌                             | 2/5 [00:02<00:03,  1.16s/trial, best loss: -0.6677966597140033][A
 60%|█████████████████████████████▍                   | 3/5 [00:04<00:02,  1.47s/trial, best loss: -0.6677966597140033][A
 80%|███████████████████████████████████████▏         | 4/5 [00:05<00:01,  1.31s/trial, best loss: -0.6677966597140033][A
100%|█████████████████████████████████████████████████| 5/5 [00:07<00:00,  1.46s/trial, best loss: -0.7312618583277997][A


 50%|██████████████████████████████████████████                                          | 4/8 [01:04<01:02, 15.56s/it]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|██████████▏                                        | 1/5 [00:02<00:11,  2.84s/trial, best loss: 8.846812300461245][A
 40%|████████████████████▍                              | 2/5 [00:04<00:06,  2.27s/trial, best loss: 8.846812300461245][A
 60%|██████████████████████████████▌                    | 3/5 [00:06<00:04,  2.27s/trial, best loss: 8.492721901624556][A
 80%|████████████████████████████████████████▊          | 4/5 [00:07<00:01,  1.76s/trial, best loss: 8.492721901624556][A
100%|███████████████████████████████████████████████████| 5/5 [00:18<00:00,  3.66s/trial, best loss: 8.492721901624556][A


 62%|████████████████████████████████████████████████████▌                               | 5/8 [01:24<00:51, 17.07s/it]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|██████████                                        | 1/5 [00:00<00:03,  1.28trial/s, best loss: 0.3801344081359928][A
 40%|███████████████████▌                             | 2/5 [00:04<00:07,  2.66s/trial, best loss: -0.2549645203426927][A
 60%|████████████████████████████▊                   | 3/5 [00:05<00:03,  1.77s/trial, best loss: -0.39919681378563454][A
 80%|███████████████████████████████████████▏         | 4/5 [00:08<00:02,  2.13s/trial, best loss: -0.4561572630450661][A
100%|█████████████████████████████████████████████████| 5/5 [00:10<00:00,  2.10s/trial, best loss: -0.4561572630450661][A


 75%|███████████████████████████████████████████████████████████████                     | 6/8 [01:36<00:30, 15.40s/it]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|█████████▌                                      | 1/5 [00:04<00:19,  4.88s/trial, best loss: -0.39906487643251715][A
 40%|███████████████████▏                            | 2/5 [00:09<00:13,  4.66s/trial, best loss: -0.39906487643251715][A
 60%|█████████████████████████████▍                   | 3/5 [00:18<00:13,  6.87s/trial, best loss: -0.4155730120567039][A
 80%|███████████████████████████████████████▏         | 4/5 [00:23<00:06,  6.12s/trial, best loss: -0.4155730120567039][A
100%|█████████████████████████████████████████████████| 5/5 [00:25<00:00,  5.09s/trial, best loss: -0.4155730120567039][A


 88%|█████████████████████████████████████████████████████████████████████████▌          | 7/8 [02:06<00:19, 19.93s/it]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|█████████▊                                       | 1/5 [00:03<00:14,  3.60s/trial, best loss: -0.3872319374652566][A
 40%|███████████████████▌                             | 2/5 [00:04<00:06,  2.30s/trial, best loss: -0.5379320566601329][A
 60%|█████████████████████████████▍                   | 3/5 [00:06<00:03,  1.96s/trial, best loss: -0.5379320566601329][A
 80%|███████████████████████████████████████▏         | 4/5 [00:16<00:05,  5.17s/trial, best loss: -0.5677950301872198][A
100%|█████████████████████████████████████████████████| 5/5 [00:18<00:00,  3.60s/trial, best loss: -0.5677950301872198][A


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [02:27<00:00, 18.48s/it]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|█████████▊                                       | 1/5 [00:00<00:03,  1.03trial/s, best loss: -0.5894689315887334][A
 40%|███████████████████▌                             | 2/5 [00:02<00:03,  1.11s/trial, best loss: -0.6613976925760454][A
 60%|█████████████████████████████▍                   | 3/5 [00:03<00:02,  1.10s/trial, best loss: -0.6613976925760454][A
 80%|███████████████████████████████████████▏         | 4/5 [00:04<00:01,  1.19s/trial, best loss: -0.6613976925760454][A
100%|█████████████████████████████████████████████████| 5/5 [00:08<00:00,  1.76s/trial, best loss: -0.7108434003024156][A


 12%|██████████▌                                                                         | 1/8 [00:10<01:15, 10.80s/it]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|█████████▊                                       | 1/5 [00:01<00:04,  1.18s/trial, best loss: -0.9438925259336388][A
 40%|███████████████████▌                             | 2/5 [00:08<00:14,  4.93s/trial, best loss: -0.9438925259336388][A
 60%|█████████████████████████████▍                   | 3/5 [00:14<00:10,  5.34s/trial, best loss: -0.9438925259336388][A
 80%|███████████████████████████████████████▏         | 4/5 [00:17<00:04,  4.45s/trial, best loss: -0.9438925259336388][A
100%|█████████████████████████████████████████████████| 5/5 [00:19<00:00,  3.81s/trial, best loss: -0.9438925259336388][A


 25%|█████████████████████                                                               | 2/8 [00:30<01:36, 16.10s/it]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|██████████                                        | 1/5 [00:02<00:08,  2.19s/trial, best loss: -0.652520371636373][A
 40%|███████████████████▌                             | 2/5 [00:05<00:07,  2.60s/trial, best loss: -0.6617799687034007][A
 60%|█████████████████████████████▍                   | 3/5 [00:09<00:07,  3.56s/trial, best loss: -0.6691726036945913][A
 80%|███████████████████████████████████████▏         | 4/5 [00:21<00:06,  6.81s/trial, best loss: -0.6691726036945913][A
100%|█████████████████████████████████████████████████| 5/5 [00:24<00:00,  4.92s/trial, best loss: -0.6691726036945913][A


 38%|███████████████████████████████▌                                                    | 3/8 [00:57<01:44, 20.92s/it]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|█████████▊                                       | 1/5 [00:02<00:09,  2.30s/trial, best loss: -0.8710349679871947][A
 40%|███████████████████▌                             | 2/5 [00:04<00:06,  2.19s/trial, best loss: -0.9012364631382119][A
 60%|█████████████████████████████▍                   | 3/5 [00:06<00:04,  2.06s/trial, best loss: -0.9054024159917806][A
 80%|███████████████████████████████████████▏         | 4/5 [00:07<00:01,  1.81s/trial, best loss: -0.9054024159917806][A
100%|█████████████████████████████████████████████████| 5/5 [00:08<00:00,  1.75s/trial, best loss: -0.9054024159917806][A


 50%|██████████████████████████████████████████                                          | 4/8 [01:07<01:06, 16.51s/it]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|█████████▊                                       | 1/5 [00:03<00:12,  3.22s/trial, best loss: -0.5981198502778265][A
 40%|███████████████████▌                             | 2/5 [00:04<00:05,  1.79s/trial, best loss: -0.6095930643815364][A
 60%|█████████████████████████████▍                   | 3/5 [00:06<00:04,  2.10s/trial, best loss: -0.6800973516975946][A
 80%|███████████████████████████████████████▏         | 4/5 [00:09<00:02,  2.57s/trial, best loss: -0.6800973516975946][A
100%|█████████████████████████████████████████████████| 5/5 [00:15<00:00,  3.12s/trial, best loss: -0.7039435327470476][A


 62%|████████████████████████████████████████████████████▌                               | 5/8 [01:25<00:52, 17.34s/it]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|█████████▊                                       | 1/5 [00:02<00:08,  2.10s/trial, best loss: -0.8894319363635703][A
 40%|███████████████████▌                             | 2/5 [00:03<00:04,  1.56s/trial, best loss: -0.8894319363635703][A
 60%|█████████████████████████████▍                   | 3/5 [00:04<00:02,  1.45s/trial, best loss: -0.8894319363635703][A
 80%|███████████████████████████████████████▏         | 4/5 [00:06<00:01,  1.62s/trial, best loss: -0.8894319363635703][A
100%|█████████████████████████████████████████████████| 5/5 [00:08<00:00,  1.64s/trial, best loss: -0.8894319363635703][A


 75%|███████████████████████████████████████████████████████████████                     | 6/8 [01:35<00:29, 14.57s/it]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|█████████▊                                       | 1/5 [00:03<00:13,  3.40s/trial, best loss: -0.7791719587393315][A
 40%|███████████████████▌                             | 2/5 [00:06<00:10,  3.40s/trial, best loss: -0.7791719587393315][A
 60%|█████████████████████████████▍                   | 3/5 [00:09<00:05,  2.96s/trial, best loss: -0.7791719587393315][A
 80%|███████████████████████████████████████▏         | 4/5 [00:10<00:02,  2.31s/trial, best loss: -0.7791719587393315][A
100%|█████████████████████████████████████████████████| 5/5 [00:12<00:00,  2.51s/trial, best loss: -0.7791719587393315][A


 88%|█████████████████████████████████████████████████████████████████████████▌          | 7/8 [01:49<00:14, 14.60s/it]


  0%|                                                                            | 0/5 [00:00<?, ?trial/s, best loss=?][A
 20%|█████████▊                                       | 1/5 [00:03<00:15,  3.78s/trial, best loss: -0.8779999233390507][A
 40%|███████████████████▌                             | 2/5 [00:06<00:10,  3.42s/trial, best loss: -0.9147427130905266][A
 60%|█████████████████████████████▍                   | 3/5 [00:11<00:07,  3.92s/trial, best loss: -0.9147427130905266][A
 80%|███████████████████████████████████████▏         | 4/5 [00:15<00:03,  3.90s/trial, best loss: -0.9147427130905266][A
100%|█████████████████████████████████████████████████| 5/5 [00:18<00:00,  3.74s/trial, best loss: -0.9147427130905266][A


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [02:10<00:00, 16.32s/it]


In [17]:
print(len(FFNN_models))

16


In [11]:
# Import the whole data set
# combine time bins into longer ones
td_full = combine_time_bins(df, 3)

# Remove low-firing neurons
td_full = remove_low_firing_neurons(td_full, "M1_spikes",  5)
td_full = remove_low_firing_neurons(td_full, "PMd_spikes", 5)

# Get the signal from idx_go_cue
df.idx_movement_on = df.idx_movement_on.astype(int)
td_full = restrict_to_interval(td_full, start_point_name='idx_go_cue', end_point_name='idx_trial_end')


td_full = smooth_signals(td_full, ["M1_spikes", "PMd_spikes"], std=0.05)
 



In [19]:
## Make predictions
N_trials = 740
start = 0
kinematics = [td_full.pos, td_full.vel]

for (idx,output) in enumerate(kinematics):
    predictions = []
    y_valid_full = []
    y_pred_full = []
    for i in range(N_trials-int(0.8*N_trials)):
        #end = start + trial - 1
        #print(int(0.8*N_trials),i)
        neural_data = td_full.M1_spikes[int(0.8*N_trials)+i]
        y_valid = output[int(0.8*N_trials)+i]

       # class_prediction = y_test[i]
        class_prediction =sv_classifier.predict([F_M1_PMd_test[i]])
        #print(class_prediction)
        predictions.append(class_prediction)



        # Preprocess data
        bins_before=6 #How many bins of neural data prior to the output are used for decoding
        bins_current=1 #Whether to use concurrent time bin of neural data
        bins_after=0 #How many bins of neural data after the output are used for decoding

        # Format for recurrent neural networks (SimpleRNN, GRU, LSTM)
        # Function to get the covariate matrix that includes spike history from previous bins
        X=get_spikes_with_history(neural_data,bins_before,bins_after,bins_current)

        # Format for Wiener Filter, Wiener Cascade, XGBoost, and Dense Neural Network
        #Put in "flat" format, so each "neuron / time" is a single feature
        #X_flat_final=X.reshape(X.shape[0],(X.shape[1]*X.shape[2]))


        X_flat_final = np.nan_to_num(X_flat_final)

       # X_flat_final = (X_flat_final-x_flat_mean[int(class_prediction-1)])/x_flat_std[int(class_prediction-1)]

       # y_valid = y_valid-label_means[int(class_prediction-1)]

        # Avoid some errors
        if X_flat_final.shape[0] == 0:
            continue

        y_valid_predicted = FFNN_models[int(class_prediction-1)+8*idx].predict(X_flat_final)


        y_valid_full.append(y_valid)
        y_pred_full.append(y_valid_predicted)

        # Update the starting point of the data for next iteration of the loop
        start = end + 1

        #y_pred_plot = y_valid_predicted
        #plt.plot(np.transpose(y_pred_plot[:,0]), np.transpose(y_pred_plot[:,1]))
    
    y_valid_full = np.array(y_valid_full)
    y_pred_full = np.array(y_pred_full)


    for i in range(y_valid_full.shape[0]):
        if i == 0:
            y_val = np.array(np.squeeze(y_valid_full[i]))
            y_pred = np.array(np.squeeze(y_pred_full[i]))
        else:
            y_val = np.concatenate((y_val, np.squeeze(y_valid_full[i])), axis=0)
            y_pred = np.concatenate((y_pred, np.squeeze(y_pred_full[i])), axis=0)

    R2_vw = r2_score(y_val,y_pred, multioutput='variance_weighted')
    print('R2 value:', R2_vw)

ValueError: in user code:

    c:\users\jon\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\keras\engine\training.py:1478 predict_function  *
        return step_function(self, iterator)
    c:\users\jon\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\keras\engine\training.py:1468 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    c:\users\jon\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1259 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    c:\users\jon\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2730 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    c:\users\jon\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:3417 _call_for_each_replica
        return fn(*args, **kwargs)
    c:\users\jon\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\keras\engine\training.py:1461 run_step  **
        outputs = model.predict_step(data)
    c:\users\jon\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\keras\engine\training.py:1434 predict_step
        return self(x, training=False)
    c:\users\jon\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\keras\engine\base_layer.py:998 __call__
        input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
    c:\users\jon\appdata\local\programs\python\python37\lib\site-packages\tensorflow\python\keras\engine\input_spec.py:259 assert_input_compatibility
        ' but received input with shape ' + display_shape(x.shape))

    ValueError: Input 0 of layer sequential_109 is incompatible with the layer: expected axis -1 of input shape to have value 441 but received input with shape (None, 7, 63)
