## Code/Variables setup

In [1]:
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
import seaborn as sns

import util

In [2]:
########## Simulation parameters ##########
eta = 0.01          # Learning rate
n_trials0 = 1000    # Number of trials in the stationary environment
n_trials1 = 1000    # Number of trials after decreasing/increasing the reward probability
n_trials = n_trials0 + n_trials1
trials0 = list(range(n_trials0))
trials = list(range(n_trials))
GAMMA = 0.35        # Discount rate, for TD errors
REWARD_DELTA = 0.25 # Maximum quantity by which to increase or decrease the reward when it's time for volatility.

########## Reproducibility ##########
seed = 2021
rng = None          # We call `rng = np.random.default_rng(seed=seed)` at the start of each major section.

########## Plotting config ##########
sns.set_theme(
    style='white',
    palette='tab10',
    font_scale=1.15,
    rc={'figure.autolayout': True, 'figure.dpi': 300, 'figure.figsize': (8, 8),
        # 'font.family': 'Fira Code', 'font.weight': 'light',
        'font.family': 'serif',
        'text.latex.preamble': r'\usepackage{amsmath}', 'text.usetex': True,  # for \text command
        'xtick.bottom': True, 'ytick.left': True
    }
)
HIDE = True
def hideplt():      # Helper macro to not embed figures in this notebook when HIDE=True
    if HIDE: plt.close()

In [3]:
def calc_rpes_for_prob(reward_prob, discount_gamma, decrease):
    rewards = np.zeros(n_trials, dtype=np.int)
    rpes = np.zeros(n_trials)
    Q = 0.5  # Sort of a neutral prior, the mean of 0 and 1. Just helps convergence happen slightly more quickly.
    for i in trials:
        if i == n_trials0:
            if decrease:
                reward_prob = max(0, reward_prob - REWARD_DELTA)
            else:
                reward_prob = min(1, reward_prob + REWARD_DELTA)
        rewards[i] = int(rng.random() < reward_prob)
        rpes[i] = rewards[i] + discount_gamma * Q - Q
        Q += eta * rpes[i]
    return rpes, rewards

def calc_rpes_for_probs(reward_probs, discount_gamma, decrease):
    outputs = [calc_rpes_for_prob(p, discount_gamma=discount_gamma, decrease=decrease) for p in reward_probs]
    rpes = [o[0] for o in outputs]
    rewards = [o[1] for o in outputs]
    return rpes, rewards

# What do RPEs do?

Initial, exploratory foray into the (asymptotic) properties of RPEs in a stationary environment.

In [4]:
rng = np.random.default_rng(seed=seed)

probs = np.linspace(1, 0, num=9)
rpe_RW, reward_RW = calc_rpes_for_probs(probs, discount_gamma=0, decrease=True)
rpe_TD, reward_TD = calc_rpes_for_probs(probs, discount_gamma=GAMMA, decrease=True)

In [5]:
def plot_rpe_asymptote(rpes, type):
    f, axs = plt.subplots(3, 3, sharex=True, sharey=True)
    axs = axs.flatten()

    for i, (rpe, p) in enumerate(zip(rpes, probs)):
        sns.scatterplot(x=trials0, y=rpe[:n_trials0], linewidth=0, s=8, ax=axs[i])
        if p > 0:
            pos_line = 1 - p
            axs[i].axhline(y=pos_line, color='green', alpha=0.5, linestyle='dotted')
        if p < 1:
            neg_line = -1 * p
            axs[i].axhline(y=neg_line, color='red', alpha=0.5, linestyle='dotted')
        axs[i].set_title(f'$p={p}$')
        axs[i].set_xlabel('Trial')
        axs[i].set_ylabel('RPE')
        axs[i].set_ylim([-1, 1])

    f.suptitle('Simulated ' + type + ' RPEs under the optimal policy\n' +
               '\\normalsize{Stationary environments with varying reward probability $p$}',
               x=0, y=1, horizontalalignment='left', verticalalignment='top')
    legend_stuff = [
        matplotlib.lines.Line2D([0], [0], color='green', alpha=0.5, linestyle='dotted', label='Positive Limit, $y=1-p$'),
        matplotlib.lines.Line2D([0], [0], color='red', alpha=0.5, linestyle='dotted', label='Negative Limit, $y=-p$')
    ]
    f.legend(handles=legend_stuff, loc='upper right', bbox_to_anchor=(1, 0.9835))
    return f

f = plot_rpe_asymptote(rpe_RW, 'RW')
f.savefig('integrators/rpe-asmpytote-rw.pdf')
hideplt()

f = plot_rpe_asymptote(rpe_TD, 'TD')
f.savefig('integrators/rpe-asmptote-td.pdf')
hideplt()

