# Entropic Desired Dynamics for Intrinsic ConTrol (EDDICT), a self-contained JAX implementation

This is a simplified version of the code used in the EDDICT paper (to appear at NeurIPS 2021). In this stand-alone Google Colab, EDDICT is trained on a continuous grid world with an uncontrollable distractor. The resulting latent representations can then be seen to yield an interpretable model of the controllable aspects of the environment (i.e. the $(x,y)$ coordinates) while being invariant to the uncontrollable aspects (i.e. the distractor $(x,y)$ coordinates).

## LICENSE

Copyright 2021 DeepMind Technologies Limited.


Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
# @title (Optional) Install JAX 0.2.25 with CUDA support
!pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
!pip install --upgrade "jax[cuda]==0.2.25" -f https://storage.googleapis.com/jax-releases/jax_releases.html  # Note: wheels only available on linux.

In [None]:
#@title Imports
import functools
import dataclasses
import datetime
import math
import operator

import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, random, lax
import matplotlib.pyplot as plt
from typing import Callable, NewType, Optional, Tuple, Union
from typing_extensions import Protocol
try:
  import chex
except ImportError:
  !pip install git+git://github.com/deepmind/chex.git@v0.1.0
  import chex
try:
  import haiku as hk
except ImportError:
  !pip install git+git://github.com/deepmind/dm-haiku.git@v0.0.5
  import haiku as hk
try:
  import optax
except ImportError:
  !pip install git+git://github.com/deepmind/optax.git@v0.1.0
  import optax
try:
  import rlax
except ImportError:
  !pip install git+git://github.com/deepmind/rlax.git@b652c45382605d3bf2c7db837364deda19819fce
  import rlax


try:
  import jax.tools.colab_tpu
  try:
    jax.tools.colab_tpu.setup_tpu()
  except KeyError:  # Not on a TPU Colab backend.
    pass
except ImportError:
  pass

In [None]:
# @title Environment Dynamics

# Define some type aliases. These don't actually help us in Colab, but if
# we were to run through a type checker we'd get helpful errors if we tried
# to pass actions where states were expected, etc.
Actions = NewType('Action', chex.Array)
States = NewType('State', chex.Array)
LatentCodes = NewType('LatentCode', chex.Array)
LatentCodeDeltas = NewType('LatentCodeDeltas', chex.Array)
QValues = NewType('QValues', chex.Array)


class Policy(Protocol):
  """Interface for an (unconditional) policy."""

  def __call__(self, rng_key: chex.PRNGKey, states: States) -> Actions:
    """Generate an action from the environment state and an RNG state.

    Args:
      rng_key: PRNGKey to use for any stochasticity in action selection.
      states: A batch of environment states.
    Returns:
      A batch of integer actions, with shape equal to `states.shape[:-1]`.
    """


def sample_ball(
  rng_key: chex.PRNGKey,
  size: Union[int, chex.Shape],
  dim: int = 2,
) -> jnp.ndarray:
  """Sample from a uniform on the unit ball i.e. ||x|| <= 1.

  Args:
    rng_key: A JAX PRNGKey.
    size: An integer or tuple of integers indicating the leading
      dimensions (e.g. batch size).
    dim: The dimension of the ball to generate within.

  Returns:
    An array with `size` as its leading dimension(s) and `dim` as its
    final dimension where every `dim`-dimensional element is a sample
    from the unit `dim`-ball.

  See:
    Muller, M.E. (1959). "A note on a method for generating points
    uniformly on n-dimensional spheres". Communications of the
    ACM, Volume 2 Issue 4, pp19-20.
  """
  normal_key, unif_key = jax.random.split(rng_key, 2)
  size = (size,) if isinstance(size, int) else size
  direction = jax.random.normal(normal_key, size + (dim,))
  norm = jnp.linalg.norm(direction, axis=-1, keepdims=True)
  magnitude = jax.random.uniform(unif_key, size + (1,)) ** (1 / dim)
  return direction / norm * magnitude


