The purpose of this code is to investigate how nonlinear choice selectivity (switch signal) contributes to the population encoding of choice history, which is fundamental for flexible decision-making. We first generated surrogate neural datasets that have different population encoding properties of linear and nonlinear choice selectivty, and then decode the choice history across two successive trials.

In [1]:
import mat4py
import pandas as pd
import numpy as np
import scipy
import scipy.io as sio
import math
import warnings
import warnings
import time
from itertools import combinations
from itertools import product

from numpy import linalg as LA
import statsmodels.api as sm
from scipy.stats import nbinom
from scipy import stats
from scipy import optimize
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn import linear_model
from sklearn.decomposition import PCA

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import gridspec
from mpl_toolkits import mplot3d
from mpl_toolkits.mplot3d import Axes3D
%matplotlib widget

In [2]:
# load preprocessed behavioral + neural data (matlab file)

# probablistic reversal learning task
data = mat4py.loadmat('C:/Users/liang/Documents/GitHub/ChoiceInteraction/PRL_all_-500_500_250_1500_spkcounts_norm0.mat')
PRL = pd.DataFrame.from_dict(data['Data'])
Var = data['Var']
varName = { i : Var[i] for i in range(len(Var)) }
PRL.rename(columns=varName,inplace=True)
PRL.dropna(inplace=True)
ts = PRL.shape[1] - len(Var)
PRL_regressors = ['Loc','PreLoc','RL','PRL','LocInter','Col','PreCol','RC','PRC','ColInter',
    'Rwd','PreRwd','POS','ChosenMag','UnchosenMag','LMag','HVL','SwitchHVL']

# matching pennies task
data = mat4py.loadmat('C:/Users/liang/Documents/GitHub/ChoiceInteraction/MP_all_-500_500_250_1500_spkcounts_norm0.mat')
MP = pd.DataFrame.from_dict(data['Data'])
Var = data['Var']
varName = { i : Var[i] for i in range(len(Var)) }
MP.rename(columns=varName,inplace=True)
MP.dropna(inplace=True)
MP_regressors = ['Loc','PreLoc','RL','PRL','LocInter','Rwd','PreRwd']

In [3]:
def calculate_coef(data, regressors, timestamps):

    """
    Apply multiple linear regression model to the spike counts data to examine how single neurons were modulated 
    by the main task variables and the high-level interaction terms among these task variables. 
    
    Args:
        data: pandas dataframe, has all the behavioral and neural data from a task;
              each row includes the spike counts of one neuron throughout different epochs within one trial,
              and the corresponding behavioral variables, and recording details from that trial.
        regressors: independent variables for the regression model
        timestamps: the specific epochs to analyze
        
    Returns:
        coef: regression coefficients for [each neuron, each regressor, each epoch]
        res: residual sum of errors for [each neuron, each epoch]
                
    """
    
    cell_num = data.cellid.unique().size   
    coef = np.zeros((cell_num,len(regressors)+1,len(timestamps)))
    residual = np.zeros((cell_num,len(timestamps)))
    
    clf = linear_model.LinearRegression()

    for c in range(1,cell_num+1):

        cData = data[(data['cellid']==c)] 
        F = cData.loc[:,regressors].to_numpy()
        
        for time_ind, t in enumerate(timestamps):
        
            FR = cData[cData.columns[-ts:]].to_numpy()[:,t]

            # first gaussianize the data, and then apply the regression model
            FR_sorted = np.sort(FR)
            ind = np.argsort(FR)
            FR_g = np.random.normal(size=len(FR_sorted))
            FR_g_sorted = np.sort(FR_g)
            FR_t = np.zeros(len(FR_sorted))
            FR_t[ind] = FR_g_sorted

            clf.fit(F, FR_t)
            coef[c-1,:,time_ind] = np.append(clf.intercept_,clf.coef_)
            residual[c-1,time_ind] = np.var(FR_t-np.dot(F,coef[c-1,1:,time_ind])-coef[c-1,0,time_ind])
                
    original = {'coef': coef, 
                'res': residual} 
                
    return original

