# Import and Prob functions

In [72]:
import numpy as np
from utils.plotSettings import *
import pandas as pd
import os
from scipy.ndimage import uniform_filter, gaussian_filter1d
from scipy.stats import sem, pearsonr, entropy

In [2]:
################# PROBABILITY FUNCTIONS ################
def fxn(mean, arms):
    x = np.linspace(1, arms, arms)
    sig = 1.75/2
#     amp = 1/(sig*np.sqrt(2*np.pi))
    amp = 0.7
    vo = 0.1
    gx = (amp*np.exp(-0.5*((x-mean)**2)/(sig**2)))+vo
#     gx = np.random.permutation(gx)
    return gx

def cauchy(median, arms):
#     cauchy = @(x, s, t)(1./(s*pi*(1+(((x-t)./s).^2))))
    x = np.linspace(1, arms, arms)
    s = 1.5
    cauchy = ((1/(s*np.pi*(1+(((x-median)/s)**2))))*3.5)+0.1;
    return cauchy

In [73]:
def fxn(mean, arms, permute = False):
    x = np.linspace(1, arms, arms)
    sig = 1.75/2
#     amp = 1/(sig*np.sqrt(2*np.pi))
    amp = 0.7
    vo = 0.1
    gx = (amp*np.exp(-0.5*((x-mean)**2)/(sig**2)))+vo
    if permute:
        gx = np.random.permutation(gx)
    return gx
    
fig = plt.figure(figsize = (4,6))
# plt.figure()
arms = 4

l = [fxn(i, arms, True) for i in range(1,arms+1)]
for ind in range(arms):
    ax = plt.subplot(4, 1, ind+1)
    ax.bar(np.arange(1, arms+1), l[ind], color = 'xkcd:pumpkin')
    
    sns.despine()
    # ax.set_xticks(np.arange(1,arms+1), np.arange(1,arms+1))
    ax.set_ylim(0, 0.8)
    ax.set_yticks([0.0, 0.8], [0, 80])
ax.set_xticks(np.arange(1,arms+1), np.arange(1,arms+1))
print(l,'\n')
    # ax.set_yticks(np.arange(4)*0.25, np.arange(4)*0.25)
#     if ind == 3:
#         break
# fig.supxlabel('Ports')
# fig.supylabel('Reward probability')

plt.tight_layout()


# For use of "viscm view"
# test_cm = parula_map

# if __name__ == "__main__":
#     import matplotlib.pyplot as plt
#     import numpy as np

    # try:
    #     from viscm import viscm
    #     viscm(parula_map)
    # except ImportError:
    #     print("viscm not found, falling back on simple display")
    #     plt.imshow(np.linspace(0, 100, 256)[None, :], aspect='auto',
    #                cmap=parula_map)
    # plt.show()
from utils.plotSettings import *
plt.figure()
l = [fxn(np.random.randint(1, arms+1), arms, False) for i in range(10000)]
sns.heatmap(np.corrcoef(np.array(l).T),
            cmap = 'winter', vmin = -1, vmax = 1, square = True,
            xticklabels = np.arange(1, arms+1), yticklabels = np.arange(1, arms+1))
plt.title('Structured')

[array([0.46431508, 0.8       , 0.10196115, 0.15135876]), array([0.46431508, 0.8       , 0.15135876, 0.46431508]), array([0.8       , 0.15135876, 0.46431508, 0.46431508]), array([0.46431508, 0.10196115, 0.15135876, 0.8       ])] 



Text(0.5, 1.0, 'Structured')

# Functions for value-based

In [74]:
################ ACTION SELECTION #######################
# epsilon greedy action selection
def epsilon_greedy(eps, actions, n_arms, value): 
    randn = np.random.uniform(0,1)
    if randn <= eps:
        action = np.random.randint(1, n_arms+1)
    else:
        action = np.random.choice((np.where(value == np.amax(value))[0])+1)
    return action, 1


# softmax action selection
def softmax(inv_temp, actions, arms, value):
    prob_choosing_action = np.zeros(len(arms))

    for arm in range(len(arms)):
        prob_choosing_action[arm] = (np.exp(value[arm]*inv_temp)) / np.sum(np.exp(value*inv_temp))

