In [2]:
# load all fxns and data
# %matplotlib inline

from opconNosepokeFunctions import *
from supplementaryFunctions import *
from scipy.optimize import minimize
from scipy.stats import entropy
from scipy.stats import ttest_rel
from sklearn.linear_model import LogisticRegression
from scipy.ndimage import gaussian_filter1d
import statsmodels.api as sm
import statsmodels.formula.api as smf

sessdf = pd.read_csv('L:/4portProb_processed/sessdf.csv')
sessdf.drop(columns = 'Unnamed: 0', inplace = True)
exclude = ['[ 20  20  20 100]', '[0 0 0 0]', '[0]', '[0 0]',
       '[1000   80]', '[30]', '[40]', '[70]']
sessdf = sessdf[~sessdf.rewprobfull.isin(exclude)]
sessdf = sessdf[~sessdf.duplicated(subset = ['animal', 'session', 'trialstart', 'eptime'], keep = False)]

In [126]:
def data_prep(dataset, hist = 20, trialsinsess=75, task = 'unstr', head = False):
    dataset = dataset.groupby(['animal','session']).filter(lambda x: x.reward.size >= trialsinsess)
#     dataset['valid'] = np.zeros(len(dataset))
#     dataset['valid'] = dataset['valid'].where(dataset['port'].isin([4.0]), 1)
#     dataset = dataset[dataset['valid']==1.0]
    dataset['choice_t0'] = dataset.port.values
    for i in range(1,hist): 
        
        dataset['choice_t'+str(i)] = dataset.groupby(['animal','session']).port.shift(i)
        dataset['shift_t'+str(i-1)] = dataset['choice_t'+str(i)]==dataset['choice_t'+str(i-1)]
        dataset['shift_t'+str(i-1)] = dataset['shift_t'+str(i-1)].replace({True: 0, False: 1})
        dataset['reward_t'+str(i)] = dataset.groupby(['animal','session']).reward.shift(i)
        # dataset['reward_t'+str(i)] = dataset['reward_t'+str(i)].replace({0:-1})
#         dataset['choice_t'+str(i)] = dataset['choice_t'+str(i)].replace({1:'a', 2:'b', 3:'c', 4:'d'})
    dataset = dataset.dropna()
    if head == True:
        dataset = dataset.groupby(['animal','session']).head(trialsinsess)

    return dataset

In [154]:
mask = (sessdf.task.isin(['unstr']))
sessdf_prep = data_prep(sessdf[mask], hist = 4, trialsinsess = 50, head= True)

In [157]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss
from pandas.api.types import CategoricalDtype
from sklearn.metrics import confusion_matrix
%matplotlib qt
catdtype = CategoricalDtype(categories=np.arange(1,5), ordered = False)

# choice only model
res = pd.DataFrame()
dfs = [sessdf_prep]
animal = sessdf_prep[~sessdf_prep.animal.isin(['Kakuna', 'Finneon'])].animal.unique()
cond = ['unstr']
mat = 0

# define y var
y_cols = ['choice_t0']

# define x var
x_cols = ['choice_t1_1','choice_t1_2', 'choice_t1_3', 'choice_t1_4',
          'choice_t2_1','choice_t2_2', 'choice_t2_3', 'choice_t2_4',
          'choice_t3_1','choice_t3_2', 'choice_t3_3', 'choice_t3_4']
        #   'choice_t4_1', 'choice_t4_2', 'choice_t4_3', 'choice_t4_4',
        #   'choice_t5_1', 'choice_t5_2', 'choice_t5_3', 'choice_t5_4']