In [6]:
def plot_rpe_histogram(rpes, name):

    f, axs = plt.subplots(3, 3, sharex=True, sharey=True)
    axs = axs.flatten()

    for i, (rpe, p) in enumerate(zip(rpes, probs)):
        positives = np.sum(rpe[:n_trials0] > 0) / n_trials0
        negatives = np.sum(rpe[:n_trials0] < 0) / n_trials0
        sns.barplot(x=['Positive', 'Negative'], y=[positives, negatives], palette=['green', 'red'], ax=axs[i])
        axs[i].axhline(y=p, color='green', alpha=0.5, linestyle='dotted')
        axs[i].axhline(y=1-p, color='red', alpha=0.5, linestyle='dotted')
        axs[i].set_title(f'$p={p}$')
        if i in [0, 3, 6]:
            axs[i].set_ylabel(f'Proportion over {n_trials0} trials')
        if i in [6, 7, 8]:
            axs[i].set_xlabel('RPE Sign')

    f.suptitle('Density histogram of ' + name + ' RPE sign under the optimal policy\n' +
               '\\large{Stationary environments with varying reward probability $p$}',
               x=0, y=1, horizontalalignment='left', verticalalignment='top')
    legend_stuff = [
        matplotlib.lines.Line2D([0], [0], color='green', alpha=0.5, linestyle='dotted', label='Reward prob, $p$'),
        matplotlib.lines.Line2D([0], [0], color='red', alpha=0.5, linestyle='dotted', label='Non-reward prob, $1-p$')
    ]
    f.legend(handles=legend_stuff, loc='upper right', bbox_to_anchor=(1, 0.9835))

    return f

f = plot_rpe_histogram(rpe_RW, 'RW')
f.savefig('integrators/rpesign-RW.pdf')
hideplt()

f = plot_rpe_histogram(rpe_TD, 'TD')
f.savefig('integrators/rpesign-TD.pdf')
hideplt()

# Deep Dives into Reward Integration

Remainder of the notebook. Two key figures to produce:

* Figure of Surprise against reward probability, for each integrator
* "Ridgelines" figure of smoothed surprise against trial, across various reward probabilities, for each integrator

In [7]:
def rint_double_neg(rpes):
    xis_f = util.shift_reward_integration(orig_weight_past=0.99, orig_weight_pres=1.5, past_shift=-0.009)
    xis_s = util.shift_reward_integration(orig_weight_past=0.99, orig_weight_pres=1.5, past_shift=0.002)
    dbar_f, dbar_s = np.zeros(n_trials), np.zeros(n_trials)
    for i, rpe in enumerate(rpes):
        prev_dbar_f = 0 if i == 0 else dbar_f[i - 1]
        prev_dbar_s = 0 if i == 0 else dbar_s[i - 1]
        dbar_f[i] = xis_f['past'] * prev_dbar_f + xis_f['pres'] * min(0, rpe)
        dbar_s[i] = xis_s['past'] * prev_dbar_s + xis_s['pres'] * min(0, rpe)
    dbar = np.minimum(0, dbar_f - dbar_s)
    return dbar

def rint_single_neg(rpes):
    xis = util.shift_reward_integration(orig_weight_past=0.99, orig_weight_pres=1.5, past_shift=0)
    dbar = np.zeros(n_trials)
    for i, rpe in enumerate(rpes):
        prev_dbar = 0 if i == 0 else dbar[i - 1]
        dbar[i] = xis['past'] * prev_dbar + xis['pres'] * min(0, rpe)
    return dbar

def rint_double_posneg(rpes):
    xis_f = util.shift_reward_integration(orig_weight_past=0.99, orig_weight_pres=1.5, past_shift=-0.009)
    xis_s = util.shift_reward_integration(orig_weight_past=0.99, orig_weight_pres=1.5, past_shift=0.002)
    dbar_f_pos, dbar_s_pos = np.zeros(n_trials), np.zeros(n_trials)
    dbar_f_neg, dbar_s_neg = np.zeros(n_trials), np.zeros(n_trials)
    for i, rpe in enumerate(rpes):
        prev_dbar_f_pos = 0 if i == 0 else dbar_f_pos[i - 1]
        prev_dbar_s_pos = 0 if i == 0 else dbar_s_pos[i - 1]
        prev_dbar_f_neg = 0 if i == 0 else dbar_f_neg[i - 1]
        prev_dbar_s_neg = 0 if i == 0 else dbar_s_neg[i - 1]
        dbar_f_pos[i] = xis_f['past'] * prev_dbar_f_pos + xis_f['pres'] * max(0, rpe)
        dbar_s_pos[i] = xis_s['past'] * prev_dbar_s_pos + xis_s['pres'] * max(0, rpe)
        dbar_f_neg[i] = xis_f['past'] * prev_dbar_f_neg + xis_f['pres'] * min(0, rpe)
        dbar_s_neg[i] = xis_s['past'] * prev_dbar_s_neg + xis_s['pres'] * min(0, rpe)
    dbar_pos = np.maximum(0, dbar_f_pos - dbar_s_pos)
    dbar_neg = np.minimum(0, dbar_f_neg - dbar_s_neg)
    dbar = np.absolute(dbar_pos + dbar_neg)
    return dbar