#     actions = np.random.choice(arms, p = prob_choosing_action) ###### DO NOT USE THIS ######
#     print(prob_choosing_action)
    actions = np.random.multinomial(1, prob_choosing_action)
    a = arms[actions.nonzero()[0][0]]
    return a, prob_choosing_action

def softmax_sticky(inv_temp, actions, arms, value, h, scaler):
    prob_choosing_action = np.exp((value*inv_temp)+ (h*scaler))
    prob_choosing_action/=np.sum(prob_choosing_action)
    actions = np.random.multinomial(1, prob_choosing_action)
    a = arms[actions.nonzero()[0][0]]
    return a, prob_choosing_action

# softmax biased action selection
def softmax_biased(inv_temp, actions, arms, value, bias):
    prob_choosing_action_biased = np.zeros(len(arms))

    for arm in range(len(arms)):
        prob_choosing_action_biased[arm] = (np.exp((value[arm]+bias[arm])*inv_temp))/(np.sum(np.exp((value+bias)*inv_temp)))
    
#     actions = np.random.choice(arms, p = prob_choosing_action) ###### DO NOT USE THIS ######
#         print(prob_choosing_action_biased, arm)

    actions = np.random.multinomial(1, prob_choosing_action_biased)
    a = arms[actions.nonzero()[0][0]]
#     print(prob_choosing_action_biased)
    return a, prob_choosing_action_biased

# softmax weighted bias action selection
def softmax_wbias(inv_temp, actions, arms, value, wbias, w):
    prob_choosing_action_wbias = np.zeros(len(arms))
    
    for arm in range(len(arms)):
        softmax = (np.exp(value[arm]*inv_temp)) / np.sum(np.exp(value*inv_temp))
        prob_choosing_action_wbias[arm] = (w*softmax) + ((1-w)*wbias[arm])
    actions = np.random.multinomial(1, prob_choosing_action_wbias)
    a = arms[actions.nonzero()[0][0]]

    return a, prob_choosing_action_wbias    

# Win-stay lose-shift 
def wsls(actions, reward, n_arms, shift_prob):
    if len(actions) == 0:
        actions = [np.random.randint(1,n_arms+1)]
    av_actions = list(1,range(n_arms+1))
    if reward==1:
        action = actions[-1]
    else:
        shift = np.random.uniform(0,1)
        if shift<=shift_prob:
            av_actions.pop(actions[-1])
            action = np.random.choice(av_actions)
        else:
            action = actions[-1]
    return action, 1


# upper confidence bound
def ucb(c, actions, arms, value):
    prob_choosing_action = np.zeros(len(arms))
    nt = np.zeros(n_arms)
    
    for arm in range(len(arms)):
        nt[arm] = actions.count(arm)
        if nt[arm]!=0:
            prob_choosing_action[arm] = value[arm] + (c*np.sqrt(np.log(len(actions)/nt[arm])))
        else:
            prob_choosing_action[arm] = 1
        
    action = np.random.choice(np.where((prob_choosing_action == np.amax(prob_choosing_action))[0]))
    return action, prob_choosing_action


#################### REGRET #######################
# regret at each timestep for minimization?
def regret(action, prob_arms):
    reg = max(prob_arms) - prob_arms[action-1] 
    return reg

################ GIVING REWARD #####################
def rewarding(prob, reward_val):
    temp = reward_val
    rand = np.random.uniform(0, 1)
    return temp if rand <= prob else 0

################# VANILLA VALUE UPDATION #############
def qlearn(value, action, alpha, reward):
    value[int(action)-1] = value[int(action)-1] + alpha * (reward - value[int(action)-1])
    return value 

################# DEVALUE OTHER ARMS #################
def qlearnAllArms(value, action, alpha, reward):
    for val in range(len(value)):
        if val == (int(action)-1):
            value[val] = value[val] + alpha[0] * (reward - value[int(action)-1])
        else:
            value[val] = value[val] + alpha[1] * (reward - value[int(action)-1])
    return value


############### BAYESIAN(?) VALUE UPDATION ###########
def bayesQlearn(value, action, alpha, reward):
    value[int(action)] = value[int(action)] + alpha * (reward - value[int(action)])
    return value

############## NEW MATRIX-WISE LEARNING ########### 
def qlearnAllMat(value, action, alpha, reward):
    for i in range(len(value)):
        value[i] = value[i] + alpha[int(action)-1, i] * (reward - value[i])
    return value


