# imports and functions

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import sem, entropy, norm
import seaborn as sns
from parula_cmap import *
parula = get_parula_cmap()
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
sns.set_style("whitegrid")
sns.set_context("notebook")
def fxn(mean, arms, permute = False):
    x = np.linspace(1, arms, arms)
    sig = 1.75/2
#     amp = 1/(sig*np.sqrt(2*np.pi))
    amp = 0.7
    vo = 0.1
    gx = (amp*np.exp(-0.5*((x-mean)**2)/(sig**2)))+vo
    if permute:
        gx = np.random.permutation(gx)
    return gx

def f(x, a, b, c):
    '''exponential fxn: a is amplitude, b is decay rate, c is offset'''
    return a*np.exp(b*x) + c

def gaussian_policy_gradient(sessions, trials, a_mu, a_r, arms, a_, b_, c_, f):
    """
    Gaussian policy gradient for multi-armed bandit problem.
    
    Parameters:
    sessions (int): Number of sessions.
    trials (int): Number of trials per session.
    a_mu (float): Learning rate for action update.
    a_r (float): Learning rate for reward update.
    arms (int): Number of arms.
    a_ (float): Amplitude for the exponential function.
    b_ (float): Decay rate for the exponential function.
    c_ (float): Offset for the exponential function.
    
    Returns:
    mu, sigma, V, rp, r, a: numpy arrays containing the action values, state values, rewards, and actions.
    """
    p = np.zeros((arms, sessions, trials))
    a = np.zeros((sessions, trials))
    r = np.zeros((sessions, trials))
    rp = np.zeros((sessions, arms))
    mu = np.zeros((sessions, trials))
    V = np.zeros((sessions, trials))
    sigma = np.ones((sessions, trials))

    for s in np.arange(sessions):
        center = np.random.choice(np.arange(1, arms+1))
        rp[s] = fxn(center, arms, True)
        mu[s, 0] = np.random.choice(np.arange(1, arms+1))
        V[s, 0] = 0.25
        sigma[s, 0] = 0.5 

        for t in np.arange(trials):
            # calculate probability of actions
            p[:, s, t] = np.array([np.exp(-(i - mu[s, t])**2/(2*(sigma[s, t]**2))) for i in np.arange(1, arms+1)])
            p[:, s, t] = p[:, s, t]/np.sum(p[:, s, t])

            # sample action
            actions = np.random.multinomial(1, p[:, s, t])
            a[s, t] = np.arange(arms, dtype = int)[actions.nonzero()[0][0]]+1

            # get reward 
            rand = np.random.uniform(0, 1)
            r[s, t] = 1 if rand <= rp[s][int(a[s, t]) - 1] else 0

            # reward prediction error
            delta = r[s, t] - V[s, t]
            if t<trials-1:
                # action update
                mu[s, t+1] = mu[s, t] + (a_mu*delta*(a[s, t] - mu[s, t]))

                # calculate state value
                V[s, t+1] = V[s, t] + a_r*delta

                # use state value as sigma?
                # sigma[t+1] = np.exp(-V[t+1]*0.9)
                # sigma[s, t+1] = f(V[s, t+1], a_, b_, c_)
                sigma[s, t+1] = f(V[s, t+1], a_, b_, c_)
    
    return mu, sigma, V, rp, r, a

def plot_sigma_function(s):
    plt.figure()
    plt.plot(V[s, :], sigma[s, :], '.-',color = 'xkcd:cornflower')
    plt.xlabel('V')
    plt.ylabel('sigma')

def plot_histogram_choices(mu, ax):
    ax.hist(mu.flatten(), bins = 100, color = 'xkcd:cornflower')
    ax.set_xlabel('mu')
    ax.set_ylabel('counts')