def rint_single_posneg(rpes):
    xis_pos = util.shift_reward_integration(orig_weight_past=0.99, orig_weight_pres=1.5, past_shift=0)
    xis_neg = util.shift_reward_integration(orig_weight_past=0.99, orig_weight_pres=1.5, past_shift=0)
    dbar_pos, dbar_neg = np.zeros(n_trials), np.zeros(n_trials)
    for i, rpe in enumerate(rpes):
        prev_dbar_pos = 0 if i == 0 else dbar_pos[i - 1]
        prev_dbar_neg = 0 if i == 0 else dbar_neg[i - 1]
        dbar_pos[i] = xis_pos['past'] * prev_dbar_pos + xis_pos['pres'] * max(0, rpe)
        dbar_neg[i] = xis_neg['past'] * prev_dbar_neg + xis_neg['pres'] * min(0, rpe)
    dbar = np.absolute(dbar_pos + dbar_neg)
    return dbar

def rint_daw(rpes, rewards, reward_lr=0.05):
    dbar, rbar = np.zeros(n_trials), np.zeros(n_trials)
    for i, (rpe, reward) in enumerate(zip(rpes, rewards)):
        prev_rbar = 0.5 if i == 0 else rbar[i - 1]
        rbar[i] = reward_lr * reward + (1 - reward_lr) * prev_rbar
        dbar[i] = rpe - rbar[i]
    return dbar

In [8]:
rng = np.random.default_rng(seed=seed)
probs = np.linspace(0, 1, num=101)

# RPEs for experiments where the reward probability decreases
RW = np.array([calc_rpes_for_probs(probs, discount_gamma=0, decrease=True) for _ in range(30)])
rpes_RW, rewards_RW = RW[:, 0], RW[:, 1]
TD = np.array([calc_rpes_for_probs(probs, discount_gamma=GAMMA, decrease=True) for _ in range(30)])
rpes_TD, rewards_TD = TD[:, 0], TD[:, 1]

# RPEs for experiments where the reward probability increases
RW_inc = np.array([calc_rpes_for_probs(probs, discount_gamma=0, decrease=False) for _ in range(30)])
rpes_RW_inc, rewards_RW_inc = RW_inc[:, 0], RW_inc[:, 1]
TD_inc = np.array([calc_rpes_for_probs(probs, discount_gamma=GAMMA, decrease=False) for _ in range(30)])
rpes_TD_inc, rewards_TD_inc = TD_inc[:, 0], TD_inc[:, 1]

In [9]:
# RW "effective dbars"
rint_d_n_all = np.array([[rint_double_neg(rpe) for rpe in rpes] for rpes in rpes_RW])  # e.g. shape = (runs, probs, trials)
rint_s_n_all = np.array([[rint_single_neg(rpe) for rpe in rpes] for rpes in rpes_RW])
rint_d_pn_all = np.array([[rint_double_posneg(rpe) for rpe in rpes] for rpes in rpes_RW])
rint_s_pn_all = np.array([[rint_single_posneg(rpe) for rpe in rpes] for rpes in rpes_RW])
rint_daw_all = np.array([[rint_daw(rpe, reward) for rpe, reward in zip(rpes, rewards)] for rpes, rewards in zip(rpes_RW, rewards_RW)])

######################################################################################################################################
# TD "effective dbars"

rint_d_n_all_td = np.array([[rint_double_neg(rpe) for rpe in rpes] for rpes in rpes_TD])  # e.g. shape = (runs, probs, trials)
rint_s_n_all_td = np.array([[rint_single_neg(rpe) for rpe in rpes] for rpes in rpes_TD])
rint_d_pn_all_td = np.array([[rint_double_posneg(rpe) for rpe in rpes] for rpes in rpes_TD])
rint_s_pn_all_td = np.array([[rint_single_posneg(rpe) for rpe in rpes] for rpes in rpes_TD])
rint_daw_all_td = np.array([[rint_daw(rpe, reward) for rpe, reward in zip(rpes, rewards)] for rpes, rewards in zip(rpes_TD, rewards_TD)])

######################################################################################################################################
# RW "effective dbars" where the reward prob increases rather than decreases

rint_d_n_all_inc = np.array([[rint_double_neg(rpe) for rpe in rpes] for rpes in rpes_RW_inc])  # e.g. shape = (runs, probs, trials)
rint_s_n_all_inc = np.array([[rint_single_neg(rpe) for rpe in rpes] for rpes in rpes_RW_inc])
rint_d_pn_all_inc = np.array([[rint_double_posneg(rpe) for rpe in rpes] for rpes in rpes_RW_inc])
rint_s_pn_all_inc = np.array([[rint_single_posneg(rpe) for rpe in rpes] for rpes in rpes_RW_inc])
rint_daw_all_inc = np.array([[rint_daw(rpe, reward) for rpe, reward in zip(rpes, rewards)] for rpes, rewards in zip(rpes_RW_inc, rewards_RW_inc)])

######################################################################################################################################
# TD "effective dbars" where the reward prob increases rather than decreases