############# side-weighted softmax values - gaussian transformation? ##############
def convSoftmax(inv_temp, actions, arms, value, sd):
    prob_choosing_action = np.zeros(len(arms))
    value = gaussian_filter1d(value, sigma = sd)
    
    for arm in range(len(arms)):
        prob_choosing_action[arm] = (np.exp(value[arm]*inv_temp)) / np.sum(np.exp(value*inv_temp))

    # prob_choosing_action = gaussian_filter1d(prob_choosing_action, sigma = sd)
    actions = np.random.multinomial(1, prob_choosing_action)
    a = arms[actions.nonzero()[0][0]]
    return a, prob_choosing_action


In [None]:
def calc_prob(pk):
    # calc prob of actions
    unique, counts = np.unique(np.array(pk), return_counts =True)
    outcomes = len(pk)
    return counts/outcomes

In [7]:
# alphas = np.linspace(0,1, num=5)
# taus = np.logspace(-2,2, num=5)
# cs = np.linspace(0,0.5, num=5)
# fig = plt.figure(figsize = (15, 10))
# ind=1
 
# for i, alpha in enumerate(alphas):
#     for j, temperature in enumerate(taus):

############### ENVIRONMENT #######################
np.random.seed(4231)
n_arms = 4
arms = list(range(1,n_arms+1))

prob_arms = np.ones(n_arms)
rew_val = np.ones(n_arms)
eps = 0.2
alpha = 0.1
# acc to fit params - mean of all alpha and tau fitted = [0.09339148, 0.18939606] # 20250123
gamma = 0.2
c = 0.1
shift_prob = 1

# temperature = 0.18939606
tau = 0.1
inv_temp = 1/tau
sd = 0.8
# bias = [0., 0, 0.1, 0]
# wbias = np.array([0.1, 0.5, 0.3, 0.1])
# sum(wbias)

In [59]:
# w=0.5

trials = 100
sessions = 1000
window = 5

sess_mean_list = []
reward_hist = {}
rew_prob = {}
value_hist = {}
action_hist = {}
corrcoef_hist = {}
regret_hist = {}
lls= []
df = pd.DataFrame()
df['trial'] = " "
df['action'] = " "
df['reward'] = " "
df['session'] = " "
df['rewprob'] = " "
df['regret'] = " "


################## run here ##############
chance_level_sess = []
# h = np.zeros(n_arms)


# x0 = np.array([0.15387731, 0.08234891, 0.0685517 , 0.05269941, 0.28824557, 0.33129077])
# x0 = np.array([0.15, -0.01, -0.08, -0.085, 0.06]) # structured
# x0 = np.array([0.17, -0.07, -0.04, -0.05, 0.07]) # unstructured
# x0 = np.array([0.3, -0.1, -0.1, -0.1, 0.1])
# x0 = np.array([0.3, 0.1, -0.1, -0.2, 0.1, 0.25, 1])
# alpha_diag, alpha_1diag, alpha_2diag, alpha_3diag, tau, sticky, scaler = x0
x0 = np.array([0.2, 0.0, 0, 0])
alpha_diag, alpha_1diag, alpha_2diag, alpha_3diag = x0
alpha = np.array([[alpha_diag, alpha_1diag, alpha_2diag, alpha_3diag],
                      [alpha_1diag, alpha_diag, alpha_1diag, alpha_2diag],
                      [alpha_2diag, alpha_1diag, alpha_diag, alpha_1diag],
                      [alpha_3diag, alpha_2diag, alpha_1diag, alpha_diag]])
inv_temp = 1/tau
value_arr = np.zeros((sessions, trials, n_arms))

for sess in range(sessions):

    actions = []
    if sess%3==0:
        q0 = 0.25*np.ones(n_arms)
    # h = np.zeros(n_arms)
    value=np.copy(q0)
    
    sess_mean = np.random.randint(1,n_arms+1) # check randint docs for details - returns number between 1 and 8
    sess_mean_list.append(sess_mean)
    gx = fxn(sess_mean, n_arms, True)
    # l = [fxn(np.random.randint(1, n_arms+1), n_arms, False) for i in range(10000)]
    # alpha = np.corrcoef(np.array(l).T)*0.1