In [4]:
def coef_manipulation(original_coef, regressors, var, coefmode, corr, timestamps): 

    """
    Manipulate the regression coefficients for the nonlinear choice selectivity (switch) obtained from the linear regression model,
    and then generate surrogate neural activity with designed population encoding properties of choice-related signals.
    There are two types of manipulations (coefmode):
    1) 'Removal', removal/lesion of the switch signal. 
    2) 'Correlated', shuffle the regression coefficients of switch across neural population, so that the population activity patterns 
       evoked by switch and choice are correlated, and takes the targeted value (corr). 
    
    Args:
        original_coef: regression coefficients obtained from function 'calculate_coef'
        regressors: independent variables for the regression model
        var: choice variable, 'Col' for target color or 'Loc' for target location
        coefmode: 'Removal' or 'Correlated'
        corr: [-1,1], targeted correlation coefficient if coefmode=='Correlated'
        timestamps: the specific epochs to analyze
        
    Returns:
        mani_coef: regression coefficients for [each neuron, each regressor, each epoch] after manipulation
                
    """
    
    cell_num = original_coef.shape[0] 
    timestamps = original_coef.shape[2] 
    mani_coef = original_coef.copy()
    varOI = [var,'Pre'+var,var+'Inter']
    var1_ind = np.where([x==var for x in regressors])[0][0]+1
    var2_ind = np.where([x==var+'Inter' for x in regressors])[0][0]+1
    
    if coefmode == 'Correlated':
                    
        a = np.random.multivariate_normal([0,0],[[1,corr],[corr,1]],size=cell_num)
        a_sorted = np.zeros((cell_num,2))
        var2_reordered = np.zeros(cell_num)

        for time_ind in range(timestamps):
            
            var1_coef = original_coef[:,var1_ind,time_ind]
            var2_coef = original_coef[:,var2_ind,time_ind]

            ind = np.argsort(var1_coef)  
            a_sorted[ind,:] = a[a[:,0].argsort()]

            var2_cellind = np.argsort(a_sorted[:,1])
            var2_sorted = np.sort(var2_coef) 

            mani_coef[var2_cellind,var2_ind,time_ind] = var2_sorted
                
    if coefmode == 'Removal':
        mani_coef[:,var2_ind,:] = 0
            
    return mani_coef

In [5]:
def generate_data(data, regressors, var, res, mani_coef, timestamps):
    
    """
    Generate surrogate data using the manipulated regression coefficients.
    
    Args:
        var: choice variable, 'Col' for target color or 'Loc' for target location
        res: residual sum of errors from function 'calculate_coef'
        mani_coef: manipulated regression coefficients
        
    Returns:
        surrogateData: simulated neural activity
        refit_coef:  apply the same regression model again to the surrogate data to examine if the manipulation is valid
                
    """
    
    cell_num = data.cellid.unique().size 
    surrogateData = data.copy()
    Y = np.array([])
    varOI = [var,'Pre'+var,var+'Inter']
    refit_coef = np.zeros(mani_coef.shape)
    clf = linear_model.LinearRegression()  
        
    for c in range(1,cell_num+1):

        cData = data[(data['cellid']==c)] 
        F = cData.loc[:,regressors].to_numpy()
        y = np.zeros((len(F),len(timestamps)))
        
        for time_ind, t in enumerate(timestamps):
        
            FR = cData[cData.columns[-ts:]].to_numpy()[:,t]
            y_gaussian = mani_coef[c-1,0,time_ind] + np.dot(F,mani_coef[c-1,1:,time_ind]) + np.random.normal(0, np.sqrt(res[c-1,time_ind]), len(FR))
            FR_sorted = np.sort(FR)      # sort the original spike counts
            ind = np.argsort(y_gaussian)  # map 
            y[ind,time_ind] = FR_sorted

            clf.fit(F, y[:,time_ind])
            refit_coef[c-1,:,time_ind] = np.append(clf.intercept_,clf.coef_)
        
        if len(Y) == 0:
            Y = y
        else:
            Y = np.append(Y,y,axis=0)
            
    surrogateData.iloc[:,-ts+np.array(timestamps)] = Y
        
    return surrogateData, refit_coef