@chex.dataclass(frozen=True)
class Unroll:
  """The result of unrolling environment dynamics according to a policy.

  `states` is expected to have one more element along its leading dimension
  than `actions`, to account for the state reached after the last action.
  """
  actions: Actions
  states: States


@chex.dataclass(frozen=True)
class Environment:
  """Defines dynamics for a noisy point mass with distractor.

  The state of the environment consists of 2 (x, y) pairs as a 4-vector.
  the first pair are under the agent's control: the agent takes a step
  in one of the 4 compass directions or 4 diagonal directions (or stays
  still) with step length `action_size`, and is then displaced to a
  point uniformly distributed in the circle of radius `noise_size`
  centered on its position after applying the deterministic effect
  of the chosen action. The distractor (x, y) coordinates are subject to a
  a random walk under the same kind ofcircular noise but with a radius of
  `action_size`; the agent's actions do not influence them whatsoever.

  This class is stateless but for its configuration, and simply provides
  methods to generate/mutate states.
  """
  action_size: float = 0.1
  noise_size: float = 0.2

  STATE_DIM = 4
  UNSCALED_ACTION_EFFECTS = np.array(
      [[0., 0.],
        [-1., 0.],
        [0., -1.],
        [1., 0.],
        [0., 1.],
        [-np.sqrt(.5), -np.sqrt(.5)],
        [-np.sqrt(.5), np.sqrt(.5)],
        [np.sqrt(.5), -np.sqrt(.5)],
        [np.sqrt(.5), np.sqrt(.5)],
      ],
      dtype=np.float32,
  )

  @property
  def num_actions(self) -> int:
    """Return the number of actions available in the environment."""
    return self.UNSCALED_ACTION_EFFECTS.shape[0]

  def initialize(
      self,
      rng_key: chex.PRNGKey,
      size: Union[int, chex.Shape],
  ) -> jnp.ndarray:
    """Generate a set of initial states of the environment.

    Args:
      rng_key: A PRNGKey to use for random number generation.
      size: An integer or tuple of integers indicating the leading
        dimensions (e.g. batch size).
    Returns:
      An array with `size` as its leading dimension(s) and 4 as its
      final dimension where every element along the final axis represents
      an initial state of the environment drawn from the uniform
      distribution on [-1, 1]^4.
    """
    size = (size,) if isinstance(size, int) else size
    return jax.random.uniform(
        rng_key,
        size + (self.STATE_DIM,),
        minval=-1,
        maxval=1,
    )

  def transition(
      self,
      rng_key: chex.PRNGKey,
      states: States,
      actions: Actions,
      stop_gradient: bool = True,
  ) -> States:
    """Step the environment dynamics.

    Args:
      rng_key: A PRNGKey to use for random number generation.
      states: An array representing the current state of the environment.
      actions: An array of actions with one less dimension than `states`
        and the same leading dimensions representing actions to be taken.
      stop_gradient: Boolean indicating whether to wrap resulting states
        in a `jax.lax.stop_gradient` so that the environment cannot be
        differentiated through (defaults to True).
    Returns:
      A new set of states with the same shape as `states`, resulting from
      applying `actions` and advancing the dynamics one step.
    """
    chex.assert_axis_dimension(states, -1, Environment.STATE_DIM)
    chex.assert_equal_shape_prefix([states, actions], actions.ndim)

    # Sample from the ball twice for each batch member and reshape to
    # horizontally stack pairs beside each other.
    noise = (
        sample_ball(rng_key, (2,) + actions.shape, 2)
        .reshape(actions.shape + (4,))
    )
    # Scale state dimensions 0 and 1 by noise_size, 2 and 3 by action_size.
    noise_gain = jnp.array([self.noise_size] * 2 + [self.action_size] * 2)

    # New state is the sum of action effects and noise.
    all_action_effects = self.action_size * self.UNSCALED_ACTION_EFFECTS

    # Add zeros for distractor dimensions.
    action_effect = jnp.concatenate([
      all_action_effects[actions],
      jnp.zeros_like(all_action_effects[actions]),
    ], axis=-1)
    displacement = action_effect + noise * noise_gain
    s_prime = jnp.clip(states + displacement, -1, 1)

    return jax.lax.select(stop_gradient, lax.stop_gradient(s_prime), s_prime)

  @functools.partial(jax.jit, static_argnums=1)
  def unroll(
      self,
      policy: Policy,
      rng_keys: chex.PRNGKey,
      initial: States,
      stop_gradient: bool = True,
  ) -> Unroll:
    """Unroll a trajectory from an initial state according to a policy.

    Args:
      policy: A callable taking a PRNGKey and a batch of states of the
        environment and returning actions.
      rng_keys: A pre-split PRNGKey with leading dimension equal to the
        length of the desired trajectory.
      initial: A batch of initial states for the trajectories.
      stop_gradient: A boolean indicating whether to place a `stop_gradient`
        around the environment state so that it cannot be differentiated
        (defaults to True).
    Returns:
      An Unroll containing `len(rng_keys)` actions and `len(rng_keys) + 1`
      states for each member of the batch.
    """
    def loop_body(
        states: States,
        key: chex.PRNGKey,
    ) -> Tuple[States, Unroll]:
      """Samples an action from the policy, step the environment dynamics.

      Args:
        states: The current state(s) of the environment.
        key: The PRNGKey to use for this step.

      Returns:
        A tuple of the next state(s) of the environment and an `Unroll` pair
        containing `states` and the action(s) sampled from it.
      """
      action_key, state_key = jax.random.split(key)
      actions = policy(action_key, states)
      new_states = self.transition(state_key, states, actions, stop_gradient)

      # N.B. states returned as part of unroll are the ones passed in as an
      # argument. The new state is only passed to the next iteration.
      return new_states, Unroll(actions=actions, states=states)

    # We will want to concatenate the final state with the unroll, and thus
    # we will end up with one more state than action.
    final, unroll = jax.lax.scan(loop_body, initial, rng_keys)
    all_states = jnp.concatenate([unroll.states, final[jnp.newaxis]])
    return unroll.replace(states=all_states)