#     median = np.random.randint(1,n_arms+1)
#      gx = cauchy(median, n_arms)

    prob_arms = np.random.permutation(np.copy(gx))
    rew_prob[sess] = prob_arms
#     rew_val = np.copy(gx)

    chance_level_sess.append(np.mean(prob_arms, axis = 0))
    rew_temp = []
    value_temp = []
    corrcoef = []
    regrets = []
        
    reward = 0
    
    for trial in range(trials):
#         action, p = epsilon_greedy(eps, actions, n_arms, value)
#         action, p = wsls(actions, reward, n_arms, shift_prob)
#         action, p = ucb(c, actions, arms, value)
        action, p = softmax(inv_temp, actions, arms, value)
        # action, p = softmax_sticky(inv_temp, actions, arms, value, h, scaler)
#         action, p = softmax_biased(inv_temp, actions, arms, value, bias)
#         action, p = softmax_wbias(inv_temp, actions, arms, value, wbias, w)
        
        actions.append(action)
        reg = regret(action, prob_arms)
        regrets.append(reg)
        reward = rewarding(prob_arms[int(action)-1], rew_val[int(action)-1])
        rew_temp.append(reward)
        df.loc[len(df.index)] = [trial, action, reward, sess, prob_arms[int(action)-1], reg]
        # value = qlearn(value, action, alpha, reward)
        # value = qlearnAllArms(value, action, alpha, reward)
        value = qlearnAllMat(value, action, alpha, reward)
        # add perseverative term to chosen arm
        chosen = np.zeros(n_arms)
        chosen[int(action)-1] = 1
        # h = h + sticky*(chosen - h)
        
        # save values!!
        value_arr[sess, trial] = value
    # value_temp.append(value)
    # value_hist[sess] = value_temp
    regret_hist[sess] = regrets
    
plt.plot(df.groupby('session')['action'].get_group(0), 'o')

[<matplotlib.lines.Line2D at 0x1e84556b1c0>]

In [69]:
df['choice_t1'] = df.groupby('session').action.shift(-1)
df['choice_t2'] = df.groupby('session').action.shift(-2)
df['shift_t0'] = (df['choice_t1']==df['action']).replace({True: 0, False: 1})
df['shift_t1'] = (df['choice_t2']==df['action']).replace({True: 0, False: 1})
# df['rr'] = (df.groupby('session', as_index = False)
#             .reward
#             .rolling(window, center=True)
#             .mean()
#             .reward)
# from utils.supplementaryFunctions import calc_prob
# df['entropy'] = (df.groupby('session', as_index = False)
#                      .action
#                      .rolling(window, center=True)
#                      .apply(lambda x: entropy(calc_prob(x), base = 2))
#                      .action)

  df['shift_t0'] = (df['choice_t1']==df['action']).replace({True: 0, False: 1})
  df['shift_t1'] = (df['choice_t2']==df['action']).replace({True: 0, False: 1})


In [None]:
def data_prep(dataset, hist = 20, trialsinsess = 100, head = False):
    dataset = dataset.groupby(['session']).filter(lambda x: x.reward.size >= trialsinsess)
    dataset['ct0'] = dataset.action.values
    for i in range(1,hist): 
        dataset['ct'+str(i)] = dataset.groupby(['session']).action.shift(i) #previous action
        dataset['shift_t'+str(i-1)] = dataset['ct'+str(i)]==dataset['ct'+str(i-1)]
        dataset['shift_t'+str(i-1)] = dataset['shift_t'+str(i-1)].replace({True: 0, False: 1})
        dataset['rt'+str(i)] = dataset.groupby(['session']).reward.shift(i) # previous reward
        dataset['rt'+str(i)] = dataset['rt'+str(i)]#.replace({0:-1})
#         dataset['choice_t'+str(i)] = dataset['choice_t'+str(i)].replace({1:'a', 2:'b', 3:'c', 4:'d'})
    dataset = dataset.dropna()
    if head == True:
        dataset = dataset.groupby(['session']).head(trialsinsess)

    return dataset
hist = 6
data = data_prep(df, hist = hist, trialsinsess = 100, head = True)

