**NOTE:** When running in public colab, this cell will prompt you to upload the archive with the code.

In [None]:
# @title Environment setup


!pip install git+https://github.com/google-deepmind/disco_rl.git

In [None]:
from typing import Any

import chex
import distrax
import haiku as hk
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
from ml_collections import config_dict
import numpy as np
import optax
import rlax
import tqdm

loaded_update_rule_params = None
rng_key = jax.random.PRNGKey(1)

In [None]:
# Types & utils
from disco_rl import types
from disco_rl import utils

# Environments
from disco_rl.environments import base as base_env
from disco_rl.environments import jittable_envs

# Learning
from disco_rl import agent as agent_lib
from disco_rl.value_fns import value_fn

In [None]:
eta_path = ''
if loaded_update_rule_params is None:
  loaded_update_rule_params = {}
  with open(eta_path, 'rb') as file:
    ur_params_wb = dict(np.load(file))
    for key_wb in ur_params_wb:
      key = key_wb[:-2]
      loaded_update_rule_params[key] = {
          'b': ur_params_wb[f'{key}/b'],
          'w': ur_params_wb[f'{key}/w'],
      }

In [None]:
agent_settings = agent_lib.get_settings_disco()
axis_name = 'i'


def get_env(batch_size):
  return jittable_envs.CatchJittableEnvironment(
      batch_size=batch_size,
      env_settings=config_dict.ConfigDict(dict(rows=5, columns=5)),
  )


env = get_env(batch_size=1)

agent_settings.net_settings.name = 'mlp'
agent_settings.net_settings.net_args = dict(
    dense=(512, 512),
    model_arch_name='lstm',
    head_w_init_std=1e-2,
    model_kwargs=dict(
        head_mlp_hiddens=(256,),
        lstm_size=256,
    ),
)
agent_settings.learning_rate = 5e-4
agent_settings.end_learning_rate = agent_settings.learning_rate

agent = agent_lib.Agent(
    agent_settings=agent_settings,
    single_observation_spec=env.single_observation_spec(),
    single_action_spec=env.single_action_spec(),
    batch_axis_name=axis_name,
)

In [None]:
random_update_rule_params, _ = agent.update_rule.init_params(rng_key)

if agent_settings.update_rule_name == 'disco':
  chex.assert_trees_all_equal_shapes_and_dtypes(
      random_update_rule_params, loaded_update_rule_params
  )
else:
  print('Not using a discovered rule.')

In [None]:
# Configs
num_agents = 1
rollout_len = 16
num_inner_steps = 1
batch_size_per_device = 128
rng_key = jax.random.PRNGKey(12)

update_rule_params = random_update_rule_params  # loaded_update_rule_params

value_fn_config = types.ValueFnConfig(
    net='mlp',
    net_args=dict(
        dense=(256, 256),
        head_w_init_std=1e-2,
        action_spec=(),
    ),
    learning_rate=1e-3,
    max_abs_update=1.0,
    discount_factor=0.99,
    td_lambda=0.96,
    outer_value_cost=1.0,
)

In [None]:
def unroll_jittable_actor(
    params,
    actor_state,
    ts,
    env_state,
    rng,
    env,
    rollout_len,
    actor_step_fn,
):

  def _single_step(carry, step_rng):
    env_state, ts, actor_state = carry
    actor_timestep, actor_state = actor_step_fn(
        params, step_rng, ts, actor_state
    )
    env_state, ts = env.step(env_state, actor_timestep.actions)
    return (env_state, ts, actor_state), actor_timestep

  (env_state, ts, actor_state), actor_rollout = jax.lax.scan(
      _single_step,
      (env_state, ts, actor_state),
      jax.random.split(rng, rollout_len),
  )

  actor_rollout = types.ActorRollout.from_timestep(actor_rollout)
  return actor_rollout, actor_state, ts, env_state

In [None]:
@chex.dataclass
class MetaTrainState:
  learner_state: agent_lib.LearnerState
  actor_state: types.HaikuState
  value_state: types.ValueState
  env_state: Any
  env_timestep: types.EnvironmentTimestep