rint_d_n_all_inc_td = np.array([[rint_double_neg(rpe) for rpe in rpes] for rpes in rpes_TD_inc])  # e.g. shape = (runs, probs, trials)
rint_s_n_all_inc_td = np.array([[rint_single_neg(rpe) for rpe in rpes] for rpes in rpes_TD_inc])
rint_d_pn_all_inc_td = np.array([[rint_double_posneg(rpe) for rpe in rpes] for rpes in rpes_TD_inc])
rint_s_pn_all_inc_td = np.array([[rint_single_posneg(rpe) for rpe in rpes] for rpes in rpes_TD_inc])
rint_daw_all_inc_td = np.array([[rint_daw(rpe, reward) for rpe, reward in zip(rpes, rewards)] for rpes, rewards in zip(rpes_TD_inc, rewards_TD_inc)])

In [10]:
# Check out the average dbar (effective, unscaled) for each reward integration mechanism,
# for each calculation of prediction error,
# in stationary environments across all reward probs.

def plot_rint_smooth(rint_all, name, type):
    # Maintain reward_prob and trial axes; smooth across "run" axis
    rint_smooth = np.mean(rint_all, axis=0)
    # Calculate mean in the last 100 trials of the stationary block
    means_by_rewardprob = np.mean([rint[n_trials0-100:n_trials0] for rint in rint_smooth], axis=-1)
    fg = sns.relplot(x=probs, y=means_by_rewardprob, linewidth=0, s=8, kind='scatter')
    fg.axes[0, 0].set_xlabel('Reward Probability')
    fg.axes[0, 0].set_ylabel(r'$\text{mean}(\text{effective }\bar \delta \text{, last 100 trials})$')
    fg.fig.suptitle('Mean of effective $\\bar \\delta$ upon value convergence with ' + type + ' RPE\n' +
                    '\\normalsize{' + name + ' integrator, optimal policy, stationary environment}\n' +
                    '\\normalsize{Smoothed across 30 runs}',
                    x=0, y=1, horizontalalignment='left', verticalalignment='top')
    return fg

plot_rint_smooth(rint_d_n_all, 'Two-Timescale Negative', 'RW').savefig('integrators/converged-mean-doubleneg.pdf')
hideplt()
plot_rint_smooth(rint_s_n_all, 'One-Timescale Negative', 'RW').savefig('integrators/converged-mean-singleneg.pdf')
hideplt()
plot_rint_smooth(rint_d_pn_all, 'Two-Timescale Pos/Neg', 'RW').savefig('integrators/converged-mean-doubleposneg.pdf')
hideplt()
plot_rint_smooth(rint_s_pn_all, 'One-Timescale Pos/Neg', 'RW').savefig('integrators/converged-mean-singleposneg.pdf')
hideplt()
plot_rint_smooth(rint_daw_all, 'Daw et al. DA/5HT', 'RW').savefig('integrators/converged-mean-daw.pdf')
hideplt()

######################################################################################################################################

plot_rint_smooth(rint_d_n_all_td, 'Two-Timescale Negative', 'TD').savefig('integrators/converged-mean-doubleneg-TD.pdf')
hideplt()
plot_rint_smooth(rint_s_n_all_td, 'One-Timescale Negative', 'TD').savefig('integrators/converged-mean-singleneg-TD.pdf')
hideplt()
plot_rint_smooth(rint_d_pn_all_td, 'Two-Timescale Pos/Neg', 'TD').savefig('integrators/converged-mean-doubleposneg-TD.pdf')
hideplt()
plot_rint_smooth(rint_s_pn_all_td, 'One-Timescale Pos/Neg', 'TD').savefig('integrators/converged-mean-singleposneg-TD.pdf')
hideplt()
plot_rint_smooth(rint_daw_all_td, 'Daw et al. DA/5HT', 'TD').savefig('integrators/converged-mean-daw-TD.pdf')
hideplt()

In [11]:
# Heuristic to estimate the "window" of effective dbar to consider: 25% and 75% quantiles.
def quantile_rint_all(rint_all, q=[0.25, 0.75]):
    x = np.array([rint[:, 0:250] for rint in rint_all]).flatten()
    return np.quantile(x, q)

q_d_n = quantile_rint_all(rint_d_n_all)
q_s_n = quantile_rint_all(rint_s_n_all)
q_d_pn = quantile_rint_all(rint_d_pn_all)
q_s_pn = quantile_rint_all(rint_s_pn_all)
q_daw = quantile_rint_all(rint_daw_all)
print('RW', q_d_n, q_s_n, q_d_pn, q_s_pn, q_daw, sep='\n')

q_d_n_td = quantile_rint_all(rint_d_n_all_td)
q_s_n_td = quantile_rint_all(rint_s_n_all_td)
q_d_pn_td = quantile_rint_all(rint_d_pn_all_td)
q_s_pn_td = quantile_rint_all(rint_s_pn_all_td)
q_daw_td = quantile_rint_all(rint_daw_all_td)
print('\nTD', q_d_n_td, q_s_n_td, q_d_pn_td, q_s_pn_td, q_daw_td, sep='\n')

# pd.Series(rint_daw_all.flatten()).describe()

RW
[-9.80224027 -3.07024636]
[-28.58822694 -10.01080247]
[1.36095501 6.24649194]
[ 4.90294435 16.69165636]
[-0.79165542 -0.20932016]

