**NOTE:** When running in public colab, this cell will prompt you to upload the archive with the code.

In [None]:
# @title Environment setup


!pip install git+https://github.com/google-deepmind/disco_rl.git

In [None]:
import os

import chex
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import seaborn as sns
import tqdm

loaded_update_rule_params = None
rng_key = jax.random.PRNGKey(1)

In [None]:
# Types & utils
from disco_rl import types
from disco_rl import utils

# Environments
from disco_rl.environments import jittable_envs

# Learning
from disco_rl import agent as agent_lib

In [None]:
eta_path = ''
if loaded_update_rule_params is None:
  loaded_update_rule_params = {}
  with open(eta_path, 'rb') as file:
    ur_params_wb = dict(np.load(file))
    for key_wb in ur_params_wb:
      key = key_wb[:-2]
      loaded_update_rule_params[key] = {
          'b': ur_params_wb[f'{key}/b'],
          'w': ur_params_wb[f'{key}/w'],
      }

In [None]:
agent_settings = agent_lib.get_settings_disco()


def get_env(batch_size):
  # return catch.CatchEnvironment(batch_size=batch_size, env_settings=catch.get_config())
  return jittable_envs.CatchJittableEnvironment(
      batch_size=batch_size, env_settings=jittable_envs.get_config_catch()
  )


env = get_env(batch_size=1)

agent_settings.net_settings.name = 'mlp'
agent_settings.net_settings.net_args = dict(
    dense=(512, 512),
    model_arch_name='lstm',
    head_w_init_std=1e-2,
    model_kwargs=dict(
        head_mlp_hiddens=(128,),
        lstm_size=128,
    ),
)
agent_settings.learning_rate = 1e-2
agent_settings.end_learning_rate = agent_settings.learning_rate
agent_settings.weight_decay = 1e-3

agent = agent_lib.Agent(
    agent_settings=agent_settings,
    single_observation_spec=env.single_observation_spec(),
    single_action_spec=env.single_action_spec(),
    batch_axis_name='i',
)

In [None]:
random_update_rule_params, _ = agent.update_rule.init_params(rng_key)

if agent_settings.update_rule_name == 'disco':
  chex.assert_trees_all_equal_shapes_and_dtypes(
      random_update_rule_params, loaded_update_rule_params
  )
else:
  print('Not using a discovered rule.')

In [None]:
def unroll_cpu_actor(
    params,
    actor_state,
    ts,
    env_state,
    rng,
    env,
    rollout_len,
    actor_step_fn,
    devices,
):
  actor_timesteps = []
  for _ in range(rollout_len):
    rng, step_rng = jax.random.split(rng)
    step_rng = jax.random.split(step_rng, len(devices))
    ts = utils.shard_across_devices(ts, devices)

    actor_timestep, actor_state = actor_step_fn(
        params, step_rng, ts, actor_state
    )
    actions = utils.gather_from_devices(actor_timestep.actions)
    env_state, ts = env.step(env_state, actions)

    actor_timesteps.append(actor_timestep)

  actor_rollout = types.ActorRollout.from_timestep(
      utils.tree_stack(actor_timesteps, axis=1)
  )
  return actor_rollout, actor_state, ts, env_state


def unroll_jittable_actor(
    params,
    actor_state,
    ts,
    env_state,
    rng,
    env,
    rollout_len,
    actor_step_fn,
    devices,
):
  del actor_step_fn, devices
  actor_step_fn = agent.actor_step

  def _single_step(carry, step_rng):
    env_state, ts, actor_state = carry
    actor_timestep, actor_state = actor_step_fn(
        params, step_rng, ts, actor_state
    )
    env_state, ts = env.step(env_state, actor_timestep.actions)
    return (env_state, ts, actor_state), actor_timestep

  (env_state, ts, actor_state), actor_rollout = jax.lax.scan(
      _single_step,
      (env_state, ts, actor_state),
      jax.random.split(rng, rollout_len),
  )

  actor_rollout = types.ActorRollout.from_timestep(actor_rollout)
  return actor_rollout, actor_state, ts, env_state

In [None]:
import collections


def split_tree_on_dim(tree, axis: int):
  """Splits all array leaves of a PyTree along a given axis into singletons."""
  leaves = jax.tree_util.tree_leaves(tree)
  split_size = -1

  for leaf in leaves:
    if isinstance(leaf, (jnp.ndarray, np.ndarray)) and leaf.ndim > axis:
      current_dim_size = leaf.shape[axis]
      if split_size == -1:
        split_size = current_dim_size
      elif split_size != current_dim_size:
        raise ValueError(
            f"Inconsistent dimension sizes found for axis {axis}. "
            f"Expected {split_size}, but found leaf with shape {leaf.shape} "
            f"(size {current_dim_size})."
        )

  output_trees = []
  for i in range(split_size):

    def slice_leaf(leaf):
      if (
          isinstance(leaf, (jnp.ndarray, np.ndarray))
          and leaf.ndim > axis
          and leaf.shape[axis] == split_size
      ):
        return jnp.expand_dims(jnp.take(leaf, i, axis=axis), axis)
      else:
        return leaf

    sliced_tree = jax.tree.map(slice_leaf, tree)
    output_trees.append(sliced_tree)
  return output_trees