class MetaTrainAgent:
  agent: agent_lib.Agent
  value_fn: value_fn.ValueFunction
  env: base_env.Environment

  def __init__(
      self,
      batch_size_per_device: int,
      agent_settings: config_dict.ConfigDict,
      value_fn_config: types.ValueFnConfig,
      axis_name: str | None = None,
  ):
    self.env = get_env(batch_size_per_device)
    self.agent = agent_lib.Agent(
        agent_settings=agent_settings,
        single_observation_spec=self.env.single_observation_spec(),
        single_action_spec=self.env.single_action_spec(),
        batch_axis_name=axis_name,
    )
    self.value_fn = value_fn.ValueFunction(value_fn_config, axis_name=axis_name)
    self._env = env

    self._unroll_jittable_actor = jax.jit(
        unroll_jittable_actor,
        static_argnames=('env', 'rollout_len', 'actor_step_fn'),
    )

  def init_state(self, rng_key: chex.PRNGKey) -> MetaTrainState:
    dummy_obs = utils.zeros_like_spec(
        self.env.single_observation_spec(),
        prepend_shape=(batch_size_per_device,),
    )
    rng_keys = jax.random.split(rng_key, 3)
    env_state, env_timestep = self.env.reset(rng_key)
    return MetaTrainState(
        learner_state=self.agent.initial_learner_state(rng_keys[0]),
        actor_state=self.agent.initial_actor_state(rng_keys[1]),
        value_state=self.value_fn.initial_state(rng_keys[2], dummy_obs),
        env_state=env_state,
        env_timestep=env_timestep,
    )

  @property
  def learner_step(self):
    return self.agent.learner_step

  @property
  def actor_step(self):
    return self.agent.actor_step

  @property
  def unroll_net(self):
    return self.agent.unroll_net

  def unroll_actor(
      self, state: MetaTrainState, rng: chex.PRNGKey, rollout_len: int
  ) -> tuple[MetaTrainState, types.ActorRollout]:
    params = state.learner_state.params
    actor_state = state.actor_state
    ts = state.env_timestep
    env_state = state.env_state
    rollout, actor_state, env_timestep, env_state = self._unroll_jittable_actor(
        params,
        actor_state,
        ts,
        env_state,
        rng,
        self.env,
        rollout_len,
        self.agent.actor_step,
    )
    new_state = MetaTrainState(
        learner_state=state.learner_state,
        actor_state=actor_state,
        value_state=state.value_state,
        env_state=env_state,
        env_timestep=env_timestep,
    )
    return new_state, rollout


# Create multiple agents.
agents = []
agents_states = []
rng, rng_key = jax.random.split(rng_key)
for rng in jax.random.split(rng, num_agents):
  agents.append(
      MetaTrainAgent(
          batch_size_per_device=batch_size_per_device,
          agent_settings=agent_settings,
          value_fn_config=value_fn_config,
      )
  )
  agents_states.append(agents[-1].init_state(rng))