@dataclasses.dataclass(frozen=True)
class UniformRandomPolicy:
  """A trivial random uniform policy for the above environment.

  Defined as a dataclass so that the hash depends only on the number of
  actions. Defined in the cell above the plotting code so that the class
  isn't redefined every time the cell is refreshed, which would invalidate
  the JIT compile cache for `env.unroll`.
  """
  num_actions: int

  def __call__(self, rng_key: chex.PRNGKey, states: States) -> Actions:
    action_shape = states.shape[:-1]
    return jax.random.randint(rng_key, action_shape, 0, self.num_actions)

In [None]:
#@title Plot an example trajectory (uniform random policy) { run: "auto" }

trajectory_seed = 0  # @param { 'type': 'slider' }
trajectory_length = 1000  # @param { 'type': 'slider', 'min': 100, 'max': 2000, 'step': 100 }
trajectory_action_size = 0.1  # @param { 'type': 'slider', 'min': 0.1, 'max': 0.5, 'step': 0.01 }
trajectory_noise_size = 0.2  # @param { 'type': 'slider', 'min': 0.0, 'max': 0.5, 'step': 0.01 }


def plot_example_trajectory():
  """Plot an example trajectory from the hyperparameters specified above."""
  env = Environment(
      action_size=trajectory_action_size,
      noise_size=trajectory_noise_size,
  )
  keys = jax.random.split(
      jax.random.PRNGKey(trajectory_seed),
      trajectory_length + 1,
  )
  initial = env.initialize(keys[0], ())
  # N.B. This will re-jit for different values of trajectory_length. Everything
  # else should be fast to recompute.

  policy = UniformRandomPolicy(env.num_actions)
  trajectory = env.unroll(policy, keys[1:], initial)

  x, y = trajectory.states.T[:2]
  plt.scatter(x, y, c=np.arange(trajectory_length + 1), linewidths=1)
  plt.title("Trajectory in controllable state dimensions (color = time)")