class SimpleReplayBuffer:
  """A simple FIFO replay buffer for JAX arrays."""

  def __init__(self, capacity: int, seed: int):
    self.buffer = collections.deque(maxlen=capacity)
    self.capacity = capacity
    self.np_rng = np.random.default_rng(seed)

  def add(self, rollout: types.ActorRollout):
    rollout = jax.device_get(rollout)
    split_tree = split_tree_on_dim(rollout, 2)
    self.buffer.extend(split_tree)

  def sample(self, batch_size: int) -> types.ActorRollout | None:
    buffer_size = len(self.buffer)
    if buffer_size == 0:
      print("Warning: Trying to sample from an empty buffer.")
      return None

    indices = self.np_rng.integers(buffer_size, size=batch_size)
    samples = [self.buffer[i] for i in indices]
    batched_samples = jax.tree.map(
        lambda *x: np.concatenate(x, axis=2), *samples
    )
    return batched_samples

  def __len__(self) -> int:
    """Returns the current number of transitions in the buffer."""
    return len(self.buffer)


def accumulate_rewards_scan_fn(carry, x):
  acc_rewards = carry
  rewards, discounts = x
  acc_rewards += rewards
  return acc_rewards * discounts, acc_rewards


def accumulate_rewards(acc_rewards, x):
  rewards, discounts = x
  return jax.lax.scan(
      accumulate_rewards_scan_fn, acc_rewards, (rewards, discounts)
  )

In [None]:
replay_ratio = 32

rollout_len = 29
num_steps = 1000
batch_size = 64
min_buffer_size = batch_size

rng_key = jax.random.PRNGKey(0)

num_envs = batch_size // replay_ratio
devices = tuple(jax.devices()[:num_envs])
env = get_env(num_envs)

actor_step_fn = jax.pmap(agent.actor_step, 'i', devices=devices)
learner_step_fn = jax.pmap(agent.learner_step, 'i', devices=devices)
unroll_jittable_actor = jax.pmap(
    unroll_jittable_actor,
    axis_name='i',
    devices=devices,
    static_broadcasted_argnums=(5, 6, 7, 8),
)
jittable_accumulate_rewards = jax.pmap(
    accumulate_rewards,
    axis_name='i',
    devices=devices,
)

learner_state = agent.initial_learner_state(rng_key)
actor_state = agent.initial_actor_state(rng_key)

update_rule_params = loaded_update_rule_params
update_rule_params = jax.device_put_replicated(update_rule_params, devices)

env_state, ts = env.reset(rng_key)
acc_rewards = jnp.zeros((num_envs,))
is_jittable_actor = isinstance(
    env, jittable_envs.batched_jittable_env.BatchedJittableEnvironment
)

learner_state = jax.device_put_replicated(learner_state, devices)
actor_state = jax.device_put_replicated(actor_state, devices)

if is_jittable_actor:
  unroll_actor = unroll_jittable_actor
  acc_rewards_fn = jittable_accumulate_rewards
  acc_rewards = jnp.zeros((num_envs,))
  (env_state, ts, acc_rewards) = utils.shard_across_devices(
      (env_state, ts, acc_rewards), devices
  )
else:
  unroll_actor = unroll_cpu_actor
  acc_rewards_fn = accumulate_rewards

buffer = SimpleReplayBuffer(capacity=1024, seed=17)

all_metrics = []
all_rewards = []
all_discounts = []
all_steps = []
all_returns = []
total_steps = 0

for step in tqdm.tqdm(range(num_steps)):
  rng_key, rng_actor, rng_learner = jax.random.split(rng_key, 3)

  if is_jittable_actor:
    rng_actor = jax.random.split(rng_actor, len(devices))

  actor_rollout, actor_state, ts, env_state = unroll_actor(
      learner_state.params,
      actor_state,
      ts,
      env_state,
      rng_actor,
      env,
      rollout_len,
      actor_step_fn,
      devices,
  )
  buffer.add(actor_rollout)

  total_steps += np.prod(actor_rollout.rewards.shape)
  acc_rewards, returns = acc_rewards_fn(
      acc_rewards,
      (actor_rollout.rewards, actor_rollout.discounts),
  )

  all_steps.append(total_steps)
  all_rewards.append(jax.device_get(actor_rollout.rewards))
  all_discounts.append(jax.device_get(actor_rollout.discounts))
  all_returns.append(jax.device_get(returns))

  rng_learner = jax.random.split(rng_learner, len(devices))

  if len(buffer) >= min_buffer_size:
    learner_batch = buffer.sample(batch_size)
    learner_state, _, metrics = learner_step_fn(
        rng=rng_learner,
        rollout=learner_batch,
        learner_state=learner_state,
        agent_net_state=actor_state,
        update_rule_params=update_rule_params,
    )
    all_metrics.append(jax.device_get(metrics))


all_metrics, all_rewards, all_discounts, all_returns = (
    utils.gather_from_devices(
        (all_metrics, all_rewards, all_discounts, all_returns)
    )
)
(all_metrics,) = jax.tree.map(lambda x: x.mean(0), (all_metrics,))

In [None]:
all_returns = np.array(all_returns)
all_discounts = np.array(all_discounts)
all_steps = np.array(all_steps)
total_returns = (all_returns * (1 - all_discounts)).sum(axis=(1, 2))
total_episodes = (1 - all_discounts).sum(axis=(1, 2))
avg_returns = total_returns / total_episodes

padded_metrics = {}
pad_width = len(all_steps) - len(all_metrics)
for key in all_metrics[0].keys():
  values = np.array([m[key] for m in all_metrics])
  padded_metrics[key] = np.pad(values, (pad_width, 0), constant_values=np.nan)

df = pd.DataFrame(
    dict(
        steps=all_steps,
        avg_returns=avg_returns,
        **padded_metrics,
    )
)

df['name'] = agent_settings.update_rule_name

In [None]:
sns.lineplot(data=df, x='steps', y='avg_returns')