In [None]:
from sklearn.linear_model import LinearRegression, Lasso, Ridge
import numpy as np
import os
import seaborn as sns
import sys
from functions import *

colors = ['#b3e2cd', '#fdcdac']

# Generate Figure 7g: the r^2 value before and after silencing of SST neurons

In [None]:
def beta_conj(rnn_activity, stims, rules, resps, hp_task_test, hp_test):
    """ Compute the strength of conjunctive coding for each neuron via fitting a linear regression model """
    
    # setup
    n_trials = rnn_activity.shape[0]
    n_stims = 4*3*2    # 4 ref cards, three test cards are uniquely determined by the ref card, but the order can change
    n_rules = 2
    n_responses = 3
    n_neurons = rnn_activity.shape[-1]

    # get the data matrix
    X_stim = np.zeros([n_trials, n_stims])    # data matrix for the stimulus regressor
    X_rule = np.zeros([n_trials, n_rules])
    X_resp = np.zeros([n_trials, n_responses])
    X_stimxrule = np.zeros([n_trials, n_stims*n_rules])    # data matrix for the conjunction
    stim_dict = {}    # dict from the stimulus to an index
    rule_dict = {'color': 0, 'shape': 1}
    resp_dict = {'0':0, '1':1, '2':2}

    # compute stim_dict
    cards = [(0, 0), (0, 1), (1, 0), (1, 1)]
    i = 0
    for ref_card in cards:
        test_cards = [c for c in cards if c!=ref_card]
        for (card1, card2, card3) in itertools.permutations(test_cards):
            stim_dict[(ref_card, card1, card2, card3)] = i
            i += 1

    # fill in the data matrix
    for rule in rule_dict.keys():
        trials = [i for i in range(len(rules)) if rules[i]==rule]
        X_rule[trials, rule_dict[rule]] = 1
    for resp in resp_dict.keys():
        trials = [i for i in range(len(resps)) if resps[i]==int(resp)]
        X_resp[trials, resp_dict[resp]] = 1
    for stim in stim_dict.keys():
        trials = [i for i in range(len(stims)) if stims[i]==stim]
        X_stim[trials, stim_dict[stim]] = 1
    for rule in rule_dict.keys():
        if rule=='color':
            r = 0
        elif rule=='shape':
            r = 1
        for stim in stim_dict.keys():
            trials = [i for i in range(len(stims)) if stims[i]==stim and rules[i]==rule]
            X_stimxrule[trials, r*n_stims+stim_dict[stim]] = 1

    # get the firing rate
    ts = np.arange(hp_task_test['resp_start']//hp_test['dt'], hp_task_test['resp_end']//hp_test['dt'])    # the time steps to use
    y = np.mean(rnn_activity[:, ts, 0, :], axis=1)    # time averaged activity

    # fit a linear model
    X = np.concatenate([X_stim, X_rule], axis=1)
    betas = []
    beta_conjs = []
    rsqrs = []
    rsqr_conjs = []
    
    for n in range(n_neurons):
        activity = y[:, n]

        # fit the linear model 
        reg = LinearRegression(fit_intercept=False).fit(X=X, y=activity)    # regression for the linear model
        if np.std(activity)<=1e-10:
            r_sqr = 0    # if the activity is minimally modulated by trial types, manually set r^2 to be 0 (the score method will return 1 since a constant function can fit well)
        else:
            r_sqr = reg.score(X=X, y=y[:,n])
        beta = reg.coef_
        
        # fit the conjunctive regressor on the residual activity not explained by the linear model
        residual_activity = y[:, n] - reg.predict(X)    # residual neural activity unexplained by linear regressors
        reg_conj = LinearRegression(fit_intercept=False).fit(X=X_stimxrule, y=residual_activity)    # fit a model on the residual activity
        if np.std(residual_activity)<=1e-10:
            r_sqr_conj = 0
        else:
            r_sqr_conj = reg_conj.score(X=X_stimxrule, y=residual_activity)
        beta_conj = reg_conj.coef_


        rsqrs.append(r_sqr)
        rsqr_conjs.append(r_sqr_conj)
        betas.extend(beta)
        beta_conjs.extend(beta_conj)
        
        
        
        
        
    return betas, beta_conjs, rsqrs, rsqr_conjs

In [None]:
start = time.time()

all_data = []

model_dir = '' 
models = sorted(os.listdir(model_dir))
betas_allmodels = []
betas_allmodels_nosst = []
beta_conjs_allmodels = []
beta_conjs_allmodels_nosst = []
rsqrs_allmodels = []
rsqrs_nosst_allmodels = []
rsqr_conjs_allmodels = []
rsqr_conjs_nosst_allmodels = []

n_models = 0
for model_name in models:
    if '2023-05-10' in model_name and 'success' in model_name:
        print(model_name+'\n')
        
        # load model
        path_to_file = model_dir+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, 
                                                                            simple=False, plot=False, toprint=False)
        
    
        # load test data
        with open('/where/test/data/is/stored/{}'.format(model_name+'_testdata_noiseless'), 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data'] 
        mean_perf = np.mean([_[0] for _ in test_data['perfs']])
        mean_perf_rule = np.mean([_[0] for _ in test_data['perf_rules']])
        if mean_perf<0.8 or mean_perf_rule<0.8:
            print('low performing model ({}/{})'.format(mean_perf, mean_perf_rule))
            continue
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
        neuron_idx = model.rnn.cg_idx['sr_esoma'].tolist()
        rnn_activity = rnn_activity[:, :, :, neuron_idx]


        # silence SST
        # generate some neural data
        with open('/where/test/data/is/stored/{}'.format(model_name+'_testdata_silenceSRSST_noiseless'), 'rb') as f:
            neural_data_nosst = pickle.load(f)
        test_data_nosst = neural_data_nosst['test_data']
        if (np.sum(np.array([_[0].numpy() for _ in test_data_nosst['resps']]), axis=1)==0).any():
            print('During some trials where SST cells are silenced, the network does not make a choice, pass')
            continue
        mean_perf_nosst = np.mean([_[0] for _ in test_data_nosst['perfs']])
        mean_perf_rule_nosst = np.mean([_[0] for _ in test_data_nosst['perf_rules']])
        rnn_activity_nosst = neural_data_nosst['rnn_activity'].detach().cpu().numpy()
        rnn_activity_nosst = rnn_activity_nosst[:, :, :, neuron_idx]
        
        
        
        #===== analysis =====#
        rules = test_data['rules']    # a list of rules from the test data
        resps = [torch.where(_[0]==1)[0].numpy()[0] for _ in test_data['resps']]    # a list of responses
        stims = [((_[0]['center_card']['color'], _[0]['center_card']['shape']),
                  (_[0]['test_cards'][0]['color'], _[0]['test_cards'][0]['shape']),
                  (_[0]['test_cards'][1]['color'], _[0]['test_cards'][1]['shape']),
                  (_[0]['test_cards'][2]['color'], _[0]['test_cards'][2]['shape'])) for _ in test_data['stims']]    # a list of stims
        
        betas, beta_conjs, rsqrs, rsqr_conjs = beta_conj(rnn_activity=rnn_activity, stims=stims, rules=rules, resps=resps, hp_task_test=hp_task_test, hp_test=hp_test)
        betas_allmodels.extend(betas)
        beta_conjs_allmodels.extend(beta_conjs)
        rsqrs_allmodels.extend(rsqrs)
        rsqr_conjs_allmodels.extend(rsqr_conjs)
        
        # silence SST
        rules_nosst = test_data_nosst['rules']    # a list of rules from the test data
        resps_nosst = [torch.where(_[0]==1)[0].numpy()[0] for _ in test_data_nosst['resps']]    # a list of responses
        stims_nosst = [((_[0]['center_card']['color'], _[0]['center_card']['shape']),
                      (_[0]['test_cards'][0]['color'], _[0]['test_cards'][0]['shape']),
                      (_[0]['test_cards'][1]['color'], _[0]['test_cards'][1]['shape']),
                      (_[0]['test_cards'][2]['color'], _[0]['test_cards'][2]['shape'])) for _ in test_data_nosst['stims']]    # a list of stims
        betas_nosst, beta_conjs_nosst, rsqrs_nosst, rsqr_conjs_nosst = beta_conj(rnn_activity=rnn_activity_nosst, stims=stims_nosst, rules=rules_nosst, resps=resps_nosst, hp_task_test=hp_task_test, hp_test=hp_test)
        betas_allmodels_nosst.extend(betas_nosst)
        beta_conjs_allmodels_nosst.extend(beta_conjs_nosst)
        rsqrs_nosst_allmodels.extend(rsqrs_nosst)
        rsqr_conjs_nosst_allmodels.extend(rsqr_conjs_nosst)
        

        n_models += 1
  

        all_data.append({'model_name': model_name,
                         'hp': hp_test,
                         'betas': betas,
                         'beta_conjs': beta_conjs,
                         'rsqrs': rsqrs,
                         'rsqr_conjs': rsqr_conjs,
                         'betas_nosst': betas_nosst,
                         'beta_conjs_nosst': beta_conjs_nosst,
                         'rsqrs_nosst': rsqrs_nosst,
                         'rsqr_conjs_nosst': rsqr_conjs_nosst})
        

print('Elapsed time: {}s'.format(time.time()-start))

In [None]:
rsqr_conjs_selectmodels = []
rsqr_conjs_nosst_selectmodels = []

for data in all_data:
    if data['hp']['dend_nonlinearity']=='divisive_2': 
        rsqr_conjs_selectmodels.extend(data['rsqr_conjs'])
        rsqr_conjs_nosst_selectmodels.extend(data['rsqr_conjs_nosst'])
        
y = [rsqr_conjs_selectmodels, rsqr_conjs_nosst_selectmodels]
colors = ['#b3e2cd', '#fdcdac']
fig, ax = plt.subplots(figsize=[2, 3])
for i in range(len(y[0])):
    ax.plot([0, 1], [y[0][i], y[1][i]], color='k', alpha=0.01, marker='o')
make_pretty_axes(ax)
ax.set_xlim([-0.2, 1.2])
plt.show()
print(stats.ttest_ind(a=rsqr_conjs_selectmodels, b=rsqr_conjs_nosst_selectmodels, alternative='greater'))