plot_example_trajectory()

In [None]:
# @title Network definitions

class ConditionalQNetwork(Protocol):
  """Interface for a (conditional) Q-network."""
  def __call__(self, state: States, desired_z: LatentCodes) -> QValues:
    """Compute Q-values (action values) for an environment state and code.

    Args:
      rng_key: PRNGKey to use for any stochasticity in action selection.
      states: A (batch of) environment state(s).
    Returns:
      A (batch of) Q-values, with leading dimensions `states.shape[:-1]`
      and a final axis with size equal to the number of environment actions.
    """


class CodePredictor(Protocol):
  """Interface for a mapping from environment states to latent codes."""
  def __call__(self, states: States) -> LatentCodes:
    """Predict a (batch of) latent code(s) from a (batch of) environment states.

    Args:
      states: A (batch of) environment state(s).
    Returns:
      A (batch of) latent code(s) corresponding to the environment state(s).
    """


@dataclasses.dataclass(frozen=True)
class Networks:
  """A bundle of a Q network and a latent code predictor."""
  q_network: hk.Transformed
  predictor: hk.Transformed

  @classmethod
  def build(cls, env: Environment, hid_dim: int, code_dim: int) -> 'Networks':
    """Build a pair of networks from an environment and hyperparameters."""

    def predictor(states: States) -> LatentCodes:
      """A simple multilayer perceptron that maps states to codes."""
      chex.assert_axis_dimension(states, -1, Environment.STATE_DIM)
      return hk.nets.MLP(
          [hid_dim, hid_dim, code_dim],
          name='predictor',
      )(states)

    def q_network(state: States, desired_z: LatentCodes) -> QValues:
      """A Q-network that concatenates a desired latent vector."""
      chex.assert_axis_dimension(state, -1, Environment.STATE_DIM)
      chex.assert_axis_dimension(desired_z, -1, code_dim)

      if state.ndim > desired_z.ndim:
        # Tile desired_z along leading dimensions when doing batched evaluation.
        leading_dims = state.ndim - desired_z.ndim
        tiling = state.shape[:leading_dims] + (1,) * desired_z.ndim
        desired_z = jnp.tile(desired_z, tiling)
      return hk.nets.MLP(
          [hid_dim, hid_dim, env.num_actions],
          name='q_network',
      )(jnp.concatenate([state, desired_z], axis=-1))

    return cls(
        # Our networks are deterministic so remove the RNG argument from apply.
        q_network=hk.without_apply_rng(hk.transform(q_network)),
        predictor=hk.without_apply_rng(hk.transform(predictor)),
    )

  def init(self, rng_key: chex.PRNGKey) -> hk.Params:
    """Initialize both networks, return a joint parameters container.

    Args:
      rng_key: A PRNGKey.
    Returns:
      A `hk.Params` containing the parameters for both the Q-network
      and the predictor, which can be passed to either `apply` function.
    """
    p_key, q_key = jax.random.split(rng_key)

    # We know that the states will always be the same dimension.
    states = jnp.zeros(Environment.STATE_DIM)

    # Initialize the predictor first, so we can encode our dummy state
    # and states.
    predictor_params = self.predictor.init(p_key, states)

    # Encode to get a code value, and use it to initialize the Q network.
    z = self.predictor.apply(predictor_params, states)
    q_network_params = self.q_network.init(q_key, states, z)

    # Merge together the parameter containers. The namespaces do not overlap
    # so this will work with either transformed functions.
    return hk.data_structures.merge(predictor_params, q_network_params)

  def with_params(
      self,
      params: hk.Params
  ) -> Tuple[ConditionalQNetwork, CodePredictor]:
    """Return callables that curry (cache) a set of parameters for convenience.

    Args:
      params: A Haiku parameter set.
    Returns:
      A pair of `functools.partial` objects for the Q-network and predictor,
      respectively, that respect the `ConditionalQNetwork` and `CodePredictor`
      interfaces defined above.
    """
    return tuple(functools.partial(getattr(self, n).apply, params)
                 for n in ('q_network', 'predictor'))

