In [None]:
from bsuite.baselines.base import Agent
import dm_env
from helx.rl import baselines

In [5]:
def run(
    agent: Agent,
    env: dm_env.Environment,
    num_episodes: int,
    eval_mode: bool = False,
) -> Agent:
    wandb.init(project="dqn")
    logging.info(
        "Starting {} agent {} on environment {}.\nThe scheduled number of episode is {}".format(
            "evaluating" if eval_mode else "training", agent, env, num_episodes
        )
    )
    for episode in range(num_episodes):
        print(
            "Starting episode number {}/{}\t\t\t".format(episode, num_episodes - 1),
            end="\r",
        )
        wandb.log({"Episode": episode})
        # initialise environment
        timestep = env.reset()
        while not timestep.last():
            # policy
            action = agent.select_action(timestep)
            # step environment
            new_timestep = env.step(tuple(action))
            wandb.log({"Reward": new_timestep.reward})
            # update
            if not eval_mode:
                loss = agent.update(timestep, action, new_timestep)
                if loss is not None:
                    wandb.log({"Bellman MSE": float(loss)})
                wandb.log({"Iteration": agent.iteration})
            # prepare next
            timestep = new_timestep
    return agent

In [None]:
from functools import partial
from typing import NamedTuple, Tuple

import dm_env
import jax
# import jax.numpy as jnp
# from bsuite.baselines.base import Action, Agent
# from dm_env import specs
# from jax.experimental import stax
# from jax.experimental.optimizers import OptimizerState, rmsprop_momentum