In [10]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import trange
import os
import json
import pandas as pd
from banditpy.models import BanditTrainer2Arm


bt = BanditTrainer2Arm(
    n_train_sessions=20000, n_test_sessions=200, trials_per_session=200
)
df_train = bt.train(structured=True)
bt.save_model()

Starting training for 20000 sessions...
Training Session 100/20000, Avg Loss (last 100): 0.6087
Training Session 200/20000, Avg Loss (last 100): 0.1906
Training Session 300/20000, Avg Loss (last 100): 0.2419
Training Session 400/20000, Avg Loss (last 100): 0.1309
Training Session 500/20000, Avg Loss (last 100): 0.0967
Training Session 600/20000, Avg Loss (last 100): 0.2127
Training Session 700/20000, Avg Loss (last 100): 0.1857
Training Session 800/20000, Avg Loss (last 100): 0.0989
Training Session 900/20000, Avg Loss (last 100): 0.1897
Training Session 1000/20000, Avg Loss (last 100): 0.1606
Training Session 1100/20000, Avg Loss (last 100): 0.1051
Training Session 1200/20000, Avg Loss (last 100): 0.1696
Training Session 1300/20000, Avg Loss (last 100): 0.1244
Training Session 1400/20000, Avg Loss (last 100): 0.0925
Training Session 1500/20000, Avg Loss (last 100): 0.0961
Training Session 1600/20000, Avg Loss (last 100): 0.1116
Training Session 1700/20000, Avg Loss (last 100): 0.1292


In [13]:
from banditpy.core import Bandit2Arm
import matplotlib.pyplot as plt
from neuropy import plotting

df_test = bt.evaluate(reward_probs=[0.8, 0.2])

fig = plotting.Fig(6, 5)
ax = fig.subplot(fig.gs[0])
ax.plot(bt.training_loss_history)

for i, df in enumerate([df_train, df_test]):
    ax = fig.subplot(fig.gs[i + 1])
    task = Bandit2Arm(
        probs=df.loc[:, ["arm1_reward_prob", "arm2_reward_prob"]].to_numpy(),
        choices=df["chosen_action"].to_numpy(),
        rewards=df["reward"].to_numpy(),
        session_ids=df["session_id"].to_numpy(),
    )

    ax.plot(task.get_performance())
    ax.set_xlabel("Trials")
    ax.set_ylim(0.4, 1)
    ax.set_ylabel("Pr (High)")

Starting evaluation with fixed weights...
Model loaded from two_arm_task_model.pt
Evaluation Session 50/200 complete.
Evaluation Session 100/200 complete.
Evaluation Session 150/200 complete.
Evaluation Session 200/200 complete.
Evaluation complete.


In [None]:
df_test["chosen_action"].unique()