In [None]:
def outer_grad(
    update_rule_params: types.MetaParams,
    agent_state: MetaTrainState,
    train_rollouts: types.ActorRollout,
    valid_rollout: types.ActorRollout,
    rng: chex.PRNGKey,
    agent: MetaTrainAgent,
    axis_name: str | None = axis_name,
):

  unroll_len = train_rollouts.rewards.shape[0]

  def _inner_step(carry, inputs):
    update_rule_params, learner_state, actor_state, value_state = carry
    actor_rollout, learner_rng = inputs

    # Update learner.
    new_learner_state, new_actor_state, metrics = agent.learner_step(
        rng=learner_rng,
        rollout=actor_rollout,
        learner_state=learner_state,
        agent_net_state=actor_state,
        update_rule_params=update_rule_params,
    )

    # Update value function.
    agent_out, _ = agent.unroll_net(
        learner_state.params, actor_state, actor_rollout
    )
    new_value_state, _, _ = agent.value_fn.update(
        value_state, actor_rollout, agent_out['logits']
    )

    return (
        update_rule_params,
        new_learner_state,
        new_actor_state,
        new_value_state,
    ), metrics

  def _outer_loss(
      update_rule_params: types.MetaParams,
      agent_state: MetaTrainState,
      train_rollouts: types.ActorRollout,
      valid_rollout: types.ActorRollout,
  ):
    """Calculates loss for the update rule."""
    train_rng, valid_rng = jax.random.split(rng, 2)

    # Perform inner steps (i.e. updates).
    learner_rngs = jax.random.split(train_rng, unroll_len)
    (_, new_learner_state, new_actor_state, new_value_state), train_metrics = (
        jax.lax.scan(
            _inner_step,
            (
                update_rule_params,
                agent_state.learner_state,
                agent_state.actor_state,
                agent_state.value_state,
            ),
            (train_rollouts, learner_rngs),
        )
    )
    train_meta_out = train_metrics.pop('meta_out')

    # Run inference on the validation rollout.
    agent_rollout_on_valid, _ = hk.BatchApply(
        lambda ts: agent.actor_step(
            actor_params=new_learner_state.params,
            rng=valid_rng,
            timestep=ts,
            actor_state=valid_rollout.first_state(time_axis=0),
        )
    )(valid_rollout.to_env_timestep())

    # Calculate value_fn on the validation rollout.
    value_out, _, _, _ = agent.value_fn.get_value_outs(
        new_value_state, valid_rollout, agent_rollout_on_valid['logits']
    )

    actions_on_valid = valid_rollout.actions[:-1]
    logits_on_valid = agent_rollout_on_valid['logits'][:-1]
    adv_t = jax.lax.stop_gradient(value_out.normalized_adv)
    pg_loss_per_step = utils.differentiable_policy_gradient_loss(
        logits_on_valid, actions_on_valid, adv_t=adv_t, backprop=False
    )
    entropy_loss_per_step = -distrax.Softmax(logits_on_valid).entropy()

    # Compute policy gradient loss.
    chex.assert_rank((pg_loss_per_step, entropy_loss_per_step), 2)  # [T, B]
    rl_loss = (pg_loss_per_step + 1e-2 * entropy_loss_per_step).mean()

    # Meta regularizers.
    reg_loss = 0

    # Validation regularisers.
    agent_out_on_valid = agent_rollout_on_valid.agent_outs
    z_a = utils.batch_lookup(agent_out_on_valid['z'][:-1], actions_on_valid)
    y_entropy_loss = -jnp.mean(
        distrax.Softmax(agent_out_on_valid['y']).entropy()
    )
    z_entropy_loss = -jnp.mean(distrax.Softmax(z_a).entropy())
    reg_loss += 1e-3 * (y_entropy_loss + z_entropy_loss)

    # Train regularisers.
    dp, dy, dz = train_meta_out['pi'], train_meta_out['y'], train_meta_out['z']
    chex.assert_equal_shape_prefix([dp, dy, dz], 3)  # [N, T, B, ...]
    reg_loss += 1e-3 * jnp.mean(jnp.square(jnp.mean(dy, axis=(1, 2, 3))))
    reg_loss += 1e-3 * jnp.mean(jnp.square(jnp.mean(dz, axis=(1, 2, 3))))
    reg_loss += 1e-3 * jnp.mean(jnp.square(jnp.mean(dp, axis=(1, 2, 3))))
    logits = train_meta_out['target_out']['logits'][:, :-1]
    chex.assert_equal_shape([logits, dp])  # [N, T, B, A]
    target_kl_loss = rlax.categorical_kl_divergence(
        jax.lax.stop_gradient(logits), dp
    )
    reg_loss += 1e-2 * jnp.mean(target_kl_loss)

    # Meta loss.
    meta_loss = rl_loss.mean() + reg_loss

    meta_log = dict(
        adv=value_out.adv.mean(),
        normalized_adv=value_out.normalized_adv.mean(),
        entropy=distrax.Softmax(logits_on_valid).entropy().mean(),
        value=value_out.value.mean(),
        val_importance_weight=jnp.mean(jnp.minimum(value_out.rho, 1.0)),
        meta_loss=meta_loss,
        rl_loss=rl_loss,
        reg_loss=reg_loss,
    )
    new_agent_state = MetaTrainState(
        learner_state=new_learner_state,
        actor_state=new_actor_state,
        value_state=new_value_state,
        env_state=agent_state.env_state,
        env_timestep=agent_state.env_timestep,
    )

    return meta_loss, (new_agent_state, train_metrics, meta_log)

  meta_grads, outputs = jax.grad(_outer_loss, has_aux=True)(
      update_rule_params, agent_state, train_rollouts, valid_rollout
  )
  new_agent_state, train_metrics, meta_log = outputs
  if axis_name is not None:
    (meta_grads, train_metrics, meta_log) = jax.lax.pmean(
        (meta_grads, train_metrics, meta_log), axis_name
    )

  return meta_grads, (new_agent_state, train_metrics, meta_log)

