In [None]:
from src.rejax.envs.bernoulli_bandit import BernoulliBandit, EnvParams
from rejax import get_algo
from rejax.evaluate import evaluate

import _pickle as pickle
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import yaml

In [None]:
config_path = "./configs/custom/bandit.yaml"

with open(config_path, "r") as f:
    config = yaml.safe_load(f.read())

In [None]:
algo_name = "ucb"
seed_id = 42
num_seeds = 10000
num_arms = 5

key = jax.random.PRNGKey(seed_id)
keys = jax.random.split(key, num_seeds)
config = config[algo_name]

In [None]:
config

In [None]:
env = BernoulliBandit()
config["env"] = env

env_params = jax.random.beta(key, a=0.2, b=0.2, shape=(num_seeds, num_arms))

algo_cls = get_algo(algo_name)
algo = jax.vmap(
    lambda x: algo_cls.create(
        **config,
        env_params=EnvParams(reward_probs=x),
    )
)(
    env_params
)


def eval_callback(algo, ts, rng):
    act = algo.make_act(ts)
    max_steps = algo.env_params.max_steps_in_episode
    return evaluate(act, rng, env, algo.env_params, 200, max_steps)

algo = algo.replace(
    eval_callback=eval_callback
)

In [None]:
algo.env_params.reward_probs

In [None]:
vmap_train = jax.jit(jax.vmap(algo_cls.train, in_axes=(0, 0)))
ts, (_, returns) = vmap_train(algo, keys)
returns.block_until_ready()

print(None)

In [None]:
jax.vmap(lambda x: jax.numpy.mean(x, axis=-1)[-1])(returns)

In [None]:
np.argmax(env_params, axis=-1)

In [None]:
np.argmax(ts.agent_ts.params["params"]["counts"], axis=-1).T

In [None]:
np.argmax(ts.agent_ts.params["params"]["q_values"], axis=-1).T

In [None]:
jax.debug.print("{x}", x=ts.agent_ts.params)

In [None]:
pickle.dump(
    {
        "buffer": ts.store_buffer.__dict__,
        "algorithm": {
            "algo": algo_name,
            **{k: v for k, v in config.items() if k != "env"},
        },
        "env": type(config["env"]).__name__,
        "env_params": env_params,
    },
    open("learning_hist-{}.pkl".format(algo_name), "wb"),
)

In [None]:
assert 0

In [None]:
for env_returns in returns:
    xrange = np.arange(len(env_returns)) * config["eval_freq"]
    mean = np.mean(env_returns, axis=-1)
    std = np.std(env_returns, axis=-1) / np.sqrt(env_returns.shape[-1])
    plt.plot(xrange, mean)
    plt.fill_between(xrange, mean - std, mean + std, alpha=0.2)
plt.show()