In [None]:
# @title Actor Loop

@dataclasses.dataclass(frozen=True)
class TrainingConfig:
  """All of the constant configuration that doesn't change during training.

  Bundle this together so that our top-level functions don't have a dozen
  arguments and we don't rely on a cluttered global namespace, which leads to
  confusing bugs.
  """
  nets: Networks
  optimizer: optax.GradientTransformation
  env: Environment
  code_dim: int
  gamma: float
  goal_duration: int
  train_epsilon: float
  evaluation_epsilon: float
  evaluation_batch_size: int


@chex.dataclass(frozen=True)
class TrainingState:
  """All of the state that _does_ change during training."""
  env_state: States
  params: hk.Params
  opt_state: optax.OptState
  rng_key: chex.PRNGKey


@chex.dataclass(frozen=True)
class Batch:
  """A batch of data generated by sampling code deltas and acting."""
  actions: Actions
  states: States
  deltas: LatentCodeDeltas


def get_batch(
    env: Environment,
    nets: Networks,
    state: TrainingState,
    goal_duration: int,
    epsilon: float,
) -> Batch:
  """Act in the provided environment using the given network.

  Args:
    env: The `Environment` instance defining the dynamics.
    nets: A `Networks` bundle of a Q-network and predictor.
    state: The current `TrainingState`.
    goal_duration: The number of actions to take in the environment
      per goal period.
    epsilon: Value to use for epsilon-greedy exploration.
  Returns:
    A `Batch` of unrolled states, actions and the code deltas added
    to the encoded initial state and used to condition the Q network when
    acting.
  """
  q_network, predictor = nets.with_params(state.params)
  chex.assert_rank(state.env_state, 2)
  batch_size = state.env_state.shape[0]

  # Split one key per step, plus one for sampling deltas.
  keys = jax.random.split(state.rng_key, goal_duration + 1)
  delta_key, step_keys = keys[0], keys[1:]

  initial_z = predictor(state.env_state)
  deltas = sample_ball(delta_key, batch_size, initial_z.shape[-1])
  desired_z = initial_z + deltas

  def batched_epsilon_greedy(key: chex.PRNGKey, states: jnp.ndarray) -> Actions:
    """Execute an epsilon greedy behavior policy conditioned on `desired_z`.

    RLax epsilon greedy is unbatched, meaning we have to split the RNG key and
    use `jax.vmap`.
    """
    chex.assert_rank(states, 2)
    batch_keys = jax.random.split(key, batch_size)
    q_values = q_network(states, desired_z)
    return jax.vmap(rlax.epsilon_greedy(epsilon).sample)(batch_keys, q_values)

  unroll = env.unroll(batched_epsilon_greedy, step_keys, state.env_state)

  # Chex dataclasses implementing the mapping interface by default, so we can
  # use them just like dicts with the ** operator.
  return Batch(deltas=deltas, **unroll)

In [None]:
#@title Loss

PMAP_AXIS = 'devices'

@chex.dataclass(frozen=True)
class Statistics:
  """Statistics computed along with the loss."""
  value_loss: float
  predictor_loss: float
  total_loss: float