q = value_arr.reshape(df.shape[0], 4)
r = data.reward.to_numpy()
r1 = data.rt1.to_numpy()
r2 = data.rt2.to_numpy()
r3 = data.rt3.to_numpy()
r4 = data.rt4.to_numpy()
#  get rpe
a = (data.action.to_numpy(dtype=int))-1
rpe = r - q[data.index, a]

data['rpe'] = rpe

# make lin reg model
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
d1 = np.abs(data.action - data.ct1)
d2 = np.abs(data.action - data.ct2)
d3 = np.abs(data.action - data.ct3)
coefs = np.zeros((4, 3))
intercepts = np.zeros((4, 3))
for trial_lag in range(3):
    r_arr = [r1, r2, r3]
    d_arr = [d1, d2, d3]
    for dist in range(4):
        y = rpe[(d_arr[trial_lag]==dist)]
        X = r_arr[trial_lag][(d_arr[trial_lag]==dist)].reshape(-1, 1)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

        reg = LinearRegression(fit_intercept = True).fit(X_train, y_train)
        coefs[dist, trial_lag] = reg.coef_[0]
        print(reg.coef_, dist, trial_lag)
        intercepts[dist, trial_lag] = reg.intercept_

In [None]:
plt.figure()
plt.plot(coefs.T, 'o-')
plt.xticks([0, 1, 2], ['t-1', 't-2', 't-3'])
plt.legend([0, 1, 2, 3], title = 'distance')
plt.xlabel('trial history')
plt.ylabel('coefficient between reward t-n and rpe (slope)')
plt.title(f'RPE{[alpha_diag, alpha_1diag, alpha_2diag, alpha_3diag, tau, sticky, scaler]}')
# sns.despine()
plt.tight_layout()

In [35]:
# potentially plot everything for this model, rr, entropy, tm, regret, distance, bias analysis, variability 
%matplotlib qt
fig = plt.figure(figsize = (10, 7))

def avg_mat(df, col):
    g = df.groupby('session').cumcount()
    L = np.array(df.set_index(['session',g])
           .unstack(fill_value=0)
           .stack().groupby(level=0)
           .apply(lambda x: x[col].values.tolist())
           .tolist())
    return L


# figure 1 - regret across all sessions
ax = plt.subplot(221)
reg_mat = avg_mat(df, 'regret')
reg_mean = np.mean(reg_mat, axis = 0)
reg_sem = sem(reg_mat, nan_policy = 'omit')
ax.plot(reg_mean, color = 'xkcd:azure')
ax.fill_between(np.arange(reg_mat.shape[1]), reg_mean - reg_sem, reg_mean + reg_sem,  color = 'xkcd:azure', alpha = 0.2)
ax.set_title('Regret')

# figure 2 - performance plot across all sessions
ax = plt.subplot(222)
rr_mat = avg_mat(df, 'rr')
rr_mean = np.mean(rr_mat, axis = 0)
rr_sem = sem(rr_mat, nan_policy = 'omit')
ax.plot(rr_mean, color = 'xkcd:azure')
ax.fill_between(np.arange(rr_mat.shape[1]), rr_mean - rr_sem, rr_mean + rr_sem,  color = 'xkcd:azure', alpha = 0.2)
ax.set_title('Performance - reward rate')

# figure 3 - entropy plot across all sessions
ax = plt.subplot(223)
entropy_mat = avg_mat(df, 'entropy')
entropy_mean = np.mean(entropy_mat, axis = 0)
entropy_sem = sem(entropy_mat, nan_policy = 'omit')
ax.plot(entropy_mean, color = 'xkcd:azure')
ax.fill_between(np.arange(entropy_mat.shape[1]), entropy_mean - entropy_sem,
                 entropy_mean + entropy_sem,  color = 'xkcd:azure', alpha = 0.2)
ax.set_title('Entropy')

sns.despine()
from utils.plotSettings import *
parula = get_parula_cmap()
# figure 4 - transition matrix
ax = plt.subplot(224)
sns.heatmap(pd.crosstab(df.action, df.choice_t1, normalize = 'index'),
            cmap = parula, annot = True, fmt = '.2f', vmin = 0.0, vmax = 0.7, #mask = np.eye(4),
            xticklabels = np.arange(1,5), yticklabels = np.arange(1,5), ax = ax)
