In [11]:
# testing starmap

# import these are all needed for fnc
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
from fnc_fit_and_score import fnc_fit_and_score
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.svm import SVC  
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import make_classification
from scipy.optimize import curve_fit
import multiprocessing
from tqdm import tqdm 
# what are we decoding
RNN_params = {}
RNN_params['prob_split'] = '70_30'
RNN_params['afc'] = 2
RNN_params['coh'] = 'hi'
RNN_params['feedback'] = False
RNN_params['thresh'] = [.3,.7]

# Decoding params
D_params = {}
D_params['time_avg'] = False
D_params['t_win'] = [200,-1]
D_params['n_cvs'] = 5
D_params['num_cgs'] = 30
D_params['label'] = 'stim' # 'stim' or 'choice'

# Load data
if sys.platform.startswith('linux'):
    data_dir = f"/mnt/neurocube/local/serenceslab/holly/RNN_Geo/data/rdk_{RNN_params['prob_split']}_{RNN_params['afc']}afc/feedforward_only/{RNN_params['coh']}_coh"
else:
    data_dir = f"/Volumes/serenceslab/holly/RNN_Geo/data/rdk_{RNN_params['prob_split']}_{RNN_params['afc']}afc/feedforward_only/{RNN_params['coh']}_coh"

# Change this if we want a different number of trials and different default stim (0 is the one RNN was trained on)
data_file = f"{data_dir}/Trials200_0expected.npz"

# Timing of task
task_info = {}
task_info['trials'] = 200
task_info['trial_dur'] = 250  # trial duration (timesteps)
task_info['stim_on'] = 80
task_info['stim_dur'] = 50

save_plt = False
prob_split = RNN_params.get('prb_split', '70_30')
afc = RNN_params.get('afc', 2)
coh = RNN_params.get('coh', 'hi')
feedback = RNN_params.get('feedback', False)
thresh = RNN_params.get('thresh', [.3, .7])
time_avg = D_params.get('time_avg', False)
t_win = D_params.get('t_win', [200, -1])
label = D_params.get('label', 'stim')
n_cvs = D_params.get('n_cvs', 5)
num_cgs = D_params.get('num_cgs', 30)
# penalties to eval
Cs = np.logspace( -5,1, num_cgs )

# store the accuracy
acc = np.full( ( n_cvs ), np.nan )

# set up the grid
param_grid = { 'C': Cs, 'kernel': ['linear'] }

# define object - use a SVC that balances class weights (because they are biased, e.g. 70/30)
# note that can also specify cv folds here, but I'm doing it by hand below in a loop
grid = GridSearchCV( SVC(class_weight = 'balanced'),param_grid,refit=True,verbose=0 )

# load data
data = np.load(data_file)

# set-up vars for decoding   
data_d = data['fr1']# layer 1 firing rate [trial x time step x unit] matrix
labs = data['labs'].squeeze()

# get some info about structure of the data
tris = data_d.shape[0]             # number of trials
tri_ind = np.arange(0,tris)      # list from 0...tris
hold_out = int( tris / n_cvs )   # how many trials to hold out
# pre-allocate
decoding_acc = np.nan



if __name__ == "__main__":
    pool = Pool(processes=round(os.cpu_count() * .7))
    with pool:  # use 70% of cpus
        results = pool.starmap(fnc_fit_and_score, [
            (t_step, data_d[:, t_step, :], tri_ind, hold_out, n_cvs, labs, label, thresh, grid)
            for t_step in range(task_info['trial_dur'])
        ], chunksize = 10)

    # Process the results from each worker process (list of lists of accuracies)
    decoding_acc = np.mean(np.array(results), axis=1)  # Calculate mean accuracy for each time step


Decoding Progress:   0%|                                | 0/250 [01:17<?, ?it/s]