def eddict_loss(
    nets: Networks,
    params: hk.Params,
    batch: Batch,
    gamma: float,
) -> Tuple[float, Statistics]:
  """Loss function for EDDICT.

  The EDDICT loss is actually a sum of two losses that operate on
  disjoint parameter sets, treated as one loss here for convenience.
  """
  chex.assert_rank([batch.actions, batch.deltas, batch.states], [2, 2, 3])
  chex.assert_equal_shape_prefix([batch.states[:-1], batch.actions], 2)

  q_network, predictor = nets.with_params(params)

  desired_z = predictor(batch.states[0]) + batch.deltas
  q_values = q_network(batch.states, lax.stop_gradient(desired_z))
  achieved_z = predictor(batch.states[-1])

  # The squared Euclidean distance between the conditioned and desired codes
  # serve as the loss for the predictor and (negated) reward for the policy.
  code_errors = jnp.square(achieved_z - desired_z).sum(axis=-1)

  # Give a reward equal to the negative squared code error at the last step.
  num_transitions = batch.actions.shape[0]
  reward_mask = jnp.arange(num_transitions) == (num_transitions - 1)
  rewards = jnp.outer(reward_mask, -code_errors)

  # Do not bootstrap from the final state (discount 0).
  # Discounts are the same for every batch item, so this has no batch dimension.
  discounts = (1 - reward_mask) * gamma

  # vmap twice, over time and inside that over batch.
  q_tm1, q_t = q_values[:-1], q_values[1:]
  td_errors = jax.vmap(
      # Don't vmap the batch dimension over discounts (which don't have one).
      jax.vmap(rlax.q_learning, in_axes=[0, 0, 0, None, 0]),
  )(q_tm1, batch.actions, rewards, discounts, q_t)

  value_loss = jnp.square(td_errors).sum(axis=0).mean()
  predictor_loss = code_errors.mean()

  stats = Statistics(
      value_loss=value_loss,
      predictor_loss=predictor_loss,
      total_loss=value_loss + predictor_loss,
  )
  # Average the loss across devices.
  stats = jax.lax.pmean(stats, PMAP_AXIS)
  return stats.total_loss, stats

In [None]:
# @title Training Step

@functools.partial(
    jax.pmap,
    axis_name=PMAP_AXIS,
    static_broadcasted_argnums=0,
    donate_argnums=1,
)
def training_step(
    config: TrainingConfig,
    state: TrainingState,
) -> TrainingState:
  """Do a full training iteration: sample a data batch and update parameters."""
  rng_key, next_key = jax.random.split(state.rng_key)
  batch = get_batch(
      config.env,
      config.nets,
      state.replace(rng_key=rng_key),
      config.goal_duration,
      config.train_epsilon,
  )

  grad_loss = jax.grad(eddict_loss, argnums=1, has_aux=True)
  dev_grads, _ = grad_loss(config.nets, state.params, batch, config.gamma)

  # The gradient of an average should be the average of the gradients.
  grads = jax.lax.pmean(dev_grads, PMAP_AXIS)
  updates, new_opt_state = config.optimizer.update(grads, state.opt_state)
  new_params = optax.apply_updates(state.params, updates)

  return state.replace(
      params=optax.apply_updates(state.params, updates),
      env_state=batch.states[-1],  # Next step starts from last batch of states.
      opt_state=new_opt_state,
      rng_key=next_key,
  )


@functools.partial(
    jax.pmap,
    axis_name=PMAP_AXIS,
    static_broadcasted_argnums=0,
    donate_argnums=1,
)
def evaluate(
    config: TrainingConfig,
    state: TrainingState,
) -> Statistics:
  """Gather a large batch of trajectories and evaluate diagnostics."""
  # We split the current key but don't replace it in the TrainingState.
  # This means the training results are independent of how often we evaluate.
  env_key, act_key = jax.random.split(state.rng_key)
  env_state = config.env.initialize(env_key, config.evaluation_batch_size)
  batch = get_batch(
      config.env,
      config.nets,
      # Substitute the batch of evaluation start states, and the RNG key
      # we just split. Note that this object isn't returned to the training
      # loop so this doesn't affect how the RNG evolves for training purposes,
      # and our training results should be the same independent of how often
      # we evaluate.
      state.replace(env_state=env_state, rng_key=act_key),
      config.goal_duration,
      config.evaluation_epsilon,
  )
  _, stats = eddict_loss(config.nets, state.params, batch, config.gamma)
  return stats


In [None]:
#@title Network & Environment Setup (changing these settings will also reinitialize parameters) { run: "auto" }

