### Fit Qlerning params from one env to other.
- Take average of parameters like alpha_chosen, persev etc. from structured environment and predict actions in unstructured env. Following this calculate performance and compare that to actual performance. 

In [21]:
import numpy as np
import pandas as pd
from banditpy.analyses import QlearningEstimator
import mab_subjects
from banditpy.core import MultiArmedBandit

exps = mab_subjects.struc.first_exposure + mab_subjects.unstruc.first_exposure
est_params_df = mab_subjects.GroupData().qlearning_2alpha_persev_anirudh
mean_params_df = (
    est_params_df.groupby(["grp", "param"]).mean(numeric_only=True).reset_index()
)

mean_struc = mean_params_df[mean_params_df["grp"] == "struc"]
mean_unstruc = mean_params_df[mean_params_df["grp"] == "unstruc"]
param_order = ["alpha_chosen", "alpha_unchosen", "persev", "scaler", "beta"]

perf_df = []

for i, exp in enumerate(exps):
    mab = exp.mab.filter_by_trials(min_trials=100, clip_max=100)
    name = exp.sub_name
    print(name)

    qlearn = QlearningEstimator(mab, bounds=None, model="persev")

    if mab.is_structured:
        df = mean_unstruc.copy()
    else:
        df = mean_struc.copy()

    params = np.array(
        [
            df.loc[df["param"] == "alpha_chosen", "param_values"].to_numpy(),
            df.loc[df["param"] == "alpha_unchosen", "param_values"].to_numpy(),
            df.loc[df["param"] == "persev", "param_values"].to_numpy(),
            df.loc[df["param"] == "scaler", "param_values"].to_numpy(),
            df.loc[df["param"] == "beta", "param_values"].to_numpy(),
        ]
    )
    predicted_choices = qlearn.predict_choices(params=params)
    predicted_choices[predicted_choices == 0] = 2
    predicted_choices[predicted_choices == 1] = 1

    new_mab = MultiArmedBandit(
        probs=mab.probs,
        choices=predicted_choices,
        rewards=mab.rewards,
        session_ids=mab.session_ids,
        starts=mab.starts,
        stops=mab.stops,
        datetime=mab.datetime,
    )

    sub_df = pd.DataFrame(
        {
            "name": exp.sub_name,
            "trial_id": np.arange(100) + 1,
            "perf": mab.get_performance(),
            "perf_new": new_mab.get_performance(),
            "grp": "struc" if mab.is_structured else "unstruc",
        }
    )
    perf_df.append(sub_df)

perf_df = pd.concat(perf_df, ignore_index=True)

mab_subjects.GroupData().save(perf_df, "perf_qlearning_switch_params")

BewilderbeastExp1Structured
BuffalordExp1Structured
GronckleExp1Structured
ToothlessExp1Structured
AggroExp1UnStructured
AuromaExp1Unstructured
BratExp1Unstructured
GrumpExp1Unstructured
perf_qlearning_switch_params saved


In [30]:
import matplotlib.pyplot as plt
import seaborn as sns
from neuropy import plotting
import mab_subjects
import numpy as np
from statannotations.Annotator import Annotator
from statplot_utils import stat_kw

fig = plotting.Fig(8, 4, size=(8.5, 11), num=1)

grpdata = mab_subjects.GroupData()
df = grpdata.perf_qlearning_switch_params


ax = fig.subplot(fig.gs[0])
# ax.axhline(0, color="gray", lw=0.8, zorder=0)
hue_order = ["unstruc", "struc"]
plot_kw = dict(data=df, x="trial_id", y="perf", hue="grp", hue_order=hue_order, ax=ax)
sns.lineplot(
    palette=["#f77189", "#36ada4"],
    # edgecolor="white",
    # facecolor=(0, 0, 0, 0),
    # alpha=0.4,
    err_kws=dict(edgecolor="none"),
    errorbar="se",
    **plot_kw,
)
plot_kw = dict(
    data=df, x="trial_id", y="perf_new", hue="grp", hue_order=hue_order, ax=ax
)
sns.lineplot(
    palette=["#f77189", "#36ada4"],
    linestyle="--",
    # facecolor=(0, 0, 0, 0),
    # alpha=0.4,
    err_kws=dict(edgecolor="none"),
    errorbar="se",
    **plot_kw,
)

# orders = ["unstruc", "struc"]
# pairs = [(("unstruc"), ("struc"))]
# annotator = Annotator(pairs=pairs, order=orders, **plot_kw)
# annotator.configure(test="Kruskal", **stat_kw, color="k", verbose=True)
# annotator.apply_and_annotate()
# annotator.reset_configuration()
# ax.grid(True)
ax.set_title("Performance")
ax.set_ylabel("Performance")
ax.set_ylim(0.45, 1)
ax.set_xticks([1, 50, 100])
# ax.get_legend().remove()

[<matplotlib.axis.XTick at 0x1b66f0a3110>,
 <matplotlib.axis.XTick at 0x1b66d9265d0>,
 <matplotlib.axis.XTick at 0x1b66d925e50>]

In [None]:
import matplotlib.pyplot as plt

plt.plot(new_mab.get_performance())