# ax.patch.set_facecolor('white')
ax.set_title('Transition matrix')

  .stack().groupby(level=0)
  .stack().groupby(level=0)
  rr_sem = sem(rr_mat, nan_policy = 'omit')
  .stack().groupby(level=0)
  entropy_sem = sem(entropy_mat, nan_policy = 'omit')


Text(0.5, 1.0, 'Transition matrix')

In [None]:
plt.figure()
sns.heatmap(alpha, cmap = parula_map, annot = True)

In [70]:
plt.figure()
sns.heatmap(pd.crosstab(df.action, df.choice_t1, normalize = 'index'), square = True,
            cmap = parula, vmin = 0, vmax = 0.5, xticklabels=np.arange(1,5), yticklabels= np.arange(1,5))
plt.xlabel('Choice at t+1')
plt.ylabel('Choice at t')
plt.tight_layout()
# ax.patch.set_facecolor('white')


# Gradient bandit algorithm

In [None]:
# setup env 
arms = 4
alpha = 0.1
trials = 100
sessions = 1
rew_val = 1
# prob_arms = fxn(2, 4)

# parameterized policy 
def policy(a, theta_arm):
    return (np.exp(theta_arm[a])/ np.sum(np.exp(theta_arm)))

# update policy
for session in range(sessions):
    # initialize
    rr = np.zeros(trials)
    theta_arm = np.zeros(arms)
    del_theta = np.zeros(arms)
    R_hist = np.zeros(trials)
    mean_p = np.random.randint(1, 5)
    prob_arms = fxn(mean_p, arms)
    
    for trial in range(trials):
        chosen = np.random.multinomial(1, [policy(a, theta_arm) for a in range(arms)]).nonzero()[0][0]
    
        R = rewarding(prob_arms[chosen], rew_val)
        
        # R_hist[trial] = R
    
        rr[trial] = np.nanmean(R_hist)
    
        del_theta = [(alpha*(1-policy(a, theta_arm))*(R - rr[trial])) 
                     if (chosen == a) 
                     else (-alpha*(policy(a, theta_arm))*(R - rr[trial])) for a in range(arms)]
            
        theta_arm = [(theta_arm[a] + del_theta[a]) for a in range(arms)]
        
        print(trial, R, chosen, theta_arm[chosen])

        R_hist[trial] = R

# Thompson sampling variants

In [75]:
%matplotlib qt
from scipy.stats import beta
trials = 100
sessions = 1000
arms = 4

a = np.zeros((sessions, trials))
r = np.zeros((sessions, trials))
p_est = np.zeros((sessions, trials, arms), dtype = float)
rp_set = [fxn(np.random.randint(1, arms+1), arms, True) for i in range(10000)]
w_diag, w_1diag, w_2diag, w_3diag = 0.95, 0.8, 0.7, 0.6
w = np.array([[w_diag, w_1diag, w_2diag, w_3diag],
                      [w_1diag, w_diag, w_1diag, w_2diag],
                      [w_2diag, w_1diag, w_diag, w_1diag],
                      [w_3diag, w_2diag, w_1diag, w_diag]])
alpha0 = 1
beta0 = 1
alphas = np.ones(arms)*alpha0
betas = np.ones(arms)*beta0
# samples = np.zeros((arms, 500))

