In [151]:
# Code used to test different classifiers on the CO task data
import pandas as pd
import numpy as np
import xarray as xr
import math

import os
from tqdm.auto import tqdm

import matplotlib.pyplot as plt
import seaborn as sns

from pyaldata import *

#Import decoder functions
from Neural_Decoding.decoders import WienerFilterClassification
from Neural_Decoding.decoders import SVClassification
from Neural_Decoding.decoders import DenseNNClassification
from Neural_Decoding.decoders import LSTMClassification

In [123]:
# Load data
data_dir = '../raw_data/'
fname = os.path.join(data_dir, "Chewie_CO_CS_2016-10-14.mat")

# load TrialData .mat file into a DataFrame
df = mat2dataframe(fname, shift_idx_fields=True)

In [143]:
# Keep only successful trials
df = select_trials(df, "result == 'R'")

# Preprocessing
# combine time bins into longer ones, e.g. group 3 time bins together
td = combine_time_bins(df, 3)

# Obtain only the interval between idx_target_on and idx_go_cue
print('Orig:', td.M1_spikes[0].shape)
td = restrict_to_interval(td, start_point_name='idx_target_on', end_point_name='idx_go_cue')
print('New:', td.M1_spikes[0].shape)

# Remove low-firing neurons
print('Orig:', td.M1_spikes[0].shape)
td = remove_low_firing_neurons(td, "M1_spikes",  5)
td = remove_low_firing_neurons(td, "PMd_spikes", 5)
print('New:', td.M1_spikes[0].shape)

# Combine M1 and PMd
td = merge_signals(td, ["M1_spikes", "PMd_spikes"], "both_spikes")

Orig: (119, 88)
New: (40, 88)
Orig: (40, 88)
New: (40, 57)




In [144]:
# Split the data into training and testing subsets
# total number of trials
N = td.shape[0]

#Number of M1_neurons
N_M1 = td.M1_spikes[0].shape[1]
#Number of PMd_neurons
N_PMd = td.PMd_spikes[0].shape[1]

M1_spikes = np.empty([N_M1,N])
PMd_spikes = np.empty([N_PMd,N])
y = np.empty([N,1])

for i in range(N):
    # Get the neuron spikes for a given trial in train data
    M1_trial = np.transpose(td.M1_spikes[i])
    PMd_trial = np.transpose(td.PMd_spikes[i])
    
    # Sum all the spikes in the given trial and save them
    M1_spikes[:,i] = np.sum(M1_trial, axis=1)
    PMd_spikes[:,i] = np.sum(PMd_trial, axis=1)
    
    # Get the label
    y[i] = determine_angle(td.target_direction[i])



In [145]:
# Build a feature vector
F_M1 = np.empty([N, N_M1])
F_PMd = np.empty([N, N_PMd])
for i in range(N):#in range(M1_spikes.shape[1]):
    total_M1_spikes = np.sum(M1_spikes[:,i]);
    total_PMd_spikes = np.sum(PMd_spikes[:,i])
    
    f_M1 = np.transpose(M1_spikes[:,i])/total_M1_spikes
    f_PMd = np.transpose(PMd_spikes[:,i])/total_PMd_spikes
    
    # Store average firing rates
    F_M1[i,:] = f_M1
    F_PMd[i,:] = f_PMd
    
# Additional combined feature vector
F_M1_PMd = np.concatenate((F_M1, F_PMd), axis = 1)

In [146]:
# Split the data into test and train subsets
split = int(0.8*N)


y_train = y[0:split-1]
y_test = y[split:]

F_PMd_train = F_PMd[0:split-1,:]
F_PMd_test = F_PMd[split:,:]

F_M1_PMd_train = F_M1_PMd[0:split-1,:]
F_M1_PMd_test = F_M1_PMd[split:,:]

print(np.squeeze(y_train).shape)

(591,)


In [147]:
## Train classifiers
wf_classifier = WienerFilterClassification()

wf_classifier.fit(F_M1_PMd_train, np.squeeze(y_train))

wf_prediction = wf_classifier.predict(F_M1_PMd_test)

check_wf = wf_prediction==np.squeeze(y_test)
accuracy = np.count_nonzero(check_wf)/(y_test.shape[0])
print('Accuracy:', accuracy)

Accuracy: 0.2635135135135135


In [148]:
## Train classifiers
# Support vector classification
sv_classifier = SVClassification()

sv_classifier.fit(F_M1_PMd_train, np.squeeze(y_train))

sv_prediction = sv_classifier.predict(F_M1_PMd_test)

check_sv = sv_prediction==np.squeeze(y_test)
accuracy = np.count_nonzero(check_sv)/(y_test.shape[0])
print('Accuracy:', accuracy)

Accuracy: 0.8445945945945946


In [149]:
# DenseNN classifier
dnn_classifier = DenseNNClassification()

dnn_classifier.fit(F_PMd_train, np.squeeze(y_train))
dnn_prediction = dnn_classifier.predict(F_PMd_test)

check_dnn = dnn_prediction==np.squeeze(y_test)
accuracy_dnn = np.count_nonzero(check_dnn)/(y_test.shape[0])
print('Accuracy:', accuracy_dnn)

Accuracy: 0.3581081081081081