# store in list then convert to df
for i, dat in enumerate(dfs):
    # get dummies 
    choice_cols = pd.get_dummies(dat[['choice_t'+str(i) for i in range(1,4)]].astype(catdtype),
                                        drop_first = False)
    rew_cols = pd.get_dummies(dat[['reward_t'+str(i) for i in range(1,4)]].astype(int),
                                    drop_first = False, columns = ['reward_t'+str(i) for i in range(1,4)])  
     
    for an in animal:

        # split data by session for train/test
        from sklearn.model_selection import GroupShuffleSplit 

        splitter = GroupShuffleSplit(random_state = 42)
        split = splitter.split(dat[dat['animal']==an], groups=dat[dat['animal']==an]['session'])
        train_inds, test_inds = next(split)

        # split by indices of dat
        train = dat.iloc[train_inds]
        test = dat.iloc[test_inds]
        print(an)
        # print('training sess = ', train.session.unique())
        # print('testing sess = ', test.session.unique())


        # get x and y acc to split
        y_train = train[y_cols].to_numpy().reshape(-1)
        y_test = test[y_cols]

        X_temp = pd.concat([choice_cols, rew_cols], axis = 1)

        # choice only model
        X = pd.DataFrame(np.c_[X_temp[x_cols]], columns = x_cols)
        
        X_train = X.iloc[train_inds]
        X_test = X.iloc[test_inds]
        
        # X_train, X_test, y_train, y_test = train_test_split(X, y)
        
        lr = LogisticRegression(multi_class='multinomial', solver = 'lbfgs',
                                # penalty = 'l2',
                                class_weight='balanced',
                                random_state = 42, fit_intercept = True)
        lr.fit(X_train, y_train)

        cols = lr.feature_names_in_
        ind = lr.classes_
        
        score = lr.score(X_test, y_test)

        y_pred_proba = lr.predict_proba(X_train)

        # ll_null = log_loss(y_train, [calc_prob(y_train)]*len(y_train), labels = [1,2,3,4])
        # ll_model = log_loss(y_train, y_pred_proba, labels = [1,2,3,4])

        ll_null = log_loss(y_train, [calc_prob(y_train)]*len(y_train))
        ll_model = log_loss(y_train, y_pred_proba)

        pseudo_r2 = (ll_null - ll_model) / ll_null
        print(round(pseudo_r2, 5), round(score, 5), an, cond[i]+dat[dat.animal==an].task.unique()[0])
        temp_res = pd.concat([pd.DataFrame(lr.coef_, columns = cols, index = ind),
                              pd.Series(lr.intercept_, index = ind, name = 'intercept'),
                              pd.Series(an, index = ind, name = 'animal'),
                              pd.Series(cond[i]+dat[dat.animal==an].task.unique()[0], index = ind, name = 'task'), 
                             pd.Series(score, index = ind, name = 'acc'),
                             pd.Series(pseudo_r2, index = ind, name = 'prsq')],
                             axis = 1)
        res = pd.concat([temp_res, res])
        print('--------------------------------------')
        # confusion matrix
        cm = confusion_matrix(y_test, lr.predict(X_test), labels=lr.classes_, normalize = 'true')
        mat+=cm
plt.figure()       
sns.heatmap(mat/len(animal), cmap = 'viridis', annot = True, vmin = 0, vmax = 1)
plt.title(f'C model, prsq = {round(res.prsq.mean(), 3)}, acc = {round(res.acc.mean(), 3)}')
print(res.acc.mean(), res.prsq.mean())

test05022023
0.12919 0.63856 test05022023 unstrunstr
--------------------------------------
Blissey




0.35376 0.6826 Blissey unstrunstr
--------------------------------------
Chikorita
0.11792 0.5929 Chikorita unstrunstr
--------------------------------------




Darkrai
0.13581 0.63207 Darkrai unstrunstr
--------------------------------------
Eevee




-0.0007 0.58348 Eevee unstrunstr
--------------------------------------
Goldeen
0.34166 0.67869 Goldeen unstrunstr
--------------------------------------
Hoppip
0.31835 0.67843 Hoppip unstrunstr
--------------------------------------




Inkay
0.34035 0.66241 Inkay unstrunstr
--------------------------------------
Jirachi
0.34884 0.64462 Jirachi unstrunstr
--------------------------------------




Kirlia
0.35325 0.68338 Kirlia unstrunstr
--------------------------------------
Mesprit




0.37479 0.64709 Mesprit unstrunstr
--------------------------------------
Nidorina
0.31945 0.66319 Nidorina unstrunstr
--------------------------------------
Oddish
0.34564 0.68092 Oddish unstrunstr
--------------------------------------




Phione
0.11865 0.57257 Phione unstrunstr
--------------------------------------
Quilava




0.35997 0.65364 Quilava unstrunstr
--------------------------------------
Raltz
0.27083 0.68046 Raltz unstrunstr
--------------------------------------




Shinx
0.34443 0.66889 Shinx unstrunstr
--------------------------------------
Togepi




0.37499 0.64082 Togepi unstrunstr
--------------------------------------
Umbreon
0.24164 0.67532 Umbreon unstrunstr




--------------------------------------
Vulpix
0.36491 0.67041 Vulpix unstrunstr
--------------------------------------




Xatu
0.10721 0.54074 Xatu unstrunstr
--------------------------------------
Yanma