In [6]:
def create_dataset(data, var, train_trialNum=100, test_trialNum=50):
    
    """
    Generate psuedo-trials. Because most of the neurons were collected from separate sessions, here we resample the neural
    activity from individual neurons under the same behavioral condition and construct a population activity for one trial
    as if these neurons were recorded simultaneously. 
    
    Args:
        var: choice variable, 'Col' for target color or 'Loc' for target location
        res: residual sum of errors from function 'calculate_coef'
        mani_coef: manipulated regression coefficients
        
    Returns:
       training set and testing set
    """
    
    cell_num = data.cellid.unique()
    k = 3
    train = [[] for x in range(k)]
    test = [[] for x in range(k)]

    for val1 in np.unique(data[var]):  
        for val2 in np.unique(data['Pre'+var]):
            for val3 in np.unique(data['PreRwd']):

                temp = data[(data[var]==val1) & (data['Pre'+var]==val2) & (data['PreRwd']==val3)]
                train_temp = [np.zeros([train_trialNum,ts,cell_num.size]) for x in range(k)]
                test_temp = [np.zeros([test_trialNum,ts,cell_num.size]) for x in range(k)]

                for c in range(cell_num.size):
                    
                    cData = temp[temp.cellid==cell_num[c]]

                    trialNum = np.array(cData.index)
                    kf = KFold(n_splits=k, shuffle=True)

                    if len(trialNum)<3:

                        for i in range(k):
                            train_temp[i][:,:,c] = np.nan
                            test_temp[i][:,:,c] = np.nan

                    else:

                        for i, (c_train_ind, c_test_ind) in enumerate(kf.split(trialNum)):

                            c_train = cData.iloc[c_train_ind]
                            c_test = cData.iloc[c_test_ind]        
                            c_train_rsm = np.random.randint(c_train.shape[0],size=train_trialNum)
                            c_test_rsm = np.random.randint(c_test.shape[0],size=test_trialNum)
                            c_train_data = c_train.iloc[c_train_rsm]
                            c_test_data = c_test.iloc[c_test_rsm]

                            train_temp[i][:,:,c] = c_train_data[c_train_data.columns[-ts:]].to_numpy()
                            test_temp[i][:,:,c] = c_test_data[c_test_data.columns[-ts:]].to_numpy()

                for i in range(k):

                    train_temp[i]  = np.concatenate((train_temp[i],np.tile([val1,val2,val1*val2],(train_trialNum,ts,1))), axis=2)
                    test_temp[i]  = np.concatenate((test_temp[i],np.tile([val1,val2,val1*val2],(test_trialNum,ts,1))), axis=2)                    

                    if len(train[i]) == 0:
                        train[i] = train_temp[i]
                        test[i] = test_temp[i]
                    else:
                        train[i] = np.append(train[i],train_temp[i],axis=0) 
                        test[i] = np.append(test[i],test_temp[i],axis=0) 


    return train, test


