In [22]:
%load_ext autoreload
%autoreload 1
%aimport earl.agents.r2d2.networks
%aimport earl.agents.r2d2.r2d2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
import ale_py
import equinox as eqx
import gymnasium
import jax
import jax.numpy as jnp
import jax.random

import earl.agents.r2d2.networks
import earl.agents.r2d2.r2d2 as r2d2

gymnasium.register_envs(ale_py)  # unnecessary but helpful for IDEs

env = gymnasium.make("BreakoutNoFrameskip-v4")
env = gymnasium.wrappers.AtariPreprocessing(env, noop_max=0)
stack_size = 4
env = gymnasium.wrappers.FrameStackObservation(env, stack_size=stack_size)



key = jax.random.PRNGKey(0)
torso_key, lstm_key, dueling_value_key, dueling_advantage_key = jax.random.split(key, num=4)
hidden_size = 512
network = r2d2.R2D2Network(
  torso=earl.agents.r2d2.networks.DeepAtariTorso(in_channels=stack_size, hidden_sizes=(hidden_size - env.action_space.n - 1,), key=torso_key),
  lstm_cell=eqx.nn.LSTMCell(hidden_size, hidden_size, key=lstm_key),
  dueling_value=eqx.nn.Linear(hidden_size, 1, key=dueling_value_key),
  dueling_advantage=eqx.nn.Linear(hidden_size, 1, key=dueling_advantage_key),
  num_actions=env.action_space.n,
)

obs = env.observation_space.sample()
print(f"obs.shape: {obs.shape}")

action = jax.random.randint(key, (1,), 0, env.action_space.n)
reward = jax.random.uniform(key, ())
hidden = (jnp.zeros((hidden_size,)), jnp.zeros((hidden_size,)))

print(network(obs, action, reward, hidden))


obs.shape: (4, 84, 84)
resnetoutput shape: (3872,)
(Array([0.02557073], dtype=float32), (Array([-1.78308619e-05,  1.17526269e-02, -3.49557549e-02,  9.84092616e-03,
       -2.66669318e-02, -2.60461064e-04, -4.35773656e-03, -8.26306548e-03,
        1.05510810e-02,  3.05276830e-04, -6.31999737e-03, -5.29096229e-03,
       -3.73429023e-02,  4.82893502e-03,  7.03593949e-03, -7.20780529e-03,
        1.92972366e-02, -1.80898644e-02,  9.45764408e-03, -1.43679336e-03,
       -1.00516537e-02,  1.49921072e-03,  9.39293671e-03, -1.35090388e-02,
       -1.75048560e-02,  5.83022251e-04,  2.18355041e-02, -2.42176577e-02,
       -3.43016000e-03,  3.38377710e-03,  2.64539160e-02, -1.00015467e-02,
       -9.39817354e-03, -1.26885492e-02, -8.68064631e-03, -1.58104748e-02,
       -1.15965004e-03, -1.25149367e-02,  9.85334162e-03,  3.99144134e-03,
       -9.19124391e-03,  1.59394685e-02,  1.41830631e-02,  3.25209983e-02,
       -1.83296727e-03, -1.52574750e-02,  1.44205894e-02, -4.38240962e-03,
       -4.3

In [24]:
jnp.squeeze(jax.nn.one_hot(action, num_classes=env.action_space.n))

Array([0., 1., 0., 0.], dtype=float32)