0.09628 0.6125 Yanma unstrunstr
--------------------------------------
Zacian
0.34692 0.68 Zacian unstrunstr




--------------------------------------
Alakazam
0.33113 0.68729 Alakazam unstrunstr
--------------------------------------




Bayleef
0.44791 0.75072 Bayleef unstrunstr
--------------------------------------




Cresselia
0.20109 0.68401 Cresselia unstrunstr
--------------------------------------
Emolga




0.34749 0.66241 Emolga unstrunstr
--------------------------------------
Giratina
0.34737 0.6498 Giratina unstrunstr
--------------------------------------




Haxorus
0.14397 0.60541 Haxorus unstrunstr
--------------------------------------
Ivysaur




0.37305 0.64972 Ivysaur unstrunstr
--------------------------------------
Jigglypuff
0.01029 0.53053 Jigglypuff unstrunstr
--------------------------------------
Lugia
0.00806 0.58891 Lugia unstrunstr
--------------------------------------
0.6460149482529143 0.25982867079539446




In [156]:
# plot all coefficients averaged across animals
res.reset_index().groupby('index')[np.append(x_cols, 'intercept')].mean().T.plot(kind = 'bar', figsize = (10,5))

<Axes: >

In [None]:
res

Unnamed: 0,choice_t1_1,choice_t1_2,choice_t1_3,choice_t1_4,choice_t2_1,choice_t2_2,choice_t2_3,choice_t2_4,choice_t3_1,choice_t3_2,choice_t3_3,choice_t3_4,intercept,animal,task,acc,prsq
1.0,0.066905,-0.148301,0.100114,-0.207940,0.762889,0.146945,-0.365570,-0.733486,0.350716,0.011443,-0.211401,-0.339980,-0.200890,Lugia,unstrunstr,0.588910,0.008061
2.0,0.078785,0.373549,-0.071401,-0.348520,0.058057,0.498055,-0.181743,-0.341956,-0.061288,0.501910,-0.042327,-0.365882,0.034411,Lugia,unstrunstr,0.588910,0.008061
1.0,0.104610,-0.185014,0.092443,-0.217704,0.797560,0.109260,-0.424599,-0.687887,0.353605,0.020109,-0.167302,-0.412077,-0.219723,Jigglypuff,unstrunstr,0.530526,0.010295
2.0,0.113295,0.345856,-0.077451,-0.328415,0.028291,0.411222,-0.122043,-0.264185,-0.127374,0.517764,-0.006162,-0.330943,0.056355,Jigglypuff,unstrunstr,0.530526,0.010295
1.0,0.481189,-0.307947,0.047645,-0.337284,1.184153,-0.126079,-0.534985,-0.639487,0.721011,-0.287376,-0.285365,-0.264667,-0.119584,Ivysaur,unstrunstr,0.649722,0.373050
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2.0,-0.030031,0.527574,0.002965,-0.349711,0.044145,0.681136,-0.118884,-0.455600,-0.023529,0.626029,-0.098228,-0.353475,0.156020,Chikorita,unstrunstr,0.592903,0.117918
1.0,0.525786,-0.301956,0.032758,-0.368442,1.155144,-0.129514,-0.530134,-0.607349,0.697180,-0.262401,-0.286030,-0.260602,-0.114518,Blissey,unstrunstr,0.682597,0.353757
2.0,-0.364525,0.796588,-0.167597,-0.517432,-0.183698,0.806221,-0.181769,-0.693718,-0.210741,0.801319,-0.115901,-0.727643,-0.257910,Blissey,unstrunstr,0.682597,0.353757
1.0,0.270287,-0.222072,0.066061,-0.387965,0.927131,0.001058,-0.496398,-0.705480,0.509726,-0.135786,-0.286339,-0.361289,-0.283526,test05022023,unstrunstr,0.638560,0.129191


In [141]:
res_with_intercept = res.copy(deep= True)

In [148]:
(res_with_intercept.choice_t1_1 - res_with_intercept.choice_t1_2)

1.0    0.285430
2.0   -0.312760
3.0    0.093494
4.0   -0.066164
1.0    0.329719
         ...   
4.0    0.532177
1.0    0.518341
2.0   -0.569921
3.0   -0.122932
4.0    0.174512
Length: 128, dtype: float64