In [7]:
def decode(data, var, timestamps, num_iter=10, train_trialNum=100, test_trialNum=50):
    
    """
    Apply linear support vector machine in combination with k-fold cross validation to decode the agent's 
    current choice, previous choice, switch and choice history across two successive trials. 
    
    Args:
        same as above 
        
    Returns:
        Accuracy: decoding/classification accuracy for these 4 variables
        Projection: the projection onto the encoding axis of each variable defined by SVM, one data point is one pseudo-trial
    
    """
    
    Accuracy = np.zeros((num_iter,3,len(timestamps),5))
    Projection = np.zeros((num_iter,3,len(timestamps),test_trialNum*8,7))
    
    clf1 = SVC(kernel='linear')
    clf2 = SVC(kernel='linear')
    clf3 = SVC(kernel='linear')
    clf4 = SVC(kernel='linear',probability=True) 
        
    for n in range(num_iter):

        train,test = create_dataset(data,var,'Pre'+var, train_trialNum=100, test_trialNum=50)

        for k in range(3):

            cellindex = (~np.isnan(train[k][:,0,:-3].mean(axis=0))) & (~np.isnan(test[k][:,0,:-3].mean(axis=0)))  
            
            for i, t in enumerate(timestamps):

                clf1.fit(train[k][:,t,:-3][:,cellindex],train[k][:,t,-3])
                clf2.fit(train[k][:,t,:-3][:,cellindex],train[k][:,t,-2])
                clf3.fit(train[k][:,t,:-3][:,cellindex],train[k][:,t,-1])
                clf4.fit(train[k][:,t,:-3][:,cellindex],train[k][:,t,-3]*2+train[k][:,t,-2])

                Accuracy[n,k,i,0] = accuracy_score(clf1.predict(test[k][:,t,:-3]),test[k][:,t,-3])
                Accuracy[n,k,i,1] = accuracy_score(clf2.predict(test[k][:,t,:-3]),test[k][:,t,-2])
                Accuracy[n,k,i,2] = accuracy_score(clf3.predict(test[k][:,t,:-3]),test[k][:,t,-1])
                Accuracy[n,k,i,3] = accuracy_score(clf4.predict(test[k][:,t,:-3]),test[k][:,t,-3]*2+test[k][:,t,-2])
                Accuracy[n,k,i,4] = t
                
                Projection[n,k,i,:,0] = test[k][:,t,-3]
                Projection[n,k,i,:,1] = test[k][:,t,-2]
                Projection[n,k,i,:,2] = test[k][:,t,-1]
                Projection[n,k,i,:,3] = np.ravel((np.dot(test[k][:,t,:-3][:,cellindex],np.transpose(clf1.coef_))+clf1.intercept_)/LA.norm(clf1.coef_))
                Projection[n,k,i,:,4] = np.ravel((np.dot(test[k][:,t,:-3][:,cellindex],np.transpose(clf2.coef_))+clf2.intercept_)/LA.norm(clf2.coef_))
                Projection[n,k,i,:,5] = np.ravel((np.dot(test[k][:,t,:-3][:,cellindex],np.transpose(clf3.coef_))+clf3.intercept_)/LA.norm(clf3.coef_))
                Projection[n,k,i,:,6] = t
                
                
    Accuracy_df = pd.DataFrame(data=Accuracy.reshape(-1,5),columns=['Current','Previous','Switch','Choice Sequence','Time'])
    Projection_df = pd.DataFrame(data=Projection.reshape(-1,7),columns=['Current','Previous','Switch','CurrentProjection',
                                                                        'PreviousProjection','SwitchProjection','Time'])
        
    return Accuracy_df, Projection_df

In [500]:
def plot_reorder(corrcoef,accuracy,Allcoef,var,regressors,title,filename):
    
    """
    Scatter plot the relationship between correlation coefficient and the decoding accuracy. 
    
    """    
    
    varOI = [var,'Pre'+var,var+'Inter']
    var_ind = np.zeros(3)
    for ii in range(3):
        var_ind[ii] = np.where([x==varOI[ii] for x in regressors])[0][0]+1
    var_coef = Allcoef[:,:,var_ind.astype(int)]
    coef_norm = np.linalg.norm(var_coef, axis=1)
    regmat = np.concatenate([corrcoef,coef_norm],axis=1)
    
    plt.close('all')
    my_dpi = 120
    fig = plt.figure(figsize=(2000/my_dpi, 500/my_dpi), dpi=my_dpi,facecolor=(1, 1, 1))
    gs = gridspec.GridSpec(1,4)
    titles=['Choice(t)','Choice(t-1)','Switch','Choice sequence']

    for c in range(4):
        
        ax = plt.subplot(gs[c])   
        mdl=sm.OLS(accuracy[:-3,c],np.abs(regmat[:-3]))
        res = mdl.fit() 
        ax.scatter(corrcoef[:-3],accuracy[:-3,c],color=[0.8,0.8,0.8])

        if res.pvalues[0]<0.01:
            ax.text(-0.7,0.45,'w='+f"{res.params[0]:.3f}"+'\np='+f"{res.pvalues[0]:.1e}",fontsize=14)
        else:
            ax.text(-0.7,0.45,'w='+f"{res.params[0]:.3f}"+'\np='+f"{res.pvalues[0]:.2f}",fontsize=14)
        ax.set_ylim(0.4,1.1)
        ax.set_xlim(-0.8,0.8)

        if c==0:
            ax.set_ylabel('Accuracy (%)', fontsize=14)

        ax.set_title(titles[c],fontsize=14)
        ax.set_position([c*0.9/4+0.05,0.15,0.19,0.7])
        ax.set_xticks([-0.5,0,0.5])
        ax.set_xticklabels([-0.5,0,0.5], fontsize=14)
        ax.set_yticks(np.linspace(0.4,1,4))
        ax.set_yticklabels(np.linspace(0.4,1,4), fontsize=14)
        
    fig.text(0.5, 0.03, 'Correlation between Choice(t) and Switch', ha='center', fontsize=14)
    
    plt.suptitle(title,fontsize=16)


