Name: Holly Kular\
Date: 03-19-2024\
Email: hkular@ucsd.edu\
decode_L1.m\
Description: wrapper script for decoding analysis on layer 1 of probabilistic RNN

In [1]:
# import these are all needed for fnc
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.svm import SVC  
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import make_classification
from scipy.optimize import curve_fit

In [2]:
# add decoding functions
from fnc_fit_and_score import fnc_fit_and_score
from fnc_decode_L1 import fnc_decode_L1


In [3]:
# MODIFY HERE - inputs

# What are we decoding?
RNN_params = {}
RNN_params['prob_split'] = '70_30'
RNN_params['afc'] = 2
RNN_params['coh'] = 'hi'
RNN_params['feedback'] = False
RNN_params['thresh'] = [.3,.7]

# Decoding params
D_params = {}
D_params['time_avg'] = False
D_params['t_win'] = [200,-1]
D_params['n_cvs'] = 5
D_params['num_cgs'] = 30
D_params['label'] = 'stim' # 'stim' or 'choice'

# Load data
if sys.platform.startswith('linux'):
    data_dir = f"/mnt/neurocube/local/serenceslab/holly/RNN_Geo/data/rdk_{RNN_params['prob_split']}_{RNN_params['afc']}afc/feedforward_only/{RNN_params['coh']}_coh"
else:
    data_dir = f"/Volumes/serenceslab/holly/RNN_Geo/data/rdk_{RNN_params['prob_split']}_{RNN_params['afc']}afc/feedforward_only/{RNN_params['coh']}_coh"

# Change this if we want a different number of trials and different default stim (0 is the one RNN was trained on)
data_file = f"{data_dir}/Trials200_0expected.npz"

# Timing of task
task_info = {}
task_info['trials'] = 200
task_info['trial_dur'] = 250  # trial duration (timesteps)
task_info['stim_on'] = 80
task_info['stim_dur'] = 50

save_plt = False

In [6]:
# Do decoding
acc = fnc_decode_L1(RNN_params, D_params, data_file, task_info)

In [5]:
acc

nan

In [None]:
# Make plots based on outputs

# Plot decoding accuracy over time
plt.figure()
plt.plot(range(task_info['trial_dur']), decoding_acc)
plt.xlabel('Time Step')
plt.ylabel('Decoding Accuracy')
plt.title('Decoding Accuracy Over Time')
plt.axvspan(task_info['stim_on'], task_info['stim_on']+task_info['stim_dur'], color = 'gray', alpha = 0.3)
if save_plt:
    plt.savefig(f"{data_dir}/decode_stim_exp.png")
plt.show()  


# fit logistic function
def logistic_func(x, a, b, c, d):
    return a / (1 + np.exp(-c * (x - d))) + b

popt, _ = curve_fit(logistic_func, range(0, task_info['trial_dur']), decoding_acc[0:])

# 'popt' will contain the fitted parameters (a, b, c) of the log function

# Optional: Estimate the y-value of the asymptote (assuming based on parameter b)
asymptote_y = popt[1]
print("Estimated y-value of the asymptote:", asymptote_y)

# Optional: Visualization
plt.plot(range(0, task_info['trial_dur']), decoding_acc[0:], label='Data')
plt.plot(range(0, task_info['trial_dur']), logistic_func(range(0, task_info['trial_dur']), *popt), label=f'y asymptote of logistic {round(asymptote_y, 3)}')
plt.axvspan(task_info['stim_on'], task_info['stim_on']+task_info['stim_dur'], color = 'gray', alpha = 0.3)
plt.xlabel('Time Step')
plt.ylabel('Decoding Accuracy')
plt.title('Decoding Accuracy Over Time (Logistic Fit)')
plt.legend()
if save_plt:
    plt.savefig(f"{data_dir}/decode_stim_expfit.png")
plt.show()