In [None]:
# plot coefficients averaged across animals for each port
port1_c = res.reset_index().groupby('index')[x_cols[::4]].mean()
port2_c = res.reset_index().groupby('index')[x_cols[1::4]].mean()
port3_c = res.reset_index().groupby('index')[x_cols[2::4]].mean()
port4_c = res.reset_index().groupby('index')[x_cols[3::4]].mean()
print(port1_c)
plt.figure()
ax = plt.subplot(221)
ax.plot(port1_c.T, 'o-')
ax.set_ylim(-1,1.2)

ax = plt.subplot(222)
ax.plot(port2_c.T, 'o-')
ax.set_ylim(-1,1.2)

ax = plt.subplot(223)
ax.plot(port3_c.T, 'o-')
ax.set_ylim(-1,1.2)

ax = plt.subplot(224)
ax.bar(np.arange(3), port4_c.T, 'o-')
ax.set_ylim(-1,1.2)

plt.legend(port1_c.index)
sns.despine()
plt.tight_layout()

       choice_t1_1  choice_t2_1  choice_t3_1
index                                       
1.0       0.603290     1.192849     0.818921
2.0      -0.242801    -0.221927    -0.153658
3.0       0.051530    -0.234898    -0.108380
4.0      -0.412019    -0.736024    -0.556883


In [183]:

from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss
from pandas.api.types import CategoricalDtype
from sklearn.metrics import confusion_matrix

catdtype = CategoricalDtype(categories=np.arange(1,5), ordered = False)

# choice*reward only model
res = pd.DataFrame()
dfs = [sessdf_prep]
animal = sessdf_prep[~sessdf_prep.animal.isin(['Kakuna', 'Finneon'])].animal.unique()
cond = ['unstr']
mat = 0
# define y var
y_cols = ['choice_t0']

# define x var
# x_cols = ['choice_t1_1', 'choice_t1_2', 'choice_t1_3', 'choice_t1_4',
#           'choice_t2_1', 'choice_t2_2', 'choice_t2_3', 'choice_t2_4',
#           'choice_t3_1', 'choice_t3_2', 'choice_t3_3', 'choice_t3_4']


# store in list then convert to df
for i, dat in enumerate(dfs):
    # get dummies 
    choice_cols = pd.get_dummies(dat[['choice_t'+str(i) for i in range(1,6)]].astype(catdtype),
                                        drop_first = False)
    rew_cols = pd.get_dummies(dat[['reward_t'+str(i) for i in range(1,6)]].astype(int),
                                    drop_first = False, columns = ['reward_t'+str(i) for i in range(1,6)])
    cr_cols = np.array([choice_cols[f'choice_t{time_lag}_{choice}']*rew_cols[f'reward_t{time_lag}_{rew}'] 
                        for time_lag in range(1,4) for choice in range(1,5) for rew in [-1, 1]])
    cr_cols = pd.DataFrame(cr_cols.T, columns = [f'cr_t{time_lag}_{choice}_{rew}' for time_lag in range(1,4) for choice in range(1,5) for rew in [-1, 1]])
    x_cols = cr_cols.columns
    x_cols = cr_cols.drop(columns = ['cr_t1_1_1', 'cr_t2_1_1', 'cr_t3_1_1']).columns
     
    for an in animal:

        # split data by session for train/test
        from sklearn.model_selection import GroupShuffleSplit 

        splitter = GroupShuffleSplit(random_state = 7)
        split = splitter.split(dat[dat['animal']==an], groups=dat[dat['animal']==an]['session'])
        train_inds, test_inds = next(split)

        # split by indices of dat
        train = dat.iloc[train_inds]
        test = dat.iloc[test_inds]
        print(an)
        # print('training sess = ', train.session.unique())
        # print('testing sess = ', test.session.unique())


        # get x and y acc to split
        y_train = train[y_cols].to_numpy().reshape(-1)
        y_test = test[y_cols]

        X_temp = pd.concat([choice_cols, rew_cols, cr_cols], axis = 1)

        # choice only model
        X = pd.DataFrame(np.c_[X_temp[x_cols]], columns = x_cols)
        
        X_train = X.iloc[train_inds]
        X_test = X.iloc[test_inds]
        
        # X_train, X_test, y_train, y_test = train_test_split(X, y)
        
        lr = LogisticRegression(multi_class='multinomial', solver = 'lbfgs',
                                # penalty = 'l2',
                                random_state = 42, fit_intercept = False)
        lr.fit(X_train, y_train)

        cols = lr.feature_names_in_
        ind = lr.classes_
        
        score = lr.score(X_test, y_test)

        y_pred_proba = lr.predict_proba(X_train)

        # ll_null = log_loss(y_train, [calc_prob(y_train)]*len(y_train), labels = [1,2,3,4])
        # ll_model = log_loss(y_train, y_pred_proba, labels = [1,2,3,4])

        ll_null = log_loss(y_train, [calc_prob(y_train)]*len(y_train))
        ll_model = log_loss(y_train, y_pred_proba)

        pseudo_r2 = (ll_null - ll_model) / ll_null
        print(round(pseudo_r2, 5), round(score, 5), an, cond[i]+dat[dat.animal==an].task.unique()[0])
        temp_res = pd.concat([pd.DataFrame(lr.coef_, columns = cols, index = ind),
                              pd.Series(lr.intercept_, index = ind, name = 'intercept'),
                              pd.Series(an, index = ind, name = 'animal'),
                              pd.Series(cond[i]+dat[dat.animal==an].task.unique()[0], index = ind, name = 'task'), 
                             pd.Series(score, index = ind, name = 'acc'),
                             pd.Series(pseudo_r2, index = ind, name = 'prsq')],
                             axis = 1)
        res = pd.concat([temp_res, res])
        # confusion matrix
        cm = confusion_matrix(y_test, lr.predict(X_test), labels=lr.classes_, normalize = 'true')
        mat+=cm
        print('--------------------------------------')