In [33]:
def original_removal_cmp(data, regressors, var):
    
    
    """
    Compare the decoding accuracy and projection between simulated data using the un-manipulated/original regression coefficients
    and the regression coefficients after the removal of switch signal. 
    
    """    

    original = calculate_coef(data, regressors, np.arange(18))
    mani_coef = coef_manipulation(original['coef'], regressors, var, 'Removal', 0, np.arange(18))
    
    original_data, original_refit_coef = generate_data(data, regressors, var, original, original['coef'], np.arange(18))
    removal_data, removal_refit_coef = generate_data(data, regressors, var, original, mani_coef, np.arange(18))
    
    original_accuracy, original_projection = decode(original_data, var, np.arange(18), num_iter=2)
    removal_accuracy, removal_projection = decode(removal_data, var, np.arange(18), num_iter=2)
    
    return original_accuracy, original_projection, removal_accuracy, removal_projection

In [216]:
def reorder_data(data, regressors, var):
    
    """
    Get the decoding accuracy and projection for the datasets, in which the population encoding of switch and choice 
    are correlated. 
    
    """    
    
    original = calculate_coef(data, regressors, [6])
    mani_coef = coef_manipulation(original['coef'], regressors, var, 'Reorder', 1, [6])
    reorder_data, reorder_refit_coef = generate_data(data, regressors, var, original, mani_coef, [6])
    reorder_accuracy, reorder_projection = decode(reorder_data, var, [6], num_iter=2)
    
    return reorder_accuracy, reorder_projection

In [529]:
def plot_original_removal(original_accuracy,removal_accuracy,task):

    """
    Plot the decoding accuracy for the datasets with original regression coefficients versus the one with switch signal removed
    
    """  
    
    plt.close('all')
    fig, axs = plt.subplots(2, 4, figsize=(6,4),facecolor=(1, 1, 1))
    variables = ['Current','Previous','Switch','Choice Sequence']

    if task == 'PRL':
        ticklabels = ['0', '0.5', '1']
        ticks=[[2,4,6],[11,13,15]]
        tlim=[[-1,8],[8,17]]
    else:
        ticklabels = ['0', '0.5']
        ticks=[[2,4],[11,13]]
        tlim=[[0,7],[9,16]]
        
    ticksize=10
    titlesize=12

    for var in range(4):
        
        for ax_ind in range(2):
            
