In [1]:
from src.rejax.envs.bernoulli_bandit import BernoulliBandit, EnvParams
from rejax import get_algo
from rejax.evaluate import evaluate

import jax
import jax.numpy as jnp
import numpy as np
import yaml

In [2]:
config_path = "/Users/chanb/research/ualberta/sandbox/rejax/configs/custom/bandit.yaml"

with open(config_path, "r") as f:
    config = yaml.safe_load(f.read())

In [3]:
algo_name = "dqn"
seed_id = 0
num_seeds = 2

key = jax.random.PRNGKey(seed_id)
keys = jax.random.split(key, num_seeds)
config = config[algo_name]

In [4]:
config

{'env': 'CartPole-v1',
 'agent': 'DuelingQNetwork',
 'agent_kwargs': {'activation': 'swish'},
 'num_envs': 1,
 'buffer_size': 50000,
 'fill_buffer': 5000,
 'batch_size': 100,
 'max_grad_norm': 10,
 'learning_rate': 0.001,
 'num_epochs': 5,
 'total_timesteps': 100000,
 'eval_freq': 2500,
 'polyak': 0.98,
 'eps_start': 1,
 'eps_end': 0.05,
 'exploration_fraction': 0.5,
 'gamma': 0.99,
 'ddqn': True,
 'normalize_observations': False}

In [5]:
env = BernoulliBandit()
config["env"] = env

env_params = jnp.array([
    [0.5, 0.5],
    [0.1, 0.9],
])

algo_cls = get_algo(algo_name)
algo = jax.vmap(lambda x: algo_cls.create(**config))(keys)
algo = algo.replace(
    env_params=algo.env_params.replace(reward_probs=env_params)
)


def eval_callback(algo, ts, rng):
    act = algo.make_act(ts)
    max_steps = algo.env_params.max_steps_in_episode
    return evaluate(act, rng, env, algo.env_params, 128, max_steps)

algo = algo.replace(
    eval_callback=eval_callback
)

In [6]:
algo

DQN(env=<src.rejax.envs.bernoulli_bandit.BernoulliBandit object at 0x12b51e950>, env_params=EnvParams(max_steps_in_episode=Array([1, 1], dtype=int32, weak_type=True), reward_probs=Array([[0.5, 0.5],
       [0.1, 0.9]], dtype=float32)), eval_callback=<function eval_callback at 0x12b410160>, eval_freq=2500, skip_initial_evaluation=False, total_timesteps=100000, learning_rate=Array([0.001, 0.001], dtype=float32, weak_type=True), gamma=Array([0.99, 0.99], dtype=float32, weak_type=True), max_grad_norm=Array([10, 10], dtype=int32, weak_type=True), normalize_rewards=False, reward_normalization_discount=0.99, normalize_observations=False, target_update_freq=1, polyak=Array([0.98, 0.98], dtype=float32, weak_type=True), num_envs=1, buffer_size=50000, fill_buffer=5000, batch_size=100, eps_start=Array([1, 1], dtype=int32, weak_type=True), eps_end=Array([0.05, 0.05], dtype=float32, weak_type=True), exploration_fraction=0.5, agent=EpsilonGreedyPolicy(
    # attributes
    hidden_layer_sizes = (64, 6

In [7]:
vmap_train = jax.jit(jax.vmap(algo_cls.train, in_axes=(0, 0)))
ts, (_, returns) = vmap_train(algo, keys)
returns.block_until_ready()

  return lax_numpy.astype(self, dtype, copy=copy, device=device)


Array([[[0., 0., 0., ..., 1., 0., 1.],
        [0., 1., 1., ..., 0., 1., 1.],
        [1., 0., 0., ..., 1., 0., 1.],
        ...,
        [1., 1., 0., ..., 1., 1., 0.],
        [0., 0., 1., ..., 0., 1., 0.],
        [0., 0., 1., ..., 1., 0., 1.]],

       [[1., 0., 1., ..., 0., 1., 1.],
        [1., 1., 1., ..., 1., 0., 0.],
        [1., 1., 1., ..., 0., 1., 1.],
        ...,
        [1., 0., 1., ..., 1., 1., 1.],
        [0., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.]]], dtype=float32, weak_type=True)

In [8]:
jax.vmap(lambda x: jax.numpy.mean(x))(returns)

Array([0.5062881 , 0.86509144], dtype=float32)