plt.figure()
sns.heatmap(mat/len(animal), cmap = 'viridis', annot = True, vmin = 0, vmax = 1)
plt.title(f'C:R model, prsq = {round(res.prsq.mean(), 3)}, acc = {round(res.acc.mean(), 3)}')
print(res.acc.mean(), res.prsq.mean())

test05022023




-0.0232 0.55969 test05022023 unstrunstr
--------------------------------------
Blissey




0.00592 0.37171 Blissey unstrunstr
--------------------------------------
Chikorita




-0.01967 0.59374 Chikorita unstrunstr
--------------------------------------
Darkrai




-0.01595 0.62232 Darkrai unstrunstr
--------------------------------------
Eevee




-0.02069 0.59107 Eevee unstrunstr
--------------------------------------
Goldeen




-0.00016 0.38235 Goldeen unstrunstr
--------------------------------------
Hoppip




0.00188 0.31681 Hoppip unstrunstr
--------------------------------------
Inkay




0.00255 0.38525 Inkay unstrunstr
--------------------------------------
Jirachi




0.01098 0.36492 Jirachi unstrunstr
--------------------------------------
Kirlia




0.01201 0.29827 Kirlia unstrunstr
--------------------------------------
Mesprit




0.01386 0.36943 Mesprit unstrunstr
--------------------------------------
Nidorina




0.00195 0.376 Nidorina unstrunstr
--------------------------------------
Oddish




0.00951 0.31911 Oddish unstrunstr
--------------------------------------
Phione




-0.01932 0.57381 Phione unstrunstr
--------------------------------------
Quilava




0.00844 0.33447 Quilava unstrunstr
--------------------------------------
Raltz




-0.00316 0.42818 Raltz unstrunstr
--------------------------------------
Shinx




0.00484 0.38537 Shinx unstrunstr
--------------------------------------
Togepi




0.01535 0.34 Togepi unstrunstr
--------------------------------------
Umbreon




-0.002 0.41098 Umbreon unstrunstr
--------------------------------------
Vulpix




0.01513 0.34797 Vulpix unstrunstr
--------------------------------------
Xatu




-0.01805 0.5408 Xatu unstrunstr
--------------------------------------
Yanma




-0.01988 0.56647 Yanma unstrunstr
--------------------------------------
Zacian




0.00367 0.37966 Zacian unstrunstr
--------------------------------------
Alakazam




0.00177 0.3792 Alakazam unstrunstr
--------------------------------------
Bayleef




0.00055 0.28094 Bayleef unstrunstr
--------------------------------------
Cresselia




-0.00644 0.48853 Cresselia unstrunstr
--------------------------------------
Emolga




0.00477 0.35231 Emolga unstrunstr
--------------------------------------
Giratina




-1e-05 0.35449 Giratina unstrunstr
--------------------------------------
Haxorus




-0.01728 0.57412 Haxorus unstrunstr
--------------------------------------
Ivysaur




0.01657 0.31994 Ivysaur unstrunstr
--------------------------------------
Jigglypuff