TD
[-8.08310697 -2.95642527]
[-23.38970872  -7.32741487]
[1.61127077 7.76406534]
[ 6.02918143 23.17503487]
[-0.68944206 -0.14724725]


In [12]:
# Sort of linear sigmoid function. Transforms and scales the inputs to values between 1 and 0.
def window_rint(rint, bound0, bound1):
    adjusted = (rint - bound0) / (bound1 - bound0)
    return np.clip(adjusted, 0, 1)

######################################################################################################################################
# RW "scaled dbar" variables.

# Calculate 30 runs of reward integration for each mechanism. These will be used to smooth the ridgeline plots.
rint_d_n_all_scale = np.array([window_rint(rint, q_d_n[1], q_d_n[0]) for rint in rint_d_n_all])
rint_s_n_all_scale = np.array([window_rint(rint, q_s_n[1], q_s_n[0]) for rint in rint_s_n_all])
rint_d_pn_all_scale = np.array([window_rint(rint, q_d_pn[0], q_d_pn[1]) for rint in rint_d_pn_all])
rint_s_pn_all_scale = np.array([window_rint(rint, q_s_pn[0], q_s_pn[1]) for rint in rint_s_pn_all])  # todo ! fix.
rint_daw_all_scale = np.array([window_rint(rint, 1, -2) for rint in rint_daw_all])  # NOTE not q_daw[1], q_daw[0]

rint_d_n_all_scale_inc = np.array([window_rint(rint, q_d_n[1], q_d_n[0]) for rint in rint_d_n_all_inc])
rint_s_n_all_scale_inc = np.array([window_rint(rint, q_s_n[1], q_s_n[0]) for rint in rint_s_n_all_inc])
rint_d_pn_all_scale_inc = np.array([window_rint(rint, q_d_pn[0], q_d_pn[1]) for rint in rint_d_pn_all_inc])
rint_s_pn_all_scale_inc = np.array([window_rint(rint, q_s_pn[0], q_s_pn[1]) for rint in rint_s_pn_all_inc])
rint_daw_all_scale_inc = np.array([window_rint(rint, 1, -2) for rint in rint_daw_all_inc])  # NOTE not q_daw[1], q_daw[0]

# print(pd.Series(rint_daw_0_scale.flatten()).describe())
# print(pd.Series(window_rint(rint_daw_all[0], 1, -2).flatten()).describe())

######################################################################################################################################
# TD "scaled dbar" variables.

# Calculate 30 runs of reward integration for each mechanism. These will be used to smooth the ridgeline plots.
rint_d_n_all_scale_td = np.array([window_rint(rint, q_d_n_td[1], q_d_n_td[0]) for rint in rint_d_n_all_td])
rint_s_n_all_scale_td = np.array([window_rint(rint, q_s_n_td[1], q_s_n_td[0]) for rint in rint_s_n_all_td])
rint_d_pn_all_scale_td = np.array([window_rint(rint, q_d_pn_td[0], q_d_pn_td[1]) for rint in rint_d_pn_all_td])
rint_s_pn_all_scale_td = np.array([window_rint(rint, q_s_pn_td[0], q_s_pn_td[1]) for rint in rint_s_pn_all_td])  # todo ! fix.
rint_daw_all_scale_td = np.array([window_rint(rint, 1, -2) for rint in rint_daw_all_td])  # NOTE not q_daw[1], q_daw[0]

rint_d_n_all_scale_inc_td = np.array([window_rint(rint, q_d_n_td[1], q_d_n_td[0]) for rint in rint_d_n_all_inc_td])
rint_s_n_all_scale_inc_td = np.array([window_rint(rint, q_s_n_td[1], q_s_n_td[0]) for rint in rint_s_n_all_inc_td])
rint_d_pn_all_scale_inc_td = np.array([window_rint(rint, q_d_pn_td[0], q_d_pn_td[1]) for rint in rint_d_pn_all_inc_td])
rint_s_pn_all_scale_inc_td = np.array([window_rint(rint, q_s_pn_td[0], q_s_pn_td[1]) for rint in rint_s_pn_all_inc_td])
rint_daw_all_scale_inc_td = np.array([window_rint(rint, 1, -2) for rint in rint_daw_all_inc_td])  # NOTE not q_daw[1], q_daw[0]

In [13]:
def calc_surprise_1state(dbscale):
    # Agent "attention"
    I_SC = 0.5
    w_A = (1 - dbscale) * I_SC + dbscale
    # Cue prototypes
    mu, sigma = 1.0, 0.05
    cue = 1.2  # Changing this seems to only change the linear scaling of the y-axis. Neato.
    # Activation and surprise
    A = scipy.stats.norm.pdf(x=cue*w_A, loc=mu*w_A, scale=sigma)
    A = np.maximum(A, util.TINYNONZERO) 
    F = -1 * np.log(A)
    return F