# EDDICT goal duration.
train_goal_duration = 10  # @param {'type': 'slider', 'min': 2, 'max': 100}

# RL hyperparameters.
train_epsilon = 0.1  # @param {'type': 'slider', 'min': 0.05, 'max': 0.5, 'step': 0.05}
train_gamma = 0.99  # @param {'type': 'slider', 'min': 0.95, 'max': 0.999, 'step': 0.001}

# Network hyperparameters.
train_net_hid_dim = 128  # @param {'type': 'integer'}
train_net_code_dim = 2  # The visualizations are based on 2-dimensional codes.

# Environment hyperparameters.
train_env_action_size = 0.1  # @param { 'type': 'slider', 'min': 0.1, 'max': 0.9, 'step': 0.1 }
train_env_noise_size = 0.2  # @param { 'type': 'slider', 'min': 0.1, 'max': 0.9, 'step': 0.1 }

# Optimizer hyperparameters.
train_batch_size_per_device = 256  # @param
train_adam_learning_rate = 1e-4  # @param
train_adam_eps = 1e-2  # @param

# Define the initialization seed separately so that it can vary independently.
train_initialization_seed = 0  # @param {'type': 'integer'}
train_seed = 1  # @param {'type': 'integer'}

# Evaluation hyperparameters.
evaluation_epsilon = 0.1  # @param {'type': 'slider', 'min': 0.05, 'max': 0.5, 'step': 0.05}
evaluation_batch_size_per_device = 1024  # @param {'type': 'integer'}


def _build_config() -> TrainingConfig:
  """Put this in a function so we have less stuff in global namespace."""
  env = Environment(
      action_size=train_env_action_size,
      noise_size=train_env_noise_size,
  )
  nets = Networks.build(env, train_net_hid_dim, train_net_code_dim)
  optimizer = optax.adam(train_adam_learning_rate, eps=train_adam_eps)
  return TrainingConfig(
      nets=nets,
      optimizer=optimizer,
      env=env,
      code_dim=train_net_code_dim,
      gamma=train_gamma,
      goal_duration=train_goal_duration,
      train_epsilon=train_epsilon,
      evaluation_epsilon=evaluation_epsilon,
      evaluation_batch_size=evaluation_batch_size_per_device,
  )

def _construct_state(
    config: TrainingConfig,
) -> TrainingState:
  env_key, train_rng_key = [
      jax.random.split(k, jax.local_device_count())
      for k in jax.random.split(jax.random.PRNGKey(train_seed))
  ]

  init_key = jax.random.PRNGKey(train_initialization_seed)

  params = jax.pmap(config.nets.init)(
      jax.device_put_replicated(init_key, jax.local_devices())
  )
  return TrainingState(
      env_state=jax.pmap(
          config.env.initialize,
          static_broadcasted_argnums=1,
      )(env_key, train_batch_size_per_device),
      params=params,
      opt_state=jax.pmap(config.optimizer.init)(params),
      rng_key=train_rng_key,
  )

train_config = _build_config()
train_state = _construct_state(train_config)

In [None]:
#@title Training Loop

train_iterations = 10000  # @param {'type': 'integer'}
evaluation_period = 1000  # @param {'type': 'integer'}


def log_statistics(iteration: int, stats: Statistics, spaces: int = 2):
  stats = dict(jax.tree_map(operator.itemgetter(0), jax.device_get(stats)))
  print(
      'Iteration ',
      str(iteration).rjust(math.ceil(1 + math.log10(train_iterations))),
      (' ' * spaces).join(
          f'{k} = %9.7f' % v
          for k, v in stats.items()
      )
  )


print(('Training with goal duration %d, batch size %d on %d devices '
       '(total per-step batch size %d, %d transitions per batch)') %
      (train_goal_duration, train_batch_size_per_device,
       jax.local_device_count(),
       train_batch_size_per_device * jax.local_device_count(),
       (jax.local_device_count() * train_batch_size_per_device *
        train_goal_duration)))