-0.02636 0.46895 Jigglypuff unstrunstr
--------------------------------------
Lugia




-0.0254 0.42864 Lugia unstrunstr
--------------------------------------
0.42204729481669156 -0.002744173395498227


In [167]:
rew_cols

Unnamed: 0,reward_t1_0,reward_t1_1,reward_t2_0,reward_t2_1,reward_t3_0,reward_t3_1
3,False,True,True,False,True,False
4,False,True,False,True,True,False
5,False,True,False,True,False,True
6,True,False,False,True,False,True
7,False,True,True,False,False,True
...,...,...,...,...,...,...
7609387,False,True,False,True,False,True
7609388,False,True,False,True,False,True
7609389,False,True,False,True,False,True
7609390,False,True,False,True,False,True


In [173]:

from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss
from pandas.api.types import CategoricalDtype
from sklearn.metrics import confusion_matrix
catdtype = CategoricalDtype(categories=np.arange(1,5), ordered = False)

# choice*reward only model
res_crc = pd.DataFrame()
dfs = [sessdf_prep]
animal = sessdf_prep[~sessdf_prep.animal.isin(['Kakuna', 'Finneon', 'Bayleef'])].animal.unique()
cond = ['unstr']

# define y var
y_cols = ['choice_t0']
mat = 0
# define x var
# x_cols = ['choice_t1_1', 'choice_t1_2', 'choice_t1_3', 'choice_t1_4',
#           'choice_t2_1', 'choice_t2_2', 'choice_t2_3', 'choice_t2_4',
#           'choice_t3_1', 'choice_t3_2', 'choice_t3_3', 'choice_t3_4']


# store in list then convert to df
for i, dat in enumerate(dfs):
    # get dummies 
    choice_cols = pd.get_dummies(dat[['choice_t'+str(i) for i in range(1,4)]].astype(catdtype),
                                        drop_first = False)
    rew_cols = pd.get_dummies(dat[['reward_t'+str(i) for i in range(1,4)]].astype(int),
                                    drop_first = False, columns = ['reward_t'+str(i) for i in range(1,4)])
    cr_cols = np.array([choice_cols[f'choice_t{time_lag}_{choice}']*rew_cols[f'reward_t{time_lag}_1'] 
                        for time_lag in range(1,4) for choice in range(1,5)])
    cr_cols = pd.DataFrame(cr_cols.T, columns = [f'cr_t{time_lag}_{choice}_1' for time_lag in range(1,4) for choice in range(1,5)])
    # crc = cr_cols.drop(columns = ['cr_t1_4_1', 'cr_t2_4_1', 'cr_t3_4_1']).columns
    #'choice_t1_4', 'choice_t2_4', 'choice_t3_4',
    # cc = choice_cols.drop(columns = ['choice_t4_1', 'choice_t4_2', 'choice_t4_3',
    #    'choice_t4_4', 'choice_t5_1', 'choice_t5_2', 'choice_t5_3',
    #    'choice_t5_4']).columns
    # define x var
    # x_cols = np.append(choice_cols.columns[:12], cr_cols.columns)
    cc = choice_cols.columns
    x_cols = np.append(cc, cr_cols.columns)
     
    for an in animal:

        # split data by session for train/test
        from sklearn.model_selection import GroupShuffleSplit 

        splitter = GroupShuffleSplit(random_state = 7)
        split = splitter.split(dat[dat['animal']==an], groups=dat[dat['animal']==an]['session'])
        train_inds, test_inds = next(split)

        # split by indices of dat
        train = dat.iloc[train_inds]
        test = dat.iloc[test_inds]
        print(an)
        # print('training sess = ', train.session.unique())
        # print('testing sess = ', test.session.unique())


        # get x and y acc to split
        y_train = train[y_cols].to_numpy().reshape(-1)
        y_test = test[y_cols]

        X_temp = pd.concat([choice_cols, rew_cols, cr_cols], axis = 1)

        # choice only model
        X = pd.DataFrame(np.c_[X_temp[x_cols]], columns = x_cols)
        
        X_train = X.iloc[train_inds]
        X_test = X.iloc[test_inds]
        
        # X_train, X_test, y_train, y_test = train_test_split(X, y)
        
        lr = LogisticRegression(multi_class='multinomial', solver = 'lbfgs',
                                # penalty = 'l2',
                                class_weight='balanced',
                                random_state = 42, fit_intercept = False)
        lr.fit(X_train, y_train)

        cols = lr.feature_names_in_
        ind = lr.classes_
        
        score = lr.score(X_test, y_test)

        y_pred_proba = lr.predict_proba(X_train)

        # ll_null = log_loss(y_train, [calc_prob(y_train)]*len(y_train), labels = [1,2,3,4])
        # ll_model = log_loss(y_train, y_pred_proba, labels = [1,2,3,4])

        ll_null = log_loss(y_train, [calc_prob(y_train)]*len(y_train))
        ll_model = log_loss(y_train, y_pred_proba)

        pseudo_r2 = (ll_null - ll_model) / ll_null
        print(round(pseudo_r2, 5), round(score, 5), an, cond[i]+dat[dat.animal==an].task.unique()[0])
        temp_res = pd.concat([pd.DataFrame(lr.coef_, columns = cols, index = ind),
                              pd.Series(lr.intercept_, index = ind, name = 'intercept'),
                              pd.Series(an, index = ind, name = 'animal'),
                              pd.Series(cond[i]+dat[dat.animal==an].task.unique()[0], index = ind, name = 'task'), 
                             pd.Series(score, index = ind, name = 'acc'),
                             pd.Series(pseudo_r2, index = ind, name = 'prsq')],
                             axis = 1)
        res_crc = pd.concat([temp_res, res_crc])
        
        # confusion matrix
        cm = confusion_matrix(y_test, lr.predict(X_test), labels=lr.classes_, normalize = 'true')
        mat+=cm
        print('--------------------------------------')
