In [11]:
%load_ext autoreload
%autoreload 1
%aimport earl.agents.r2d2.networks
%aimport earl.agents.r2d2.r2d2
%aimport earl.agents.r2d2.utils
%aimport earl.core

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
import math

import ale_py
import gymnasium
import jax
import jax.numpy as jnp
import numpy as np
from jax_loop_utils.metric_writers.torch import TensorboardWriter

from earl.agents.r2d2 import utils
import earl.agents.r2d2.networks as r2d2_networks
from earl.agents.r2d2.r2d2 import R2D2, R2D2Config
from earl.core import env_info_from_gymnasium, env_info_from_gymnax
from earl.environment_loop.gymnasium_loop import GymnasiumLoop
from earl.environment_loop.gymnax_loop import GymnaxLoop

gymnasium.register_envs(ale_py)  # suppress unused import warning

## init environment

In [13]:
import gymnax.environments.spaces
from gymnax.environments.minatar.asterix import MinAsterix

env = MinAsterix()
num_envs = 64
env_params = env.default_params
env_info = env_info_from_gymnax(env, env_params, num_envs)
action_space = env.action_space(env_params)
assert isinstance(action_space, gymnax.environments.spaces.Discrete), action_space
num_actions = int(action_space.n)
observation_space = env.observation_space(env_params)
assert isinstance(observation_space, gymnax.environments.spaces.Box), observation_space


## init networks and agent

In [14]:
hidden_size = 64
key = jax.random.PRNGKey(0)
networks_key, loop_key, agent_key = jax.random.split(key, 3)
networks = r2d2_networks.make_networks_mlp(
    num_actions=num_actions,
    input_size=int(math.prod(observation_space.shape)),
    dtype=jnp.float32,
    hidden_size=hidden_size,
    key=networks_key,
)
num_cycles = 5_000
steps_per_cycle = 80
devices = jax.local_devices()
print(f"running on {len(devices)} {devices[0].platform} devices")
config = R2D2Config(
    epsilon_greedy_schedule_args=dict(
      init_value=0.1, end_value=0.01, transition_steps=steps_per_cycle * num_cycles
    ),
    num_envs_per_learner=num_envs,
    replay_seq_length=steps_per_cycle,
    buffer_capacity=steps_per_cycle * 10,
    burn_in=40,
    learning_rate_schedule_name="cosine_onecycle_schedule",
    learning_rate_schedule_args=dict(
      transition_steps=steps_per_cycle * num_cycles // 2,
      peak_value=1e-4 * len(devices),
    ),
)
agent = R2D2(env_info, config)
loop_state = agent.new_state(networks, agent_key)


running on 2 gpu devices


## init metric writer and training loop

run in `actor_only` and pass `observe_cycle` to render a video of initial performance


In [15]:
metric_writer = TensorboardWriter(logdir="logs/MinAsterix/before_training")
gymnax_loop = GymnaxLoop(
    env,
    env_params,
    agent,
    num_envs,
    loop_key,
    actor_only=True,
    observe_cycle=utils.render_minatar_cycle,
    metric_writer=metric_writer,
    devices=devices,
)

In [16]:
%load_ext tensorboard
%tensorboard --logdir logs


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 2661613), started 0:07:40 ago. (Use '!kill 2661613' to kill it.)

## run one cycle to gather video

It will show up in the Images tab of Tensorboard

In [17]:
loop_state = gymnax_loop.run(loop_state, 1, 2*steps_per_cycle)
metric_writer.close()

                                                        

## train

In [None]:
metric_writer = TensorboardWriter(logdir="logs/MinAsterix/after_training")
gymnax_loop = GymnaxLoop(
    env,
    env_params,
    agent,
    num_envs,
    loop_key,
    metric_writer=metric_writer,
    devices=devices,
)
loop_state = gymnax_loop.run(loop_state, num_cycles, steps_per_cycle)


cycles:  60%|██████    | 3017/5000 [07:16<04:40,  7.08cycle/s]

## run one cycle to gather video

It will show up in the Images tab of Tensorboard

In [None]:
gymnax_loop = GymnaxLoop(
    env,
    env_params,
    agent,
    num_envs,
    loop_key,
    actor_only=True,
    observe_cycle=utils.render_minatar_cycle,
    metric_writer=metric_writer,
    devices=devices,
)
loop_state = gymnax_loop.run(loop_state, 1, 2*steps_per_cycle)
metric_writer.close()