def calc_surprise_2state(dbscale):
    # Agent "attention"
    I_SC = np.array([0.99, 0.01])
    w_As = [[((1 - d) * I_SC[0] + d, (1 - d) * I_SC[1] + d) for d in dbscale_p] for dbscale_p in dbscale]
    w_As = np.array(w_As)  # shape (100, 1000, 2) = (reward_probs, trials, n_cues)
    # Cue prototypes
    mu = np.array([1.0, 0.5])
    proto0 = np.array([1.0, 0])
    proto1 = np.array([1.0, 1])
    cues0 = proto0 + rng.normal(loc=0, scale=0.05, size=(100000, 2))
    cues1 = proto1 + rng.normal(loc=0, scale=0.05, size=(100000, 2))
    cues = np.concatenate([cues0, cues1])
    Sigma = np.cov(cues, rowvar=False)
    cue = np.array([1.1, -0.1])
    # Activation and surprise
    A = np.zeros_like(dbscale)
    for p in range(len(w_As)):
        for t in range(len(w_As[p])):
            A[p, t] = scipy.stats.multivariate_normal.pdf(x=w_As[p, t] * cue, mean=w_As[p, t] * mu, cov=Sigma)
    A = np.maximum(A, util.TINYNONZERO)
    F = -1 * np.log(A)
    return F

######################################################################################################################################
# RW surprise variables.

surp_d_n_all = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_d_n_all_scale])
surp_s_n_all = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_s_n_all_scale])
surp_d_pn_all = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_d_pn_all_scale])
surp_s_pn_all = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_s_pn_all_scale])
surp_daw_all = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_daw_all_scale])

surp_d_n_all_inc = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_d_n_all_scale_inc])
surp_s_n_all_inc = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_s_n_all_scale_inc])
surp_d_pn_all_inc = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_d_pn_all_scale_inc])
surp_s_pn_all_inc = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_s_pn_all_scale_inc])
surp_daw_all_inc = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_daw_all_scale_inc])

# Just pick the first run from each for our main figure, showing surprise by reward prob for each mechanism.
surp_d_n = surp_d_n_all[0]
surp_s_n = surp_s_n_all[0]
surp_d_pn = surp_d_pn_all[0]
surp_s_pn = surp_s_pn_all[0]
surp_daw = surp_daw_all[0]

######################################################################################################################################
# TD surprise variables.

surp_d_n_all_td = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_d_n_all_scale_td])
surp_s_n_all_td = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_s_n_all_scale_td])
surp_d_pn_all_td = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_d_pn_all_scale_td])
surp_s_pn_all_td = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_s_pn_all_scale_td])
surp_daw_all_td = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_daw_all_scale_td])

surp_d_n_all_inc_td = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_d_n_all_scale_inc_td])
surp_s_n_all_inc_td = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_s_n_all_scale_inc_td])
surp_d_pn_all_inc_td = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_d_pn_all_scale_inc_td])
surp_s_pn_all_inc_td = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_s_pn_all_scale_inc_td])
surp_daw_all_inc_td = np.array([calc_surprise_1state(rint_scale) for rint_scale in rint_daw_all_scale_inc_td])

# Just pick the first run from each for our main figure, showing surprise by reward prob for each mechanism.
surp_d_n_td = surp_d_n_all_td[0]
surp_s_n_td = surp_s_n_all_td[0]
surp_d_pn_td = surp_d_pn_all_td[0]
surp_s_pn_td = surp_s_pn_all_td[0]
surp_daw_td = surp_daw_all_td[0]

In [14]:
# sns.color_palette('tab10')

def plot_summary_subplots(vals, names, t0, t1, valname, ylab, color_pt):
    colors = sns.color_palette('tab10')
    f, axs = plt.subplots(3, 2, sharex=True, sharey=True)
    axs = axs.flatten()

    for i, (val, name) in enumerate(zip(vals, names)):
        val_last100 = val[:, t0:t1]
        IQRs = np.quantile(val_last100, q=[0.25, 0.75], axis=-1)
        axs[i].fill_between(x=probs, y1=IQRs[0], y2=IQRs[1], alpha=0.2, color=colors[0])
        sns.scatterplot(x=probs, y=np.median(val_last100, axis=-1), linewidth=0, s=8, color=colors[color_pt], ax=axs[i])
        axs[i].set_title(name)
        axs[i].set_xlabel('Reward Probability')
        axs[i].set_ylabel(ylab)

    axs[-1].set_frame_on(False)
    axs[-1].get_yaxis().set_visible(False)
    xmin, xmax = axs[-1].get_xaxis().get_view_interval()
    ymin, _ = axs[-1].get_yaxis().get_view_interval()
    axs[-1].add_artist(matplotlib.lines.Line2D((xmin, xmax), (ymin, ymin), color='black', linewidth=2))
    axs[-1].set_xlabel('Reward Probability')

    legend_stuff = [
        matplotlib.lines.Line2D([0], [0], color=colors[color_pt], linestyle='dotted', label='Median'),
        matplotlib.patches.Patch(color=colors[0], alpha=0.2, label='Interquartile Range')
    ]
    f.legend(handles=legend_stuff, loc='upper right', bbox_to_anchor=(1, 0.9835))
    f.suptitle(valname + '\n' +
           '\\normalsize{By reward probability $p$, per-mechanism}',
           x=0, y=1, horizontalalignment='left', verticalalignment='top')
    return f