plt.figure()
sns.heatmap(mat/len(animal), cmap = 'viridis', annot = True, vmin = 0, vmax = 1)
plt.title(f'C+ C:R model, prsq = {round(res.prsq.mean(), 3)}, acc = {round(res.acc.mean(), 3)}')
print(res.acc.mean(), res.prsq.mean())

test05022023




0.14758 0.59177 test05022023 unstrunstr
--------------------------------------
Blissey




0.34718 0.69299 Blissey unstrunstr
--------------------------------------
Chikorita




0.12916 0.60194 Chikorita unstrunstr
--------------------------------------
Darkrai




0.15203 0.585 Darkrai unstrunstr
--------------------------------------
Eevee




0.0755 0.55913 Eevee unstrunstr
--------------------------------------
Goldeen




0.33045 0.70156 Goldeen unstrunstr
--------------------------------------
Hoppip




0.33151 0.66051 Hoppip unstrunstr
--------------------------------------
Inkay




0.34404 0.69889 Inkay unstrunstr
--------------------------------------
Jirachi




0.34495 0.68923 Jirachi unstrunstr
--------------------------------------
Kirlia




0.34709 0.69299 Kirlia unstrunstr
--------------------------------------
Mesprit




0.34235 0.72004 Mesprit unstrunstr
--------------------------------------
Nidorina




0.34226 0.69414 Nidorina unstrunstr
--------------------------------------
Oddish




0.35511 0.66322 Oddish unstrunstr
--------------------------------------
Phione




0.10961 0.676 Phione unstrunstr
--------------------------------------
Quilava




0.32875 0.72238 Quilava unstrunstr
--------------------------------------
Raltz




0.25257 0.65614 Raltz unstrunstr
--------------------------------------
Shinx




0.32936 0.71037 Shinx unstrunstr
--------------------------------------
Togepi




0.34509 0.70047 Togepi unstrunstr
--------------------------------------
Umbreon




0.25928 0.65319 Umbreon unstrunstr
--------------------------------------
Vulpix




0.33663 0.70877 Vulpix unstrunstr
--------------------------------------
Xatu




0.11952 0.59852 Xatu unstrunstr
--------------------------------------
Yanma




0.08885 0.62 Yanma unstrunstr
--------------------------------------
Zacian




0.34631 0.69895 Zacian unstrunstr
--------------------------------------
Alakazam




0.32468 0.70719 Alakazam unstrunstr
--------------------------------------
Cresselia




0.18379 0.68095 Cresselia unstrunstr
--------------------------------------
Emolga




0.3455 0.68899 Emolga unstrunstr
--------------------------------------
Giratina




0.34205 0.70805 Giratina unstrunstr
--------------------------------------
Haxorus




0.13602 0.66938 Haxorus unstrunstr
--------------------------------------
Ivysaur




0.3474 0.705 Ivysaur unstrunstr
--------------------------------------
Jigglypuff




-0.00222 0.56737 Jigglypuff unstrunstr
--------------------------------------
Lugia