In [None]:
def outer_update(
    update_rule_params: types.MetaParams,
    meta_opt_state: optax.OptState,
    agents_states: list[MetaTrainState],
    rng: chex.PRNGKey,
    axis_name: str | None = axis_name,
):

  # Generate inputs.
  train_rollouts = [None] * num_agents
  valid_rollouts = [None] * num_agents
  rng_act, rng_upd = jax.random.split(rng)
  rngs_per_agent_act = jax.random.split(rng_act, num_agents)
  for agent_i in range(num_agents):
    a, state = agents[agent_i], agents_states[agent_i]
    rollouts = [None] * num_inner_steps
    rngs_per_step = jax.random.split(
        rngs_per_agent_act[agent_i], num_inner_steps
    )
    for step_i in range(num_inner_steps):
      state, rollouts[step_i] = a.unroll_actor(
          state, rngs_per_step[step_i], rollout_len
      )
    train_rollouts[agent_i] = utils.tree_stack(rollouts)
    agents_states[agent_i], valid_rollouts[agent_i] = a.unroll_actor(
        state, rngs_per_agent_act[agent_i], rollout_len * 2
    )

  # Calculate meta gradients.
  meta_grads = [None] * num_agents
  rngs_per_agent_upd = jax.random.split(rng_upd, num_agents)
  metrics, meta_log = None, None
  for agent_i in range(num_agents):
    meta_grads[agent_i], (agents_states[agent_i], metrics, meta_log) = (
        outer_grad(
            update_rule_params=update_rule_params,
            agent_state=agents_states[agent_i],
            train_rollouts=train_rollouts[agent_i],
            valid_rollout=valid_rollouts[agent_i],
            rng=rngs_per_agent_upd[agent_i],
            agent=agents[agent_i],
            axis_name=axis_name,
        )
    )

  # Log rewards and proportion positive rewards.
  rewards = [None] * num_agents
  pos_rewards = [None] * num_agents
  neg_rewards = [None] * num_agents

  for agent_i in range(num_agents):
    assert train_rollouts[agent_i] is not None
    r = train_rollouts[agent_i].rewards
    rewards[agent_i] = r.mean()

    pos_rewards[agent_i] = (r > 0).sum()
    neg_rewards[agent_i] = (r < 0).sum()

  # Pass through meta optimizer.
  meta_gradient = jax.tree.map(
      lambda x: x.mean(axis=0), utils.tree_stack(meta_grads)
  )
  meta_update, meta_opt_state = meta_opt.update(meta_gradient, meta_opt_state)
  update_rule_params = optax.apply_updates(update_rule_params, meta_update)

  meta_log['meta_grad_norm'] = optax.global_norm(meta_gradient)
  meta_log['meta_up_norm'] = optax.global_norm(meta_update)
  meta_log['rewards'] = utils.tree_stack(rewards).mean()
  meta_log['pos_rewards'] = utils.tree_stack(pos_rewards).mean()
  meta_log['neg_rewards'] = utils.tree_stack(neg_rewards).mean()

  return update_rule_params, meta_opt_state, agents_states, metrics, meta_log

In [None]:
devices = jax.devices()
if axis_name is not None:
  jit_outer_update = jax.pmap(
      outer_update, axis_name=axis_name, devices=devices
  )
else:
  jit_outer_update = jax.jit(outer_update)

num_steps = 1000
(rng,) = jax.random.split(rng_key, 1)

meta_opt = optax.adam(5e-4)
meta_opt_state = meta_opt.init(update_rule_params)

meta_log = {}
metrics = {}

In [None]:
# Run meta-training (note that compilation can take time!).

if axis_name is None:
  step_update_rule_params = update_rule_params
  step_meta_opt_state = meta_opt_state
  step_agents_states = agents_states
else:
  step_update_rule_params = jax.device_put_replicated(
      update_rule_params, devices
  )
  step_meta_opt_state = jax.device_put_replicated(meta_opt_state, devices)
  step_agents_states = jax.device_put_replicated(agents_states, devices)

for meta_step in tqdm.tqdm(range(num_steps)):
  if meta_step in metrics:  # to support interrupting and continuing
    continue

  rng, step_rngs = jax.random.split(rng)
  if axis_name is not None:
    step_rngs = jax.random.split(step_rngs, len(devices))

  (
      step_update_rule_params,
      step_meta_opt_state,
      step_agents_states,
      metrics[meta_step],
      meta_log[meta_step],
  ) = jit_outer_update(
      update_rule_params=step_update_rule_params,
      meta_opt_state=step_meta_opt_state,
      agents_states=step_agents_states,
      rng=step_rngs,
  )
  metrics[meta_step], meta_log[meta_step] = jax.device_get(
      (metrics[meta_step], meta_log[meta_step])
  )

if axis_name is not None:
  metrics, meta_log = utils.gather_from_devices((metrics, meta_log))
  metrics, meta_log = jax.tree.map(lambda x: x.mean(0), (metrics, meta_log))

In [None]:
meta_log_cpu = jax.device_get(meta_log)
steps = np.sort(np.unique(list(meta_log_cpu.keys())))
rows = []
for i in steps:
  for key in (
      'meta_grad_norm',
      'meta_up_norm',
      'meta_loss',
      'reward',
      'pos_rewards',
      'neg_rewards',
  ):
    rows.append(dict(step=i, value=float(meta_log_cpu[i][key]), f=key))

In [None]:
import pandas as pd
import seaborn as sns

df = pd.DataFrame(rows)
sns.relplot(
    data=df[df.f.isin(['pos_rewards', 'neg_rewards'])],
    x='step',
    y='value',
    kind='line',
    hue='f',
    errorbar=None,
    aspect=1.5,
)
plt.show()

In [None]:
sns.relplot(
    data=df[
        df.f.isin(['meta_grad_norm', 'meta_loss', 'meta_up_norm', 'rewards'])
    ],
    x='step',
    y='value',
    kind='line',
    col='f',
    errorbar=None,
    aspect=1.5,
    facet_kws={'sharey': False, 'sharex': True},
)
plt.show()