In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from bsuite.baselines.base import Agent
import dm_env
from helx.rl import baselines
import gym
from gym_minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
from bsuite.utils.gym_wrapper import DMEnvFromGym
import wandb
import logging

In [3]:
def make(name):
    env = gym.make(name)
    env = RGBImgPartialObsWrapper(env)  # Get pixel observations
    env = ImgObsWrapper(env)  # Get rid of the 'mission' field
    env = DMEnvFromGym(env)  #  Convert to dm_env.Environment
    return env


def run(
    agent: Agent,
    env: dm_env.Environment,
    num_episodes: int,
    eval_mode: bool = False,
) -> Agent:
    wandb.init(project="dqn")
    logging.info(
        "Starting {} agent {} on environment {}.\nThe scheduled number of episode is {}".format(
            "evaluating" if eval_mode else "training", agent, env, num_episodes
        )
    )
    for episode in range(num_episodes):
        print(
            "Starting episode number {}/{}\t\t\t".format(episode, num_episodes - 1),
            end="\r",
        )
        wandb.log({"Episode": episode})
        # initialise environment
        timestep = env.reset()
        while not timestep.last():
            # policy
            action = agent.select_action(timestep)
            # step environment
            new_timestep = env.step(tuple(action))
            wandb.log({"Reward": new_timestep.reward})
            # update
            if not eval_mode:
                loss = agent.update(timestep, action, new_timestep)
                if loss is not None:
                    wandb.log({"Bellman MSE": float(loss)})
                wandb.log({"Iteration": agent.iteration})
            # prepare next
            timestep = new_timestep
    return agent

In [4]:
env = make("MiniGrid-Empty-8x8-v0")

In [5]:
dqn = baselines.dqn.Dqn(env.observation_spec(), env.action_spec(), baselines.dqn.HParams())

In [9]:
run(dqn, env, 100)

0,1
Episode,0
_runtime,3
_timestamp,1619522186
_step,0


0,1
Episode,▁
_runtime,▁
_timestamp,▁
_step,▁


Starting episode number 0/99			

AssertionError: unknown action