0.00407 0.56286 Lugia unstrunstr
--------------------------------------
0.6460149482529143 0.25982867079539446


In [174]:
# plot all coefficients averaged across animals
res_crc.reset_index().groupby('index')[np.append(x_cols, 'intercept')].mean().T.plot(kind = 'bar', figsize = (10,5))

<Axes: >

In [169]:
# unrew_terms = ['cr_t1_1_-1', 'cr_t1_2_-1', 'cr_t1_3_-1', 'cr_t1_4_-1', 'cr_t2_1_-1', 'cr_t2_2_-1', 'cr_t2_3_-1', 'cr_t2_4_-1', 'cr_t3_1_-1', 'cr_t3_2_-1', 'cr_t3_3_-1', 'cr_t3_4_-1']
rew_terms = ['cr_t1_1_1', 'cr_t1_2_1', 'cr_t1_3_1', 'cr_t1_4_1', 'cr_t2_1_1', 'cr_t2_2_1', 'cr_t2_3_1', 'cr_t2_4_1', 'cr_t3_1_1', 'cr_t3_2_1', 'cr_t3_3_1', 'cr_t3_4_1']
# plt.figure()
# sns.heatmap(res.reset_index().groupby('index')[unrew_terms[4:8]].mean(),
#             xticklabels = ['1 unrew', '2 unrew', '3 unrew', '4 unrew'],
#             annot = True,
#             fmt = '.2f',
#             cmap = 'vlag',
#             center = 0)
plt.figure()
sns.heatmap(res_crc.reset_index().groupby('index')[rew_terms[:4]].mean(),
            xticklabels = ['1 rew', '2 rew', '3 rew', '4 rew'],
            annot = True,
            fmt = '.2f',
            cmap = 'vlag',
            center = 0)

<Axes: ylabel='index'>

In [187]:
np.add(res_crc[res_crc.animal == animal][cc[:4]], res_crc[res_crc.animal == animal][rew_terms[:4]].to_numpy())

Unnamed: 0,choice_t1_1,choice_t1_2,choice_t1_3,choice_t1_4
1.0,0.414106,-0.267291,0.049313,-0.406879
2.0,-0.375817,0.771039,-0.077985,-0.43506
3.0,0.109121,0.04767,0.456221,-0.193897
4.0,-0.14741,-0.551418,-0.427548,1.035836


In [191]:
animal = 'Quilava'
plt.figure()
sns.heatmap(res_crc[res_crc.animal == animal][cc[:4]].to_numpy()+res_crc[res_crc.animal == animal][rew_terms[:4]], annot = True, fmt = '.2f', cmap = 'seismic', center = 0)
plt.title('Choice at t-1*rew predicting choice at t0')
plt.figure()
sns.heatmap(res_crc[res_crc.animal == animal][cc[:4]], annot = True, fmt = '.2f', cmap = 'seismic', center = 0)
plt.title('Choice at t-1* unrew predicting choice at t0')
# plt.figure()
# sns.heatmap(res_crc[res_crc.animal == animal][crc[8:]], annot = True, fmt = '.2f', cmap = 'seismic', center = 0)
# plt.title('Choice at t-3 predicting choice at t0')

Text(0.5, 1.0, 'Choice at t-1* unrew predicting choice at t0')

In [172]:
animal = 'Jirachi'
plt.figure()
sns.heatmap(res[res.animal == animal][cc[:4]], annot = True, fmt = '.2f', cmap = 'seismic', center = 0)
plt.title('Choice at t-1 predicting choice at t0')
plt.figure()
sns.heatmap(res[res.animal == animal][cc[4:8]], annot = True, fmt = '.2f', cmap = 'seismic', center = 0)
plt.title('Choice at t-2 predicting choice at t0')
plt.figure()
sns.heatmap(res[res.animal == animal][cc[8:]], annot = True, fmt = '.2f', cmap = 'seismic', center = 0)
plt.title('Choice at t-3 predicting choice at t0')
# plt.figure()
# sns.heatmap(res[res.animal == animal][unrew_terms[:4]], annot = True, fmt = '.2f', cmap = 'vlag', center = 0)
# plt.title('Unrewarded')
# plt.figure()
# sns.heatmap(res[res.animal == animal][rew_terms[:4]], annot = True, fmt = '.2f', cmap = 'vlag', center = 0)
# plt.title('Rewarded')

Text(0.5, 1.0, 'Choice at t-3 predicting choice at t0')