# c = 2
# w = 1
for sess in range(sessions):
    # reward probability changes here
    rp = rp_set[sess]
    alphas = np.ones(arms)*alpha0
    betas = np.ones(arms)*beta0
    # if sess%3==0:
    #     alphas = np.ones(arms)*alpha0
    #     betas = np.ones(arms)*beta0
    for t in range(trials):
        # draw mean of distribution from s+ alpha, f+beta
        # p_est[sess, t, arm] = (alphas[arm])/(alphas[arm]+betas[arm])   # beta prior
        # p_est[sess, t, arm] = ((alphas[arm]+alpha0)/(alphas[arm]+alpha0+betas[arm]+beta0))+(beta.std(alphas[arm]+alpha0, betas[arm]+beta0)*c)       # beta prior mean+ c*ucb
        samples = np.random.beta(alphas, betas, size = (500, arms))  # trying to get var estimate
        # p_est[sess, t, arm] = np.random.beta(alphas[arm]+alpha0, betas[arm]+beta0, size = 500)      # actual thompson sampling

        # draw arm using samples drawn from distr
        arm_choices = np.argmax(samples, axis=1)
        counts = np.bincount(arm_choices, minlength=arms)
        p_est[sess, t, :] = counts / counts.sum()
        # p_est[sess, t, :] = np.clip(p_est[sess, t, :], a_min=1e-6, a_max = 1)

        # a[sess, t] = np.random.choice(np.where(p_est[sess, t] == np.amax(p_est[sess, t]))[0])
        a[sess, t] = np.random.choice(np.arange(arms), p = p_est[sess, t, :])
        chosen = int(a[sess, t])
        # reward chosen arm
        r[sess, t] = rewarding(rp[int(a[sess, t])], 1)
        
        alphas = np.array([(alphas[arm]*w[arm, chosen]) for arm in range(arms)])
        betas = np.array([(betas[arm]*w[arm, chosen]) for arm in range(arms)])
        # alphas = np.clip(alphas, a_min=1, a_max = None)
        # betas = np.clip(betas, a_min = 1, a_max = None)
        alphas[chosen] = alphas[chosen]+r[sess, t]+alpha0
        betas[chosen] = betas[chosen]+(1-r[sess, t])+beta0
        
        # alphas[chosen] = (w*alphas[chosen])+r[sess, t]+alpha0
        # betas[chosen] = (w*betas[chosen])+(1-r[sess, t])+beta0

        

In [None]:
# ashesh's modified thompson sampling with learning rate l and forgetting rate w
alpha0 = 1
beta0 = 1
w = 0.95
l = 1
for sess in range(sessions):
    # reward probability changes here
    rp = rp_set[sess]
    if sess%3==0:
        alphas = np.ones(arms)*alpha0
        betas = np.ones(arms)*beta0
        s = np.zeros(arms)
        f = np.zeros(arms)

    for t in range(trials):
        # draw samples from distribution
        samples = np.random.beta(alphas, betas, size = (20, arms))  # trying to get var estimate

        # calc prob using samples
        arm_choices = np.argmax(samples, axis=1)
        counts = np.bincount(arm_choices, minlength=arms)
        p_est[sess, t, :] = counts / counts.sum()

        # select an arm using the probabilities
        a[sess, t] = np.random.choice(np.arange(arms), p = p_est[sess, t, :])
        chosen = int(a[sess, t])

        # reward chosen arm
        r[sess, t] = rewarding(rp[int(a[sess, t])], 1)

        # # update success/fail for all arms
        s = np.array([(w*x)+(l*r[sess, t]) if i==chosen else w*x for i, x in enumerate(s)])
        f = np.array([(w*x)+(l*(1-r[sess, t])) if i==chosen else w*x for i, x in enumerate(f)])

        # linked parameters
        # s = np.array([(w*x) + (l_plus[i, chosen]*r[sess, t])     + (l_minus[i, chosen]*(1-r[sess, t])) for i, x in enumerate(s)])
        # f = np.array([(w*x) + (l_plus[i, chosen]*(1-r[sess, t])) + (l_minus[i, chosen]*r[sess, t]) for i, x in enumerate(f)])

        # update alpha and beta
        alphas = alpha0 + s
        betas = beta0 + f

In [77]:
w_diag, w_1diag, w_2diag, w_3diag = 1, 0.95, 0.85, 0.7
l_plus = np.array([[w_diag, w_1diag, w_2diag, w_3diag],
                      [w_1diag, w_diag, w_1diag, w_2diag],
                      [w_2diag, w_1diag, w_diag, w_1diag],
                      [w_3diag, w_2diag, w_1diag, w_diag]])
w_diag, w_1diag, w_2diag, w_3diag = 0, 0.9, 0.8, 0.7
l_minus = np.array([[w_diag, w_1diag, w_2diag, w_3diag],
                      [w_1diag, w_diag, w_1diag, w_2diag],
                      [w_2diag, w_1diag, w_diag, w_1diag],
                      [w_3diag, w_2diag, w_1diag, w_diag]])

In [79]:
# plot behavior scatter in random subset of sessions
plt.plot(a[9:12].flatten(), '.', color = 'xkcd:cornflower')
plt.fill_between(np.arange(300),
                 np.repeat(np.argmax(rp_set[9:12], axis = 1)-0.5, 100), 
                 np.repeat(np.argmax(rp_set[9:12], axis = 1)+0.5, 100), 
                 alpha = 0.2, color = 'xkcd:cornflower')