def plot_mu_sigma(s, mu, sigma, rp):
    colors = plt.cm.cool_r(np.linspace(0, 1, 8))
    for i in range(0, 160, 20):
        x = np.linspace(mu[s, i] - 3*sigma[s, i], mu[s, i] + 3*sigma[s, i], 100)
        plt.plot(x, norm.pdf(x, mu[s, i], sigma[s, i]), color = colors[i//20])
    # x = np.linspace(mu[0] - 3*sigma, mu[0] + 3*sigma, 100)
    # plt.plot(x, stats.norm.pdf(x, mu[0], sigma))
    plt.axvline(mu[s, t], color = 'r', linestyle = '--')
    plt.axvline(np.argmax(rp[s])+1, color = 'g', linewidth = 40, alpha = 0.1)
    plt.xlim(0,4)

def calc_prob(pk):
    # calc prob of actions
    _, counts = np.unique(np.array(pk), return_counts =True)
    outcomes = len(pk)
    return counts/outcomes

def create_dataframe(mu, sigma, V, rp, r, a):
    df = pd.DataFrame()
    window = 5
    df['action'] = a.flatten()
    df['reward'] = r.flatten()
    df['session'] = np.repeat(np.arange(sessions), trials)
    df['trial'] = np.tile(np.arange(trials), sessions)
    df['mu'] = mu.flatten()
    df['V'] = V.flatten()
    df['sigma'] = sigma.flatten()
    df['rewprob'] = np.repeat(rp, trials, axis = 0)[np.arange(trials*sessions), np.array(a.flatten() - 1, dtype = int)]
    df['regret'] = abs(df['rewprob'] - 0.8)
    df['rr'] = (df.groupby('session', as_index = False)
                .reward
                .rolling(window, center=True)
                .mean()
                .reward)
    df['choice_t1'] = df.groupby('session').action.shift(-1, fill_value = 0)
    df.loc[(df.choice_t1 == 0), 'choice_t1'] = df.loc[(df.choice_t1 == 0), 'action']
    df['choice_t2'] = df.groupby('session').action.shift(-2, fill_value = 0)
    df.loc[(df.choice_t2 == 0), 'choice_t2'] = df.loc[(df.choice_t2 == 0), 'action']
    df['shift_t0'] = (df['choice_t1']==df['action']).replace({True: 0, False: 1})
    df['shift_t1'] = (df['choice_t2']==df['action']).replace({True: 0, False: 1})
    df['disp'] = df['choice_t1']-df['action']
    df['entropy'] = (df.groupby('session', as_index = False)
                        .action
                        .rolling(window, center=True)
                        .apply(lambda x: entropy(calc_prob(x), base = 2))
                        .action)
    return df

# PG discrete choice simulation

In [2]:
trials = 100
sessions = 1000
arms = 4
a_mu = 0.2
a_r = 0.2

def f(x, a, b, c):
    '''sigmoid fxn: a is amplitude, b is steepness, c is offset'''
    return a / (1. + np.exp(-b * (x - c)))

a_ = 0.7399
b_ = -4.856
c_ = 0.606

# a_ = 2.2597172731
# b_ = -0.47236397750
# c_ = -1.3944851128
# a_ = 2
# b_ = -2.3235334564429153
# c_ = -0.005890831236694547

mu, sigma, V, rp, r, a = gaussian_policy_gradient(sessions, trials, a_mu, a_r, arms, a_, b_, c_, f)

In [5]:
# comparing parameter variation 
a_ = 1
c_ = 0
pn = 1
x = np.linspace(0, 1, 100)

fig = plt.figure(figsize=(20, 10))
for b_ in np.linspace(-5, 0, 9):
    ax = plt.subplot(3, 3, pn)
    ax.set_title(f'b_ = {b_}')
    mu, sigma, V, rp, r, a = gaussian_policy_gradient(sessions, trials, a_mu, a_r, arms, a_, b_, c_)
    plot_histogram_choices(mu, ax)
    axin = inset_axes(ax, width="30%", height="20%", loc='upper right')
    axin.plot(x, f(x, a_, b_, c_), color='coral', lw = 2)
    # axin.set_ylim(0, 1)
    ax.set_xlim(0, 5)
    pn+= 1
plt.tight_layout()

  plt.tight_layout()


In [87]:
plot_sigma_function(2)
plt.figure()
ax = plt.subplot(111)
plot_histogram_choices(mu, ax)

In [3]:
df = create_dataframe(mu, sigma, V, rp, r, a)
# potentially plot everything for this model, rr, entropy, tm, regret, distance, bias analysis, variability 
%matplotlib qt
fig = plt.figure(figsize = (10, 7))

def avg_mat(df, col):
    g = df.groupby('session').cumcount()
    L = np.array(df.set_index(['session',g])
           .unstack(fill_value=0)
           .stack().groupby(level=0)
           .apply(lambda x: x[col].values.tolist())
           .tolist())
    return L

secret_sauce = 'xkcd:cornflower'
# figure 1 - regret across all sessions
ax = plt.subplot(221)
reg_mat = avg_mat(df, 'regret')
reg_mean = np.mean(reg_mat, axis = 0)
reg_sem = sem(reg_mat, nan_policy = 'omit')
ax.plot(reg_mean, color = secret_sauce)
ax.fill_between(np.arange(reg_mat.shape[1]), reg_mean - reg_sem, reg_mean + reg_sem,  color = 'xkcd:cornflower', alpha = 0.2)
ax.set_title('Regret')

# figure 2 - performance plot across all sessions
ax = plt.subplot(222)
rr_mat = avg_mat(df, 'rr')
rr_mean = np.mean(rr_mat, axis = 0)
rr_sem = sem(rr_mat, nan_policy = 'omit')
ax.plot(rr_mean, color = secret_sauce)
ax.fill_between(np.arange(rr_mat.shape[1]), rr_mean - rr_sem, rr_mean + rr_sem,  color = secret_sauce, alpha = 0.2)
ax.set_title('Performance - reward rate')

# figure 3 - entropy plot across all sessions
ax = plt.subplot(223)
entropy_mat = avg_mat(df, 'entropy')
entropy_mean = np.mean(entropy_mat, axis = 0)
entropy_sem = sem(entropy_mat, nan_policy = 'omit')
ax.plot(entropy_mean, color = secret_sauce)
ax.fill_between(np.arange(entropy_mat.shape[1]), entropy_mean - entropy_sem,
                 entropy_mean + entropy_sem,  color = secret_sauce, alpha = 0.2)
ax.set_title('Entropy')

sns.despine()

# figure 4 - transition matrix
ax = plt.subplot(224)
sns.heatmap(pd.crosstab(df.action, df.choice_t1, normalize = 'index'),
            cmap = parula, annot = True, fmt = '.2f', vmin = 0.0, vmax = 0.3, 
            # mask = np.eye(4),
            xticklabels = np.arange(1,5), yticklabels = np.arange(1,5), ax = ax)
# ax.patch.set_facecolor('white')
ax.set_title('Transition matrix')

Text(0.5, 1.0, 'Transition matrix')

In [4]:
# distance traveled in the first n trials
fig = plt.figure(figsize = (10, 5))

def calc_dist_metric(tempdf, mask):
    mean_distance = pd.Series([2, 4/3, 4/3, 2], index=[1, 2, 3, 4])

    filtered = tempdf[mask]

    # calculate `d_value` for each group
    d_values = filtered.groupby('session')['disp'].apply(lambda x: x.abs().mean())

    # calculate `d_chance` for each group
    d_chances = filtered.groupby('session').apply(
        lambda group: np.sum(group['action'].value_counts(normalize=True) * mean_distance)
    )

    tempdf['d_value'] = tempdf.set_index('session').index.map(d_values).values
    tempdf['d_chance'] = tempdf.set_index('session').index.map(d_chances).values
    return tempdf

for enum, trial_group in enumerate(range(0, 101, 10)):    
    if trial_group == 0:
        continue
    ax = plt.subplot(2, 5, enum)

    tempdf = df[(df.trial.isin(np.arange(trial_group-9, trial_group)))]
    # compute switch probability after outcome
    rdf = calc_dist_metric(tempdf, mask = (tempdf.reward == 1) & (tempdf.shift_t0 == 1))
    y = rdf.d_value.mean()
    yerr = rdf.d_value.sem()
    chan = rdf.d_chance.mean()
    ax.bar(0, chan, color = 'grey', alpha = 0.2)
    ax.bar(0, y, color = 'red', yerr = yerr, label = 'R')

    nrdf = calc_dist_metric(tempdf, mask = (tempdf.reward == 0) & (tempdf.shift_t0 == 1))
    y = nrdf.d_value.mean()
    yerr = nrdf.d_value.sem()
    chan = nrdf.d_chance.mean()
    ax.bar(1, chan, color = 'grey', alpha = 0.2)
    ax.bar(1, y, color = 'blue', yerr = yerr, label = 'NR')
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_xticks([0, 1], ['R', 'NR'], fontsize = 'large', color = 'grey')
    ax.set_ylim(1, 1.7)
    ax.set_title(f'Trials {trial_group-9} to {trial_group}')
# fig.supxlabel('Outcome at trial t')
plt.legend()
fig.supylabel('Average switch distance at trial t+1')
sns.despine()
plt.tight_layout()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  tempdf['d_value'] = tempdf.set_index('session').index.map(d_values).values
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  tempdf['d_chance'] = tempdf.set_index('session').index.map(d_chances).values
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  tempdf['d_value'] = tempdf.set_index('session').index

In [None]:
plt.figure(figsize = (4.5, 5))
ax = plt.subplot(111)
rdf = calc_dist_metric(df, mask = (df.reward == 1) & (df.shift_t0 == 1))
nrdf = calc_dist_metric(df, mask = (df.reward == 0) & (df.shift_t0 == 1))
y = rdf.d_value.mean()
yerr = rdf.d_value.sem()
chan = rdf.d_chance.mean()
ax.bar(0, chan, color = 'grey', alpha = 0.2)
ax.bar(0, y, color = 'red', yerr = yerr, label = 'R')
y = nrdf.d_value.mean()
yerr = nrdf.d_value.sem()
chan = nrdf.d_chance.mean()
ax.bar(1, chan, color = 'grey', alpha = 0.2)
ax.bar(1, y, color = 'blue', yerr = yerr, label = 'NR')
plt.xticks([0, 1], ['R', 'NR'], fontsize = 'large', color = 'grey')
plt.yticks(color = 'grey', fontsize = 'large')
plt.ylabel('Average switch distance', fontsize = 'x-large')
plt.tight_layout()

In [6]:
fig = plt.figure(figsize = (10, 5))

for enum, trial_group in enumerate(range(0, 101, 10)):
    
    if trial_group == 0:
        continue
    ax = plt.subplot(2, 5, enum)
    
    temperdf = df[(df.trial.isin(np.arange(trial_group-9, trial_group)))]
    # compute switch probability after outcome
    ax.bar(0, temperdf[(temperdf.reward == 1)].shift_t0.mean(),
           color = 'red', 
           yerr = temperdf[(temperdf.reward == 1)].shift_t0.sem(), 
           label = 'R')
    ax.bar(1, temperdf[(temperdf.reward == 0)].shift_t0.mean(), 
           color = 'blue', 
           yerr = temperdf[(temperdf.reward == 0)].shift_t0.sem(),
           label = 'NR')
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_xticks([0, 1], ['R', 'NR'], fontsize = 'large', color = 'grey')
    ax.set_ylim(0, 0.7)
    ax.set_title(f'Trials {trial_group-9} to {trial_group}')

# fig.supxlabel('Outcome at trial t')
plt.legend()
fig.supylabel('Switch probability at trial t+1')
sns.despine()
plt.tight_layout()

In [7]:
plt.figure(figsize = (4.5, 5))
# yerr = tempdf[(tempdf.reward == 1) & (tempdf.shift_t0 == 1)].d_value.mean()
plt.bar([0], tempdf[(tempdf.reward == 1)].shift_t0.mean(), color = 'red')
plt.bar([1], tempdf[(tempdf.reward == 0)].shift_t0.mean(), color = 'blue')
plt.xticks([0, 1], ['R', 'NR'], fontsize = 'large', color = 'grey')
plt.yticks(color = 'grey', fontsize = 'large')
plt.ylabel('Average switch probability', fontsize = 'x-large')
plt.tight_layout()

In [85]:
df['trial_group'] = pd.cut(df['trial']+1, bins = np.arange(0, 101, 10), labels = np.arange(1, 11))
sns.pointplot(data = df, x = 'trial_group', y = 'shift_t0',
            hue = 'reward', palette = ['xkcd:cornflower', 'xkcd:coral'], markers = '.', linestyles = '--', capsize = 0.25, errorbar = None, lw = 2.5)

<Axes: xlabel='trial_group', ylabel='shift_t0'>

In [103]:
# random shot
mat = pd.crosstab(df.action, df.choice_t1, normalize = 'index').to_numpy()
l =[]
for i in range(-3, 4):
    l.append(np.mean(np.diagonal(mat, offset = i)))
plt.bar(np.arange(-3, 4), l, color = 'xkcd:cornflower')
sns.despine()
def gaussian(x, mean, amplitude, standard_deviation):
    return amplitude * np.exp( - (x - mean)**2 / (2*standard_deviation ** 2))
from scipy.optimize import curve_fit
popt, _ = curve_fit(gaussian, np.arange(-3, 4), l, p0=[1., 0., 1.])
plt.plot(np.arange(-3, 4), gaussian(np.arange(-3, 4), *popt), color = 'xkcd:coral', linewidth = 2)
plt.ylim(-0.01)

(-0.01, 0.698131287976535)

# PG fitting

In [8]:
sessdf = pd.read_csv('L:/4portProb_processed/sessdf.csv')
sessdf.drop(columns = 'Unnamed: 0', inplace = True)
window = 7
exclude = ['[ 20  20  20 100]', '[0 0 0 0]', '[0]', '[0 0]',
       '[1000   80]', '[30]', '[40]', '[70]']
sessdf = sessdf[~sessdf.rewprobfull.isin(exclude)]
sessdf = sessdf[~sessdf.duplicated(subset = ['animal', 'session', 'trialstart', 'eptime'], keep = False)]
bin_size = 50
sessdf['sess_bin'] = sessdf.groupby(['animal', 'task'])['session'].transform(lambda x: pd.cut(x, bins=range(0, x.max() + bin_size, bin_size), labels=False, right=False)+1)

trialsinsess = 100
data = sessdf[(sessdf.sess_bin>=4) & (sessdf.task == 'unstr')]
data = data.groupby(['animal','session']).filter(lambda x: x.reward.size >= trialsinsess).groupby(['animal','session']).head(trialsinsess)

In [136]:
# import pgfittingFunctions as pgf
def fitgPG(x0, sessdf, arms):
    a_mu, a_r, a_, b_, c_, mu_init, V_init, sigma_init = x0
    ll = 0
    
    sessions = sessdf['session'].nunique()
    trials = 100 # automate later
    p = np.zeros((arms, sessions, trials))
    mu = np.zeros((sessions, trials))
    V = np.zeros((sessions, trials))
    sigma = np.ones((sessions, trials))
    P = np.zeros((sessions, trials))
    
    for s, (_, group) in enumerate(sessdf.reset_index().groupby('session')):
        mu[s, 0] = mu_init
        V[s, 0] = V_init
        sigma[s, 0] = sigma_init
        for t, (_, trial) in enumerate(group.iterrows()):
            p[:, s, t] = np.array([np.exp(-(i - mu[s, t])**2/(2*(sigma[s, t]**2))) for i in np.arange(1, arms+1)])
            p[:, s, t] = p[:, s, t]/np.sum(p[:, s, t])

            # which action on this trial
            a = trial['port']
            index = int(a-1)

            # probability of selected action on this trial
            P[s, t] = p[index, s, t]

            # rewarded?
            r = trial['reward']

            # reward prediction error
            delta = r - V[s, t]
            if t<trials-1:
                # action update
                mu[s, t+1] = mu[s, t] + (a_mu*delta*(a - mu[s, t]))

                # calculate state value
                V[s, t+1] = V[s, t] + a_r*delta

                # use state value as sigma?
                # sigma[t+1] = np.exp(-V[t+1]*0.9)
                # sigma[s, t+1] = f(V[s, t+1], a_, b_, c_)
                sigma[s, t+1] = f(V[s, t+1], a_, b_, c_)
            else:break


    ll += np.nansum(np.log(P))
    nll = -ll
    return nll
from pybads import BADS
extra_params = 4
n = 0
k = 8
# a_mu, a_r, a_, b_, c_, mu_init, V_init, sigma_init
plb = np.array([0, 0, 0, -100, -1, 0, 0, 1e-5])
pub = np.array([1, 1, 10, 0, 1, 5, 10, 10])
n_trials = data.shape[0]
options = {}
def fun_for_pybads(x):
    return fitgPG(x, data, extra_params)
# for n in range(n_optim):
    # run multiple optimizations
    # print('Running optimization ' + str(n) + '...')
options['random_seed'] = n
bads = BADS(fun_for_pybads, None, plausible_lower_bounds= plb, plausible_upper_bounds=pub, options=options)
optimize_result = bads.optimize()
print([optimize_result.x, optimize_result.fval, optimize_result.success])
nll = optimize_result.fval
bic = k*np.log(n_trials) + 2*nll

Detected fully unconstrained optimization.
Initial starting point is invalid or not provided. Initial point randomly sampled uniformly from plausible box



  p[:, s, t] = p[:, s, t]/np.sum(p[:, s, t])
  ll += np.nansum(np.log(P))


ValueError: FunctionLogger:InvalidFuncValue:
            The returned function value must be a finite real-valued scalar
            (returned value inf)

In [None]:
from scipy.optimize import differential_evolution as de

de(
    func = fitgPG,
    bounds = [(0, 1), (0, 1), (0, 10), (-100, 0), (-1, 1), (0, 5), (0, 10), (1e-5, 10)],
    args = (data, extra_params),
    maxiter = 100,
    disp = True, workers = 4
)
    

  with DifferentialEvolutionSolver(func, bounds, args=args,