print('Evaluating on a batch of size %d per device '
      '(for a total of %d trajectories per evaluation)' %
      (evaluation_batch_size_per_device,
       evaluation_batch_size_per_device * jax.local_device_count()))


train_stats = evaluate(train_config, train_state)
log_statistics(0, train_stats)
print('Beginning training.')
t_start = datetime.datetime.now()
for i in range(train_iterations):
  train_state = training_step(train_config, train_state)
  if (i + 1) % evaluation_period == 0:
    train_stats = evaluate(train_config, train_state)
    log_statistics(i + 1, train_stats)

t_end = datetime.datetime.now()
print(f'Training took {t_end - t_start}')

In [None]:
#@title Visualize Results

def plot_results(
    config: TrainingConfig,
    train_state: TrainingState,
    grid_distractors: bool = False,
):
  # meshgrid over 2 dimensions, random sampling for the other 2
  if grid_distractors:
    point_grid = np.meshgrid(np.linspace(-1.,1., dtype=np.float32), np.linspace(-1.,1., dtype=np.float32))
    points = np.hstack(
        [np.random.uniform(-1, 1, size=points.shape),
         np.stack(point_grid, -1).reshape((-1, 2))]
    )
  else:
    point_grid = np.meshgrid(
      np.linspace(-1.,1., dtype=np.float32),
      np.linspace(-1.,1., dtype=np.float32),
    )
    points = np.stack(point_grid, -1).reshape([-1, 2])
    points = np.concatenate(
        [points, np.random.uniform(-1, 1, size=points.shape)],
        axis=-1,
    )
  states = points
  delta = np.random.uniform(size=(states.shape[0], train_config.code_dim))

  # Retrieve a single copy of the parameters.
  params = jax.tree_map(
      operator.itemgetter(0),
      jax.device_get(train_state.params),
  )
  z = config.nets.predictor.apply(params, states)

  plt.figure(figsize=[20, 5])
  plt.suptitle("EDDICT $z$ vs $(x,y)$ position", y=1.05)
  plt.subplot(1, 4, 1)
  plt.title("$z_0$ vs $z_1$ \n color=$x$")
  plt.axis('off')
  plt.scatter(z[:,0], z[:,1], c=states[:,0])
  plt.subplot(1, 4, 2)
  plt.title("$z_0$ vs $z_1$ \n color=$y$")
  plt.axis('off')
  plt.scatter(z[:,0], z[:,1], c=states[:,1])
  plt.subplot(1, 4, 3)
  plt.title("$x$ vs $y$ \n color=$z_0$")
  plt.axis('off')
  plt.scatter(states[:,0], states[:,1], c=z[:,0])
  plt.subplot(1, 4, 4)
  plt.title("$x$ vs $y$ \n color=$z_1$")
  plt.axis('off')
  plt.scatter(states[:,0], states[:,1], c=z[:,1])

  plt.figure(figsize=[20, 5])
  plt.suptitle("EDDICT Z vs Uncontrollable Distractor $(x,y)$ position", y=1.05)
  plt.subplot(1, 4, 1)
  plt.title("$z_0$ vs $z_1$ \n color=distractor x")
  plt.axis('off')
  plt.scatter(z[:,0], z[:,1], c=states[:,2])
  plt.subplot(1, 4, 2)
  plt.title("$z_0$ vs $z_1$ \n color=distractor y")
  plt.axis('off')
  plt.scatter(z[:,0], z[:,1], c=states[:,3])
  plt.subplot(1, 4, 3)
  plt.title("distractor $x$ vs $y$ \n color=$z_0$")
  plt.axis('off')
  plt.scatter(states[:,2], states[:,3], c=z[:,0])
  plt.subplot(1, 4, 4)
  plt.title("distractor $x$ vs $y$ \n color=$z_1$")
  plt.axis('off')
  plt.scatter(states[:,2], states[:,3], c=z[:,1])

plot_results(train_config, train_state)