In [15]:
rint_d_n_0_scale = rint_d_n_all_scale[0]
rint_s_n_0_scale = rint_s_n_all_scale[0]
rint_d_pn_0_scale = rint_d_pn_all_scale[0]
rint_s_pn_0_scale = rint_s_pn_all_scale[0]
rint_daw_0_scale = rint_daw_all_scale[0]

scaled = [rint_d_n_0_scale, rint_s_n_0_scale, rint_d_pn_0_scale, rint_s_pn_0_scale, rint_daw_0_scale]
names = ['Two-Timescale Negative', 'One-Timescale Negative', 'Two-Timescale Pos/Neg', 'One-Timescale Pos/Neg', 'Daw et al. DA/5HT']
f = plot_summary_subplots(scaled, names, t0=n_trials0-100, t1=n_trials0,
    valname='$\\bar \delta_{\\text{scaled}}$ upon value convergence in stationary environment, RW RPE',
    ylab=r'$\bar \delta_{\text{scaled}}$ last 100 trials', color_pt=1)
f.savefig('integrators/summary_dbscaled_stationary-RW.pdf')
hideplt()

rint_d_n_0_scale_td = rint_d_n_all_scale_td[0]
rint_s_n_0_scale_td = rint_s_n_all_scale_td[0]
rint_d_pn_0_scale_td = rint_d_pn_all_scale_td[0]
rint_s_pn_0_scale_td = rint_s_pn_all_scale_td[0]
rint_daw_0_scale_td = rint_daw_all_scale_td[0]

scaled = [rint_d_n_0_scale_td, rint_s_n_0_scale_td, rint_d_pn_0_scale_td, rint_s_pn_0_scale_td, rint_daw_0_scale_td]
names = ['Two-Timescale Negative', 'One-Timescale Negative', 'Two-Timescale Pos/Neg', 'One-Timescale Pos/Neg', 'Daw et al. DA/5HT']
f = plot_summary_subplots(scaled, names, t0=n_trials0-100, t1=n_trials0,
    valname='$\\bar \delta_{\\text{scaled}}$ upon value convergence in stationary environment, TD RPE',
    ylab=r'$\bar \delta_{\text{scaled}}$ last 100 trials', color_pt=1)
f.savefig('integrators/summary_dbscaled_stationary-TD.pdf')
hideplt()

In [16]:
surps = [surp_d_n, surp_s_n, surp_d_pn, surp_s_pn, surp_daw]
names = ['Two-Timescale Negative', 'One-Timescale Negative', 'Two-Timescale Pos/Neg', 'One-Timescale Pos/Neg', 'Daw et al. DA/5HT']
f = plot_summary_subplots(surps, names, t0=n_trials0-100, t1=n_trials0,
    valname='Surprise upon value convergence in stationary environment, RW RPE',
    ylab='Surprise last 100 trials', color_pt=1)
f.savefig('integrators/summary_surprise_1d_stationary-RW.pdf')
hideplt()

surps = [surp_d_n_td, surp_s_n_td, surp_d_pn_td, surp_s_pn_td, surp_daw_td]
names = ['Two-Timescale Negative', 'One-Timescale Negative', 'Two-Timescale Pos/Neg', 'One-Timescale Pos/Neg', 'Daw et al. DA/5HT']
f = plot_summary_subplots(surps, names, t0=n_trials0-100, t1=n_trials0,
    valname='Surprise upon value convergence in stationary environment, TD RPE',
    ylab='Surprise last 100 trials', color_pt=1)
f.savefig('integrators/summary_surprise_1d_stationary-TD.pdf')
hideplt()

In [17]:
# surp2_d_n = calc_surprise_2state(rint_d_n_0_scale)
# surp2_s_n = calc_surprise_2state(rint_s_n_0_scale)

# surp2_d_pn = calc_surprise_2state(rint_d_pn_0_scale)
# surp2_s_pn = calc_surprise_2state(rint_s_pn_0_scale)

# surp2_daw = calc_surprise_2state(rint_daw_0_scale)

# surps = [surp2_d_n, surp2_s_n, surp2_d_pn, surp2_s_pn, surp2_daw]
# names = ['Two-Timescale Negative', 'One-Timescale Negative', 'Two-Timescale Pos/Neg', 'One-Timescale Pos/Neg', 'Daw et al. DA/5HT']
# f = plot_summary_subplots(surps, names, t0=n_trials0-100, t1=n_trials0,
#     valname='Surprise upon value convergence in stationary environment',
#     ylab='Surprise last 100 trials', color_pt=1)
# f.savefig('integrators/summary_surprise_2d_stationary.pdf')

# f = plot_summary_subplots(surps, names, t0=n_trials0, t1=n_trials0+200,
#     valname='Surprise after reward volatility',
#     ylab='Surprise in trials after $p$ change', color_pt=3)
# f.savefig('integrators/summary_surprise_2d_volatile.pdf')