plt.yticks([0, 1, 2, 3], [1,2,3,4])
plt.xlabel('Trials')
plt.ylabel('Choices')
plt.tight_layout()

In [80]:
# transition matrix

parula = get_parula_cmap()
sns.heatmap(pd.crosstab(pd.Series(a.flatten()).shift(-1), a.flatten(), normalize= 'index'),
            cmap = parula,
            square=True, annot=True,
            vmax = 0.25,
            xticklabels=np.arange(1,5), 
            yticklabels=np.arange(1,5))
plt.xlabel('Choice at trial t+1')
plt.ylabel('Choice at trial t')
plt.tight_layout()

In [90]:
alphas = np.array([25, 1, 1, 3])
betas = np.array([20, 1, 1, 2])
# draw expectation of each arm being selected (alp/ alp+beta)
q = alphas/(alphas+betas)
ucb = np.sqrt((alphas*betas)/(((alphas+betas)**2)*alphas+betas+np.ones(arms)))
print(q)
print(ucb)
# softmax prob of choosing actions
invtemp=1/tau
P = np.exp(invtemp*(q+c*ucb))
P = P/ np.sum(P) 
P

[0.55555556 0.5        0.5        0.6       ]
[0.09936019 0.40824829 0.40824829 0.2773501 ]


array([0.22592438, 0.17653742, 0.17653742, 0.42100078])

In [81]:
# reward rate
# plt.figure()
plt.plot(np.mean(r, axis = 0), color = 'xkcd:cornflower')
plt.fill_between(np.arange(trials), np.mean(r, axis = 0) - sem(r, axis = 0), np.mean(r, axis = 0) + sem(r, axis = 0), alpha = 0.3, color = 'xkcd:cornflower')
chance_level = np.mean(rp_set)
plt.axhline(chance_level, linestyle = '--', color = 'k')
plt.xlabel('Trials in session')
plt.ylabel('Reward rate')
plt.tight_layout()

In [82]:
# calculate switch-rate (trial-wise?)
# get all first blocks
switches = (np.diff(a, axis = 1)!=0).astype(int)
plt.plot(np.mean(np.concatenate([switches[range(0, 999, 3), :], switches[range(1, 1000, 3), :], switches[range(2, 1000, 3), :]], axis = 1), axis = 0), '.-', 
         color = 'xkcd:cornflower')
plt.axvline(99.5, color =  'grey', lw = 0.75, linestyle = '--')
plt.axvline(198.5, color =  'grey', lw = 0.75, linestyle = '--')
# plt.legend()
plt.ylabel('Switch rate')
plt.tight_layout()

In [None]:
# calculate switch-distance (trial-wise?)
# get all first blocks
switches = np.abs(np.diff(a, axis = 1))
# switches[switches==0]=np.nan
plt.plot(np.mean(np.concatenate([switches[range(0, 999, 3), :], switches[range(1, 1000, 3), :], switches[range(2, 1000, 3), :]], axis = 1), axis = 0), '.-', 
         color = 'xkcd:cornflower')
plt.axvline(99.5, color =  'grey', lw = 0.75, linestyle = '--')
plt.axvline(198.5, color =  'grey', lw = 0.75, linestyle = '--')
# plt.legend()
plt.ylabel('Switch distance')
plt.tight_layout()

In [85]:
# how many times the algo was wrong 
chosen_a = [np.argmax(p_est[sess, 99, :])- np.argmax(rp_set[sess]) for sess in range(sessions)]
sns.histplot(chosen_a, discrete = True, color = 'xkcd:cornflower')
plt.ylabel('Count')
plt.xlabel('Disp. by which wrong')
plt.tight_layout()

In [86]:
x = np.linspace(beta.ppf(0.01, 10, 10),
                beta.ppf(0.99, 10, 10), 100)
plt.figure()
for ax_n in range(arms):
    plt.subplot(1, arms, ax_n+1)
    rv = beta(alphas[ax_n], betas[ax_n])
    plt.plot(x, rv.pdf(x), 'k-', lw=2, label='frozen pdf')
    # plt.xlim(0,1)
plt.tight_layout()