In [1]:
from src.rejax.envs.bernoulli_bandit import BernoulliBandit, EnvParams
from rejax import get_algo
from rejax.evaluate import evaluate

import _pickle as pickle
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import yaml

In [2]:
config_path = "/Users/chanb/research/ualberta/sandbox/rejax/configs/custom/bandit.yaml"

with open(config_path, "r") as f:
    config = yaml.safe_load(f.read())

In [3]:
algo_name = "ucb"
seed_id = 42
num_seeds = 10000
num_arms = 5

key = jax.random.PRNGKey(seed_id)
keys = jax.random.split(key, num_seeds)
config = config[algo_name]

In [4]:
config

{'agent_kwargs': {'confidence': 1.0},
 'total_timesteps': 10000,
 'eval_freq': 50,
 'buffer_size': 10000}

In [5]:
env = BernoulliBandit()
config["env"] = env

env_params = jax.random.beta(key, a=0.2, b=0.2, shape=(num_seeds, num_arms))

algo_cls = get_algo(algo_name)
algo = jax.vmap(
    lambda x: algo_cls.create(
        **config,
        env_params=EnvParams(reward_probs=x),
    )
)(
    env_params
)


def eval_callback(algo, ts, rng):
    act = algo.make_act(ts)
    max_steps = algo.env_params.max_steps_in_episode
    return evaluate(act, rng, env, algo.env_params, 200, max_steps)

algo = algo.replace(
    eval_callback=eval_callback
)

In [6]:
algo.env_params.reward_probs

Array([[9.8228019e-01, 9.9999964e-01, 8.2774550e-01, 8.4435036e-03,
        9.0748203e-01],
       [3.3898753e-01, 4.3795776e-01, 5.6661258e-05, 1.4348644e-03,
        3.3623093e-01],
       [9.9965811e-01, 9.3900967e-01, 3.0559276e-03, 4.0207192e-02,
        1.8510675e-02],
       ...,
       [2.0241037e-01, 9.8987877e-01, 1.5433159e-04, 4.8612203e-02,
        9.1597050e-01],
       [9.9443638e-01, 9.9909961e-01, 9.8552716e-01, 1.1944144e-05,
        9.8501843e-01],
       [9.9912339e-01, 4.8892325e-01, 1.0000000e+00, 8.4167169e-03,
        6.0977685e-01]], dtype=float32)

In [7]:
vmap_train = jax.jit(jax.vmap(algo_cls.train, in_axes=(0, 0)))
ts, (_, returns) = vmap_train(algo, keys)
returns.block_until_ready()

print(None)

  return lax_numpy.astype(self, dtype, copy=copy, device=device)


None


In [8]:
jax.vmap(lambda x: jax.numpy.mean(x, axis=-1)[-1])(returns)

Array([1.        , 0.47      , 1.        , ..., 0.98499995, 1.        ,
       1.        ], dtype=float32)

In [9]:
np.argmax(env_params, axis=-1)

Array([1, 1, 0, ..., 1, 1, 2], dtype=int32)

In [10]:
np.argmax(ts.agent_ts.params["params"]["counts"], axis=-1).T

Array([[1, 1, 0, ..., 1, 1, 2]], dtype=int32)

In [11]:
np.argmax(ts.agent_ts.params["params"]["q_values"], axis=-1).T

Array([[1, 1, 0, ..., 1, 1, 2]], dtype=int32)

In [12]:
jax.debug.print("{x}", x=ts.agent_ts.params)

{'params': {'counts': Array([[[3188., 5757.,  273.,   17.,  765.]],

       [[ 788., 8222.,   77.,   81.,  832.]],

       [[8334., 1613.,   17.,   19.,   17.]],

       ...,

       [[  34., 8665.,   18.,   20., 1263.]],

       [[2661., 3013., 2247.,   16., 2063.]],

       [[4885.,   57., 4986.,   17.,   55.]]], dtype=float32), 'q_values': Array([[[0.9805517 , 1.        , 0.7948715 , 0.        , 0.9006534 ]],

       [[0.33629432, 0.44234985, 0.        , 0.01234568, 0.3401443 ]],

       [[0.9996283 , 0.93924385, 0.        , 0.05263158, 0.        ]],

       ...,

       [[0.29411766, 0.9893836 , 0.        , 0.05      , 0.9144896 ]],

       [[0.993987  , 0.9990052 , 0.98664886, 0.        , 0.9825496 ]],

       [[0.99938   , 0.491228  , 1.        , 0.        , 0.4727273 ]]],      dtype=float32), 'timesteps': Array([10000., 10000., 10000., ..., 10000., 10000., 10000.], dtype=float32)}}


In [14]:
pickle.dump(
    {
        "buffer": ts.store_buffer.__dict__,
        "algorithm": {
            "algo": algo_name,
            **{k: v for k, v in config.items() if k != "env"},
        },
        "env": type(config["env"]).__name__,
        "env_params": env_params,
    },
    open("learning_hist-{}.dill".format(algo_name), "wb"),
)

In [None]:
for env_returns in returns:
    xrange = np.arange(len(env_returns)) * config["eval_freq"]
    mean = np.mean(env_returns, axis=-1)
    std = np.std(env_returns, axis=-1) / np.sqrt(env_returns.shape[-1])
    plt.plot(xrange, mean)
    plt.fill_between(xrange, mean - std, mean + std, alpha=0.2)
plt.show()
