In [177]:
import numpy as np
from banditpy.core import Bandit2Arm
from banditpy.plots import plot_trial_by_trial_2Arm
from neuropy import plotting

n_sim = 500
# probs = np.arange(0.1, 1, 0.1)
probs = [0.2, 0.3, 0.4, 0.6, 0.7, 0.8, 0.9]


def run_thomp(delta_s, delta_f, tau):

    choices = []
    rewards = []
    reward_probs = []
    session_ids = []
    for i in range(100):
        # reward_probs_i = np.random.choice(probs, size=2, replace=False)
        reward_probs_i = [p := np.random.choice(probs), 1 - p]
        alpha = np.ones(2)
        beta = np.ones(2)
        for tr in range(100):
            samples = np.random.beta(alpha[:, None], beta[:, None], size=(2, n_sim))
            selected = np.argmax(samples, axis=0)
            choice_prob = np.array([1 - selected.mean(), selected.mean()])
            choice = np.random.choice([0, 1], p=choice_prob)
            random_num = np.random.rand()

            alpha = 1.0 + (alpha - 1.0) * tau
            beta = 1.0 + (beta - 1.0) * tau

            if random_num < reward_probs_i[choice]:
                alpha[choice] += delta_s
                rewards.append(1)
            else:
                beta[choice] += delta_f
                rewards.append(0)

            choices.append(choice)
            session_ids.append(i)
            reward_probs.append(reward_probs_i)

    choices = np.array(choices)
    rewards = np.array(rewards)
    reward_probs = np.array(reward_probs)
    session_ids = np.array(session_ids)

    return choices, rewards, reward_probs, session_ids


fig = plotting.Fig(8, 4, fontsize=10)

params = [[7, 3, 0.5], [5, 5, 0.7], [6, 4, 0.8], [5, 8, 0.9]]
for i, (delta_s, delta_f, tau) in enumerate(params):
    choices, rewards, reward_probs, session_ids = run_thomp(delta_s, delta_f, tau)
    task = Bandit2Arm(
        probs=reward_probs, choices=choices, rewards=rewards, session_ids=session_ids
    )
    perf = task.get_optimal_choice_probability()
    ax = fig.subplot(fig.gs[:3, i])
    plot_trial_by_trial_2Arm(task, ax=ax, sort_by_deltaprob=True)
    ax.set_title(f"deltaS={delta_s}, deltaF={delta_f}, tau={tau}")

    ax2 = fig.subplot(fig.gs[3, i])
    ax2.plot(np.arange(100), perf, color="k")
    ax2.set_ylim(0.4, 1.0)
    ax2.set_xlabel("Trial")
    ax2.set_ylabel("Pr(High)")
    ax2.grid(axis="y")

In [None]:
plot_trial_by_trial_2Arm(task, sort_by_deltaprob=True)

In [175]:
from scipy.stats import beta as beta_dist
import matplotlib.pyplot as plt
from neuropy import plotting

x = np.linspace(0, 1, 100)
alpha = 1
beta = 1
prob = 0.3

fig = plotting.Fig(20, 5)
for i in range(20):
    pdf_values = beta_dist.pdf(x, alpha, beta)

    rand_val = np.random.rand()

    if rand_val < prob:
        alpha += 1
    else:
        beta += 1

    ax = fig.subplot(fig.gs[i])
    ax.fill_between(x, pdf_values, alpha=0.5, color="green")
    ax.set_ylim(0, 8)
    ax.axvline(prob, color="k", ls="--")

In [198]:
import seaborn as sns
import pandas as pd
from neuropy import plotting
from mab_colors import colors_2arm
from statplotannot.plots import fix_legend
from pathlib import Path

# data = {
#     "name": [
#         "GrumpExp1Unstruc",
#         "BratExp1Unstruc",
#         "ToothlessExp1Struc",
#         "BewilderbeastExp1Struc",
#         "GronckleExp1Struc",
#         "AuromaExp1Unstruc",
#         "AggroExp1Unstruc",
#     ],
#     "delta_s": [7.29, 2.30, 9.27, 3.07, 5.94, 7.16, 8.86],
#     "delta_f": [2.04, 2.08, 6.11, 0.51, 2.42, 0.51, 5.86],
#     "tau": [0.76, 0.86, 0.75, 0.96, 0.87, 0.8, 0.75],
#     "grp": ["unstruc", "unstruc", "struc", "struc", "struc", "unstruc", "unstruc"],
# }
# data = pd.DataFrame(data)

fp = Path("D:/Data/mab/thomp_params_delta_tau_2.csv")
data = pd.read_csv(fp, sep=",")
data = pd.melt(data, id_vars=["sub_name", "grp", "first_experience"], var_name="param")
data = data[data["first_experience"] == True]

fig = plotting.Fig(5, 4, fontsize=10)

ax = fig.subplot(fig.gs[1])

sns.barplot(
    data,
    x="param",
    y="value",
    hue="grp",
    hue_order=["unstruc", "struc"],
    ax=ax,
    errorbar="se",
    palette=colors_2arm(1.2),
)
sns.stripplot(
    data=data,
    x="param",
    y="value",
    hue="grp",
    hue_order=["unstruc", "struc"],
    ax=ax,
    dodge=True,
    palette=["gray", "gray"],
    size=5,
)
ax.set_xlabel("")
ax.set_ylabel("Parameter value")
ax.set_title("Estimating Thompson parameters")
fix_legend(ax, loc="upper right")

In [192]:
data

Unnamed: 0,sub_name,grp,first_experience,param,value
0,AggroExp1Unstructured,unstruc,True,delta_s,8.86697
1,AuromaExp1Unstructured,unstruc,True,delta_s,7.168745
2,BratExp1Unstructured,unstruc,True,delta_s,2.30535
3,GronckleExp2Unstructured,unstruc,False,delta_s,3.455762
4,GrumpExp1Unstructured,unstruc,True,delta_s,7.296441
5,ToothlessExp2Unstructured,unstruc,False,delta_s,9.439173
6,BewilderbeastExp1Structured,struc,True,delta_s,3.077212
7,BuffalordExp1Structured,struc,True,delta_s,9.081098
8,GronckleExp1Structured,struc,True,delta_s,5.94163
9,GrumpExp2Structured,struc,False,delta_s,5.502178