In [18]:
c = ['#d2e3f0', '#f6d3d4', '#d4ecd4']  # Corresponds to a few colors in the 'tab10' palette at 20% transparency

def surp_ridgelines(surp_all, ax, color_i):
    'Plots the Surprise ridgeline figure into a matplotlib axis object.'
    # Plotting setup
    overlap = 0.25
    idx = np.arange(0, 101, 10)
    prob_labels = np.round(probs[idx], decimals=2)
    # Calculate x, y values
    xrange = np.array(trials)
    surp_smooth = np.mean(surp_all, axis=0)
    curves = surp_smooth[idx]
    curves -= np.min(curves)  # Rebase each surprise curve to be between 0 and 1, in a comparable fashion
    curves /= np.max(curves)
    offsets = np.array(range(11)) * (1 - overlap)  # y-value offset for each new curve
    curves += offsets.reshape(11, 1)
    # Ridelines! And fills.
    for i, (curve, offset) in enumerate(zip(curves, offsets)):
        ax.fill_between(x=xrange[:n_trials0], y1=curve[:n_trials0], y2=offset, color=c[0])
        ax.fill_between(x=xrange[n_trials0:], y1=curve[n_trials0:], y2=offset, color=c[color_i])
        # Seaborn makes things a little slower, but it plays nicely with automatically hiding redundant features in our subplots
        sns.lineplot(x=xrange, y=curve, color='black', alpha=0.8, linewidth=0.5, ax=ax)
    # Offset tickmarks by just above the "0" mark for each ridge
    ax.set_yticks(offsets + 0.1)
    ax.set_yticklabels(prob_labels)

def plot_ridgeline_subplots(vals, names, decrease, kind):
    f, axs = plt.subplots(3, 2, sharex=True, sharey=True)
    axs = axs.flatten()
    color_i = 1 if decrease else 2
    for i, (val, name) in enumerate(zip(vals, names)):
        surp_ridgelines(val, axs[i], color_i)
        axs[i].set_title(name)
        axs[i].set_xlabel('Trial')
        axs[i].set_ylabel('Initial $p$')

    axs[-1].set_frame_on(False)
    axs[-1].get_yaxis().set_visible(False)
    xmin, xmax = axs[-1].get_xaxis().get_view_interval()
    ymin, _ = axs[-1].get_yaxis().get_view_interval()
    axs[-1].add_artist(matplotlib.lines.Line2D((xmin, xmax), (ymin, ymin), color='black', linewidth=2))
    axs[-1].set_xlabel('Trial')

    legend_stuff = [
        matplotlib.patches.Patch(facecolor=c[0], edgecolor='black', linewidth=0.5,
                                label='Surprise in stationary environment'),
        matplotlib.patches.Patch(facecolor=c[color_i], edgecolor='black', linewidth=0.5,
                                label='Surprise after ' + ('decrease' if decrease else 'increase') + ' in $p$')
    ]
    f.legend(handles=legend_stuff, loc='upper right', bbox_to_anchor=(1, 0.9835))
    f.suptitle('Surprise in reward-volatile environment with ' + ('decreasing' if decrease else 'increasing') + ' rewards, ' + kind + ' RPE\n'
           '\\normalsize{By reward probability $p$, per-mechanism}\n' +
           '\\normalsize{Smoothed across 30 runs}',
           x=0, y=1, horizontalalignment='left', verticalalignment='top')
    return f

In [19]:
names = ['Two-Timescale Negative', 'One-Timescale Negative', 'Two-Timescale Pos/Neg', 'One-Timescale Pos/Neg', 'Daw et al. DA/5HT']

surps_all = [surp_d_n_all, surp_s_n_all, surp_d_pn_all, surp_s_pn_all, surp_daw_all]
f = plot_ridgeline_subplots(surps_all, names, decrease=True, kind='RW')
f.savefig('integrators/ridgelines-surprise-dec-RW.pdf')
hideplt()

surps_all = [surp_d_n_all_inc, surp_s_n_all_inc, surp_d_pn_all_inc, surp_s_pn_all_inc, surp_daw_all_inc]
f = plot_ridgeline_subplots(surps_all, names, decrease=False, kind='RW')
f.savefig('integrators/ridgelines-surprise-inc-RW.pdf')
hideplt()

surps_all = [surp_d_n_all_td, surp_s_n_all_td, surp_d_pn_all_td, surp_s_pn_all_td, surp_daw_all_td]
f = plot_ridgeline_subplots(surps_all, names, decrease=True, kind='TD')
f.savefig('integrators/ridgelines-surprise-dec-TD.pdf')
hideplt()

surps_all = [surp_d_n_all_inc_td, surp_s_n_all_inc_td, surp_d_pn_all_inc_td, surp_s_pn_all_inc_td, surp_daw_all_inc_td]
f = plot_ridgeline_subplots(surps_all, names, decrease=False, kind='TD')
f.savefig('integrators/ridgelines-surprise-inc-TD.pdf')
hideplt()

# Session Load/Save

In [20]:
import dill

dill.dump_session('integrators/session.pkl')
# dill.load_session('integrators/session.pkl')