#             ax=plt.subplot(gs[int(np.floor(var/2)),ax_ind+2*(var%2)])
            ax=axs[int(np.floor(var/2)),ax_ind+2*(var%2)]
            sns.lineplot(original_accuracy[(original_accuracy.Time>tlim[ax_ind][0]) & (original_accuracy.Time<tlim[ax_ind][1])],
                         x='Time',y=variables[var],color='forestgreen',ax=ax,zorder=2,marker='o',markeredgecolor=None)
            sns.lineplot(removal_accuracy[(removal_accuracy.Time>tlim[ax_ind][0]) & (removal_accuracy.Time<tlim[ax_ind][1])],
                         x='Time',y=variables[var],color='darkviolet',ax=ax,zorder=2,marker='o',markeredgecolor=None)

            ax.set_xticks(ticks[ax_ind])
            ax.set_ylim(0,1.05)
            ax.set_xlim(tlim[ax_ind][0]+1,tlim[ax_ind][1]-1)
                
                
            if ax_ind == 0:

                ax.fill_between([2,4],[1,1],color=[0.6,0.6,0.6],alpha=0.15,edgecolor=None)
                ax.spines[['right', 'top']].set_visible(False)
                ax.set_xlabel('Target onset',fontsize=ticksize)
                    
                if var%2==0:
                    ax.set_ylabel('Accuracy (%)')
                else:
                    ax.set_ylabel('')
                    ax.set_yticklabels('')
                
            else:
                ax.fill_between([11,13],[1,1],color=[0.6,0.6,0.6],alpha=0.15,edgecolor=None)
                ax.spines[['left', 'right', 'top']].set_visible(False)
                ax.set_xlabel('Feedback onset',fontsize=ticksize)
                ax.set_ylabel('')
                ax.set_yticks([])
                
            if var == 3:
                ax.plot(np.array(tlim[ax_ind]),[0.25,0.25],'k--',zorder=1)
            else:
                ax.plot(np.array(tlim[ax_ind]),[0.5,0.5],'k--',zorder=1)     
                
            if var < 2:
                ax.set_xticklabels('')
                ax.set_xlabel('')
            else:
                ax.set_xticklabels(ticklabels,fontsize=ticksize)
                    

                
            ax.set_position([0.2+0.4*(var%2)+0.17*ax_ind,0.55-0.4*np.floor(var/2),0.15,0.3])

                    
        ax.text(10,1.05,variables[var],fontsize=titlesize,ha='center')


    plt.suptitle(task,fontsize=titlesize)


In [246]:
def plotProjection(Projection,task,var):
    
    """
    Plot the projection of population activity onto the encoding axis of each task variable, which to some extent serves as
    dimensionality reduction. Each datapoint represents a pseudo-trial. 
    
    """  
    
    plt.close('all')
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')   
    colors = ['#59ce8f', '#ff7f50', '#0761f2', '#7a0bc0']

    for var1 in [-1,1]:
        for var2 in [-1,1]:
            
            c = int(var1+var2/2+1.5)
            
            conditioned = Projection[(Projection.Current==var1) & (Projection.Previous==var2)]
            conditioned_mean = conditioned.mean()

            ax.scatter3D(conditioned.CurrentProjection.values, conditioned.PreviousProjection.values,
                         conditioned.SwitchProjection.values, color=colors[c], marker='o', depthshade=True, s=20, alpha=0.5)


    plt.show()
    
    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False
    
    # Now set color to white (or whatever is "invisible")
    ax.xaxis.pane.set_edgecolor('w')
    ax.yaxis.pane.set_edgecolor('w')
    ax.zaxis.pane.set_edgecolor('w')
    ax.grid(False)
    
    if var=='Col':
        
        plt.legend(['G->G','R->G','G->R','R->R'])
        
        ax.set_xlim(-5,5)
        ax.set_ylim(-5,5)
        ax.set_zlim(-5,5)

        ax.set_xticks([-2.5,0,2.5])
        ax.set_yticks([-2.5,0,2.5])
        ax.set_zticks([-2.5,0,2.5])
        
    else:
        plt.legend(['L->L','R->L','L->R','R->R'])
        
        ax.set_xlim(-30,30)
        ax.set_ylim(-10,10)
        ax.set_zlim(-10,10)

        ax.set_xticks([-15,0,15])
        ax.set_yticks([-5,0,5])
        ax.set_zticks([-5,0,5])

    ax.set_xlabel(var+'(t)')
    ax.set_ylabel(var+'(t-1)')
    ax.set_zlabel('Switch')
    ax.view_init(azim=100,elev=-170)
    plt.savefig(task+' '+var+' projection on the encoding axis')

    ax.view_init(azim=0,elev=0)
    plt.savefig(task+' '+var+' projection on the encoding axis 1')

    ax.view_init(azim=90,elev=0)
    plt.savefig(task+' '+var+' projection on the encoding axis 2')

    ax.view_init(azim=90,elev=90)
    plt.savefig(task+' '+var+' projection on the encoding axis 3')