# Tutorial: Penalties

In this tutorial you will learn how to use the physics simulator MuJoCo, and Reinforcement Learning to teach two humanoids to take and stop penalties while competing against each other.

## 1 - Kaggle
It's recommended to run this notebook on www.kaggle.com where you can use two T4 GPUs for 30h/week for free.

To import this notebook into Kaggle you need to:
- Login to your Kaggle account
- Create a new notebook
- Click on `File` and then `Import Notebook`
- Select the tab `GitHub`
- Search `goncalog/ai-robotics`
- Select the file `tutorials/penalties.ipynb`
- Click the `Import` button

To run this notebook you can either click the `Run All` button or run each cell individually by clicking the `Run current cell` button.

## 2 - Config

This is the configuration to run the tutorial, it includes:
- Training hyperparameters
- Mujoco environment variables
- Rendering variables

In [None]:
num_timesteps = 20_000_000
num_evals = 9
# num_envs: the number of parallel environments to use for rollouts
num_envs = 2048

# learning_rate: learning rate for ppo loss
learning_rate = 3e-4
# discounting: discounting rate
discounting = 0.97
# episode_length: the length of an environment episode
episode_length = 1000
# normalize_observations: whether to normalize observations
normalize_observations = True
# action_repeat: the number of timesteps to repeat an action
action_repeat = 1
# unroll_length: the number of timesteps to unroll in each environment.
# The PPO loss is computed over `unroll_length` timesteps
unroll_length = 10
# entropy_cost: entropy reward for ppo loss, higher values increase entropy of the policy
entropy_cost = 1e-3
# batch_size: the batch size for each minibatch SGD step
batch_size = 1024
# num_minibatches: the number of times to run the SGD step,
# each with a different minibatch with leading dimension of `batch_size`
num_minibatches = 32
# num_updates_per_batch: the number of times to run the gradient update over
# all minibatches before doing a new environment rollout
num_updates_per_batch = 8
# reward_scaling: float scaling for reward
reward_scaling = 1
# clipping_epsilon: clipping epsilon for PPO loss
clipping_epsilon = 0.3
# gae_lambda: General advantage estimation lambda
gae_lambda = 0.95
# normalize_advantage: whether to normalize advantage estimate
normalize_advantage = True

policy_hidden_layer_sizes = (32,) * 4
value_hidden_layer_sizes = (256,) * 5

training_agent = "striker" # can be striker or keeper
ball_size = 0.15
ball_x = 0.3 # x coordinate of centre of mass
goal_width = 7.32 / 2 # 7.32m
goal_distance = 11 / 2 # 11m
goal_height = 2
post1_y = - goal_width / 2
post2_y = goal_width / 2
torso_index = 2 # index of torso body in mjx data (it contains the head geom)

# Simulation time step in seconds. 
# This is the single most important parameter affecting the speed-accuracy trade-off 
# which is inherent in every physics simulation. 
# Smaller values result in better accuracy and stability
mj_model_timestep = 0.004

num_penalties = 5
tournament_mode = False

## 3 - Install MuJoCo, MJX, and Brax

In [None]:
!pip install mujoco
!pip install mujoco_mjx
!pip install brax

# Check if MuJoCo installation was successful
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but this
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/'
NVIDIA_ICD_CONFIG_FILE = '10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  os.makedirs(NVIDIA_ICD_CONFIG_PATH)
  file_path = os.path.join(NVIDIA_ICD_CONFIG_PATH, NVIDIA_ICD_CONFIG_FILE)
  with open(file_path, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

In [None]:
# Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

# Import MuJoCo, MJX, and Brax
from datetime import datetime
import functools
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.io import html, mjcf, model
from brax.training import distribution, networks

from etils import epath
from flax import linen, struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx

## 4 - Setting up the two Humanoid environment with MJX

MJX is an implementation of MuJoCo written in [JAX](https://jax.readthedocs.io/en/latest/index.html), enabling large batch training on GPU/TPU. In this notebook, we train RL policies with MJX.

Here we implement our environment by adapting the original [Humanoid](https://github.com/google-deepmind/mujoco/blob/546a27ca72397b888e314ee4549bcf12d9fd5957/model/humanoid/humanoid.xml) environment to also include a ball, a goal, and a second Humanoid who acts as a keeper. Notice that `reset` initializes a `State`, and `step` steps through the physics step and reward logic. The reward and stepping logic train the Humanoids to take and stop penalties.


Finally we can implement a real environment. We choose to first implement the Humanoid environment. Notice that `reset` initializes a `State`, and `step` steps through the physics step and reward logic. The reward and stepping logic train the Humanoid to run forwards.

In [None]:
# Humanoid XML
def get_humanoid(player_type: str) -> str:
    """Returns the humanoid XML.

    Parameters:
        player_type (str): can be 'striker' or 'keeper'.

    Returns:
        str: The humanoid XML to generate a MJX model.
    """
    if player_type == "striker":
        x_pos = 0
        x_front = 1 # facing positive x
        cameras = f"""
        <camera name="back_{player_type}" pos="-4.5 -2 2.5" xyaxes="2 -5 0 1 0 3" mode="fixed"/>
        <camera name="side_{player_type}" pos="0.5 -4.5 2.5" xyaxes="1 0 0 0 1 3" mode="fixed"/>
        """
    elif player_type == "keeper":
        x_pos = goal_distance
        x_front = -1 # facing negative x
        cameras = f"""
        <camera name="side_{player_type}" pos="{goal_distance - 0.5} -4.5 2.5" xyaxes="1 0 0 0 1 3" mode="fixed"/>
        """
    else:
        raise Exception(f"{player_type=} isn't supported.")
    
    humanoid_xml = f"""
    <light name="top_{player_type}" pos="0 0 2" mode="trackcom"/>
    {cameras}
    <body name="torso_{player_type}" pos="{x_pos + 0} 0 1.282" childclass="body">
      <freejoint name="root_{player_type}"/>
      <geom name="torso_{player_type}" fromto="0 -.07 0 0 .07 0" size=".07"/>
      <geom name="waist_upper_{player_type}" fromto="{x_front * -.01} -.06 -.12 {x_front * -.01} .06 -.12" size=".06"/>
      <body name="head_{player_type}" pos="0 0 .19">
        <geom name="head_{player_type}" type="sphere" size=".09"/>
        <camera name="egocentric_{player_type}" pos=".09 0 0" xyaxes="0 -1 0 .1 0 1" fovy="80"/>
      </body>
      <body name="waist_lower_{player_type}" pos="{x_front * 0.01} 0 -.26">
        <geom name="waist_lower_{player_type}" fromto="0 -.06 0 0 .06 0" size=".06"/>
        <joint name="abdomen_z_{player_type}" pos="0 0 .065" axis="0 0 {x_front * 1}" range="-45 45" class="joint_big_stiff"/>
        <joint name="abdomen_y_{player_type}" pos="0 0 .065" axis="0 {x_front * 1} 0" range="-75 30" class="joint_big"/>
        <body name="pelvis_{player_type}" pos="0 0 -.165">
          <joint name="abdomen_x_{player_type}" pos="0 0 .1" axis="1 0 0" range="-35 35" class="joint_big"/>
          <geom name="butt_{player_type}" fromto="{x_front * -.02} -.07 0 {x_front * -.02} .07 0" size=".09"/>
          <body name="thigh_right_{player_type}" pos="0 -.1 -.04">
            <joint name="hip_x_right_{player_type}" axis="1 0 0" class="hip_x"/>
            <joint name="hip_z_right_{player_type}" axis="0 0 {x_front * 1}" class="hip_z"/>
            <joint name="hip_y_right_{player_type}" class="hip_y" axis="0 {x_front * 1} 0"/>
            <geom name="thigh_right_{player_type}" fromto="0 0 0 0 .01 -.34" class="thigh"/>
            <body name="shin_right_{player_type}" pos="0 .01 -.4">
              <joint name="knee_right_{player_type}" class="knee" axis="0 {x_front * -1} 0"/>
              <geom name="shin_right_{player_type}" class="shin"/>
              <body name="foot_right_{player_type}" pos="0 0 -.39">
                <joint name="ankle_y_right_{player_type}" class="ankle_y" axis="0 {x_front * 1} 0"/>
                <joint name="ankle_x_right_{player_type}" class="ankle_x" axis="1 0 {x_front * .5}"/>
                <geom name="foot1_right_{player_type}" class="foot1_{player_type}"/>
                <geom name="foot2_right_{player_type}" class="foot2_{player_type}"/>
              </body>
            </body>
          </body>
          <body name="thigh_left_{player_type}" pos="0 .1 -.04">
            <joint name="hip_x_left_{player_type}" axis="-1 0 0" class="hip_x"/>
            <joint name="hip_z_left_{player_type}" axis="0 0 {x_front * -1}" class="hip_z"/>
            <joint name="hip_y_left_{player_type}" class="hip_y" axis="0 {x_front * 1} 0"/>
            <geom name="thigh_left_{player_type}" fromto="0 0 0 0 -.01 -.34" class="thigh"/>
            <body name="shin_left_{player_type}" pos="0 -.01 -.4">
              <joint name="knee_left_{player_type}" class="knee" axis="0 {x_front * -1} 0"/>
              <geom name="shin_left_{player_type}" fromto="0 0 0 0 0 -.3" class="shin"/>
              <body name="foot_left_{player_type}" pos="0 0 -.39">
                <joint name="ankle_y_left_{player_type}" class="ankle_y" axis="0 {x_front * 1} 0"/>
                <joint name="ankle_x_left_{player_type}" class="ankle_x" axis="-1 0 {x_front * -.5}"/>
                <geom name="foot1_left_{player_type}" class="foot1_{player_type}"/>
                <geom name="foot2_left_{player_type}" class="foot2_{player_type}"/>
              </body>
            </body>
          </body>
        </body>
      </body>
      <body name="upper_arm_right_{player_type}" pos="0 -.17 .06">
        <joint name="shoulder1_right_{player_type}" axis="2 {x_front * 1} {x_front * 1}"  class="shoulder"/>
        <joint name="shoulder2_right_{player_type}" axis="0 {x_front * -1} {x_front * 1}" class="shoulder"/>
        <geom name="upper_arm_right_{player_type}" fromto="0 0 0 {x_front * .16} -.16 -.16" class="arm_upper"/>
        <body name="lower_arm_right_{player_type}" pos="{x_front * .18} -.18 -.18">
          <joint name="elbow_right_{player_type}" axis="0 {x_front * -1} {x_front * 1}" class="elbow"/>
          <geom name="lower_arm_right_{player_type}" fromto="{x_front * 0.01} .01 .01 {x_front * 0.17} .17 .17" class="arm_lower"/>
          <body name="hand_right_{player_type}" pos="{x_front * 0.18} .18 .18">
            <geom name="hand_right_{player_type}" zaxis="1 {x_front * 1} 1" class="hand"/>
          </body>
        </body>
      </body>
      <body name="upper_arm_left_{player_type}" pos="0 .17 .06">
        <joint name="shoulder1_left_{player_type}" axis="-2 {x_front * 1} {x_front * -1}" class="shoulder"/>
        <joint name="shoulder2_left_{player_type}" axis="0 {x_front * -1} {x_front * -1}"  class="shoulder"/>
        <geom name="upper_arm_left_{player_type}" fromto="0 0 0 {x_front * 0.16} .16 -.16" class="arm_upper"/>
        <body name="lower_arm_left_{player_type}" pos="{x_front * 0.18} .18 -.18">
          <joint name="elbow_left_{player_type}" axis="0 {x_front * -1} {x_front * -1}" class="elbow"/>
          <geom name="lower_arm_left_{player_type}" fromto="{x_front * 0.01} -.01 .01 {x_front * 0.17} -.17 .17" class="arm_lower"/>
          <body name="hand_left_{player_type}" pos="{x_front * 0.18} -.18 .18">
            <geom name="hand_left_{player_type}" zaxis="1 {x_front * -1} 1" class="hand"/>
          </body>
        </body>
      </body>
    </body>
    """
    return humanoid_xml


def get_contacts(player_type: str) -> str:
    """Returns the contacts XML.

    Parameters:
        player_type (str): can be 'striker' or 'keeper'.

    Returns:
        str: The contacts XML to generate a MJX model.
    """
    if player_type == "striker":
        extra_contacts = ""
    elif player_type == "keeper":
        extra_contacts = f"""
        <pair geom1="torso_{player_type}" geom2="floor"/>
        <pair geom1="waist_upper_{player_type}" geom2="floor"/>
        <pair geom1="head_{player_type}" geom2="floor"/>
        <pair geom1="waist_lower_{player_type}" geom2="floor"/>
        <pair geom1="butt_{player_type}" geom2="floor"/>
        <pair geom1="thigh_right_{player_type}" geom2="floor"/>
        <pair geom1="shin_right_{player_type}" geom2="floor"/>
        <pair geom1="thigh_left_{player_type}" geom2="floor"/>
        <pair geom1="shin_left_{player_type}" geom2="floor"/>
        <pair geom1="upper_arm_right_{player_type}" geom2="floor"/>
        <pair geom1="lower_arm_right_{player_type}" geom2="floor"/>
        <pair geom1="hand_right_{player_type}" geom2="floor"/>
        <pair geom1="upper_arm_left_{player_type}" geom2="floor"/>
        <pair geom1="lower_arm_left_{player_type}" geom2="floor"/>
        <pair geom1="hand_left_{player_type}" geom2="floor"/>
        
        <pair geom1="torso_{player_type}" geom2="ball"/>
        <pair geom1="waist_upper_{player_type}" geom2="ball"/>
        <pair geom1="head_{player_type}" geom2="ball"/>
        <pair geom1="waist_lower_{player_type}" geom2="ball"/>
        <pair geom1="butt_{player_type}" geom2="ball"/>
        <pair geom1="thigh_right_{player_type}" geom2="ball"/>
        <pair geom1="shin_right_{player_type}" geom2="ball"/>
        <pair geom1="foot1_right_{player_type}" geom2="ball"/>
        <pair geom1="foot2_right_{player_type}" geom2="ball"/>
        <pair geom1="thigh_left_{player_type}" geom2="ball"/>
        <pair geom1="shin_left_{player_type}" geom2="ball"/>
        <pair geom1="upper_arm_right_{player_type}" geom2="ball"/>
        <pair geom1="lower_arm_right_{player_type}" geom2="ball"/>
        <pair geom1="hand_right_{player_type}" geom2="ball"/>
        <pair geom1="upper_arm_left_{player_type}" geom2="ball"/>
        <pair geom1="lower_arm_left_{player_type}" geom2="ball"/>
        <pair geom1="hand_left_{player_type}" geom2="ball"/>
        """
    else:
        raise Exception(f"{player_type=} isn't supported.")
    
    contact_xml = f"""
    <exclude body1="waist_lower_{player_type}" body2="thigh_right_{player_type}"/>
    <exclude body1="waist_lower_{player_type}" body2="thigh_left_{player_type}"/>
    <pair geom1="foot1_left_{player_type}" geom2="floor"/>
    <pair geom1="foot1_right_{player_type}" geom2="floor"/>
    <pair geom1="foot2_left_{player_type}" geom2="floor"/>
    <pair geom1="foot2_right_{player_type}" geom2="floor"/>
    <pair geom1="foot1_left_{player_type}" geom2="ball"/>
    <pair geom1="foot2_left_{player_type}" geom2="ball"/>
    {extra_contacts}
    """
    return contact_xml


def get_actuators(player_type: str) -> str:
    """Returns the actuators XML.

    Parameters:
        player_type (str): can be 'striker' or 'keeper'.

    Returns:
        str: The actuators XML to generate a MJX model.
    """
    actuator_xml = f"""
    <motor name="abdomen_y_{player_type}"       gear="40"  joint="abdomen_y_{player_type}"/>
    <motor name="abdomen_z_{player_type}"       gear="40"  joint="abdomen_z_{player_type}"/>
    <motor name="abdomen_x_{player_type}"       gear="40"  joint="abdomen_x_{player_type}"/>
    <motor name="hip_x_right_{player_type}"     gear="40"  joint="hip_x_right_{player_type}"/>
    <motor name="hip_z_right_{player_type}"     gear="40"  joint="hip_z_right_{player_type}"/>
    <motor name="hip_y_right_{player_type}"     gear="120" joint="hip_y_right_{player_type}"/>
    <motor name="knee_right_{player_type}"      gear="80"  joint="knee_right_{player_type}"/>
    <motor name="ankle_x_right_{player_type}"   gear="20"  joint="ankle_x_right_{player_type}"/>
    <motor name="ankle_y_right_{player_type}"   gear="20"  joint="ankle_y_right_{player_type}"/>
    <motor name="hip_x_left_{player_type}"      gear="40"  joint="hip_x_left_{player_type}"/>
    <motor name="hip_z_left_{player_type}"      gear="40"  joint="hip_z_left_{player_type}"/>
    <motor name="hip_y_left_{player_type}"      gear="120" joint="hip_y_left_{player_type}"/>
    <motor name="knee_left_{player_type}"       gear="80"  joint="knee_left_{player_type}"/>
    <motor name="ankle_x_left_{player_type}"    gear="20"  joint="ankle_x_left_{player_type}"/>
    <motor name="ankle_y_left_{player_type}"    gear="20"  joint="ankle_y_left_{player_type}"/>
    <motor name="shoulder1_right_{player_type}" gear="20"  joint="shoulder1_right_{player_type}"/>
    <motor name="shoulder2_right_{player_type}" gear="20"  joint="shoulder2_right_{player_type}"/>
    <motor name="elbow_right_{player_type}"     gear="40"  joint="elbow_right_{player_type}"/>
    <motor name="shoulder1_left_{player_type}"  gear="20"  joint="shoulder1_left_{player_type}"/>
    <motor name="shoulder2_left_{player_type}"  gear="20"  joint="shoulder2_left_{player_type}"/>
    <motor name="elbow_left_{player_type}"      gear="40"  joint="elbow_left_{player_type}"/>
    """
    return actuator_xml


ball_material = """
    <texture name="texgeom" type="cube" builtin="flat" mark="cross" width="128" height="128" 
        rgb1="0.6 0.6 0.6" rgb2="0.6 0.6 0.6" markrgb="1 1 1"/>
    <material name="ball" texture="texgeom" texuniform="true" rgba=".1 .9 .1 1" />
    """
ball_default = f"""
    <default class="ball">
      <geom type="sphere" material="ball" size="{ball_size}" mass="0.045" friction="0.7 0.075 0.075" solref="0.02 1.0"/>
    </default>
    """
ball_body = f"""
    <body name="ball" pos="{ball_x} 0 {ball_size}" quat="0.632456 -0.632456 0.316228 0.316228">
      <freejoint/>
      <geom class="ball" name="ball"/>
    </body>
    """
goal_body = f"""
    <body name="post1" pos="{goal_distance} {post1_y} 0">
        <geom name="post_geom1" type="box" size="0.05 0.05 {goal_height}" rgba="1 1 1 1"/>
    </body>
    <body name="post2" pos="{goal_distance} {post2_y} 0">
        <geom name="post_geom2" type="box" size="0.05 0.05 {goal_height}" rgba="1 1 1 1"/>
    </body>

    <body name="bar" pos="{goal_distance} 0 {goal_height}">
        <geom name="bar_geom" type="box" size="0.05 {goal_width / 2} 0.05" rgba="1 1 1 1"/>
    </body>
"""
    
xml = f"""
<mujoco model="Humanoid and a ball">
  <option timestep="{mj_model_timestep}" iterations="1" ls_iterations="4">
    <flag eulerdamp="disable"/>
  </option>

  <visual>
    <map force="0.1" zfar="30" znear="0.1" />
    <rgba haze="0.15 0.25 0.35 1"/>
    <global offwidth="2560" offheight="1440" elevation="-20" azimuth="120"/>
    <quality shadowsize="4096" offsamples="8"/>
  </visual>

  <statistic center="0 0 0.7" extent="4"/>

  <asset>
    <texture type="skybox" builtin="gradient" rgb1=".3 .5 .7" rgb2="0 0 0" width="32" height="512"/>
    <texture name="body" type="cube" builtin="flat" mark="cross" width="128" height="128" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" markrgb="1 1 1" random="0.01"/>
    <material name="body" texture="body" texuniform="true" rgba="0.8 0.6 .4 1"/>

    <texture type="skybox" builtin="gradient" width="512" height="512" rgb1=".4 .6 .8" rgb2="0 0 0"/>
    <texture name="texplane" type="2d" builtin="checker" rgb1=".4 .4 .4" rgb2=".6 .6 .6"
             width="512" height="512"/>
    <material name='MatPlane' reflectance='0.3' texture="texplane" texrepeat="1 1" texuniform="true"
              rgba=".7 .7 .7 1"/>
    {ball_material}
  </asset>

  <default>
    {ball_default}
    <motor ctrlrange="-1 1" ctrllimited="true"/>
    <default class="body">

      <!-- geoms -->
      <geom type="capsule" condim="3" friction=".7" solimp=".9 .99 .003" solref=".015 1" material="body" contype="0" conaffinity="0"/>
      <default class="thigh">
        <geom size=".06"/>
      </default>
      <default class="shin">
        <geom fromto="0 0 0 0 0 -.3"  size=".049"/>
      </default>
      <default class="foot">
        <geom size=".027"/>
        <default class="foot1_striker">
          <geom fromto="-.07 -.01 0 .14 -.03 0"/>
        </default>
        <default class="foot2_striker">
          <geom fromto="-.07 .01 0 .14  .03 0"/>
        </default>
        <default class="foot1_keeper">
          <geom fromto="-.14 -.03 0 .07 -.01 0"/>
        </default>
        <default class="foot2_keeper">
          <geom fromto="-.14 .03 0 .07  .01 0"/>
        </default>
      </default>
      <default class="arm_upper">
        <geom size=".04"/>
      </default>
      <default class="arm_lower">
        <geom size=".031"/>
      </default>
      <default class="hand">
        <geom type="sphere" size=".04"/>
      </default>

      <!-- joints -->
      <joint type="hinge" damping=".2" stiffness="1" armature=".01" limited="true" solimplimit="0 .99 .01"/>
      <default class="joint_big">
        <joint damping="5" stiffness="10"/>
        <default class="hip_x">
          <joint range="-30 10"/>
        </default>
        <default class="hip_z">
          <joint range="-60 35"/>
        </default>
        <default class="hip_y">
          <joint range="-150 20"/>
        </default>
        <default class="joint_big_stiff">
          <joint stiffness="20"/>
        </default>
      </default>
      <default class="knee">
        <joint pos="0 0 .02" range="-160 2"/>
      </default>
      <default class="ankle">
        <joint range="-50 50"/>
        <default class="ankle_y">
          <joint pos="0 0 .08" stiffness="6"/>
        </default>
        <default class="ankle_x">
          <joint pos="0 0 .04" stiffness="3"/>
        </default>
      </default>
      <default class="shoulder">
        <joint range="-85 60"/>
      </default>
      <default class="elbow">
        <joint range="-100 50" stiffness="0"/>
      </default>
    </default>
  </default>

  <worldbody>
    <light directional="true" diffuse=".8 .8 .8" pos="0 0 10" dir="0 0 -10"/>
    <geom name="floor" type="plane" size="{goal_distance * 2} 5 .05" material="MatPlane" condim="3"/>

    {get_humanoid("striker")}
    {get_humanoid("keeper")}

    {goal_body}
    {ball_body}
  </worldbody>

  <contact>
    {get_contacts("striker")}
    {get_contacts("keeper")}
    <pair geom1="post_geom1" geom2="ball"/>
    <pair geom1="post_geom2" geom2="ball"/>
    <pair geom1="bar_geom" geom2="ball"/>
  </contact>

  <actuator>
    {get_actuators("striker")}
    {get_actuators("keeper")}
  </actuator>
</mujoco>
"""

In [None]:
# Humanoid Env

class Humanoid(PipelineEnv):

  def __init__(
      self,
      terminate_when_unhealthy=True,
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,
  ):
    mj_model = mujoco.MjModel.from_xml_string(xml)
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )
    
    data = self.pipeline_init(qpos, qvel)

    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        "reward": zero,
        "goal": zero,
    }
    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    
    if training_agent == "striker":
        reward, done = self._get_striker_reward(state, action, data0, data)
    elif training_agent == "keeper":
        reward, done = self._get_keeper_reward(state, action, data0, data)
    else:
        raise Exception(f"{training_agent=} isn't supported.")
    
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body and ball position, velocities, and angles."""
    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # external_contact_forces are excluded
    return jp.concatenate([
        # qpos: position / nq: number of generalized coordinates = dim(qpos)
        position,
        # qvel: velocity / nv: number of degrees of freedom = dim(qvel)
        data.qvel,
        # cinert: com-based body inertia and mass / (nbody, 10)
        data.cinert[1:].ravel(),
        # cvel: com-based velocity [3D rot; 3D tran] / (nbody, 6)
        data.cvel[1:].ravel(),
        # qfrc_actuator: actuator force / nv: number of degrees of freedom
        data.qfrc_actuator,
    ])


  def _get_striker_reward(
      self, state: State,  action: jp.ndarray, data0: mjx.Data, data: mjx.Data
  ) -> Tuple[jp.ndarray, jp.ndarray]:
    """Apply reward func for striker.
    
    Based on distance to goal, ball speed and whether it's a goal
    """
    ctrl_cost_weight = 0.1
    healthy_reward = 1.0
    healthy_z_range = (1.0, 3.0)
    ball_reward = 5.0
    ball_healthy_z_range = (0.0, goal_height)
    ball_reward_target_x = goal_distance
    goal_reward = 1000
    
    com_before_ball = data0.subtree_com[-1]
    com_after_ball = data.subtree_com[-1]
    
    min_z, max_z = healthy_z_range
    is_healthy = jp.where(data.q[torso_index] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[torso_index] > max_z, 0.0, is_healthy)

    ball_min_z, ball_max_z = ball_healthy_z_range
    is_healthy = jp.where(com_after_ball[2] < ball_min_z, 0.0, is_healthy)
    is_healthy = jp.where(com_after_ball[2] > ball_max_z, 0.0, is_healthy)
    
    is_healthy = jp.where(com_after_ball[0] > goal_distance * 1.5, 0.0, is_healthy)
    
    ctrl_cost = ctrl_cost_weight * jp.sum(jp.square(action))
    
    velocity = (com_after_ball - com_before_ball) / self.dt
    distance_goal = jp.sqrt(jp.square(ball_reward_target_x - com_after_ball[0]))
    
    ball_velocity_reward = jp.where(com_after_ball[0] > goal_distance, 0.0, ball_reward * velocity[0])
    ball_reward = jp.where(com_after_ball[0] > goal_distance, 0.0, ball_reward * (1 - (distance_goal / ball_reward_target_x)))
    
    is_goal = self._is_goal(com_before_ball, com_after_ball)
    goal_reward = goal_reward * is_goal
    
    reward = ball_reward + ball_velocity_reward + healthy_reward - ctrl_cost + goal_reward

    state.metrics.update(
        reward=reward,
        goal=is_goal,
    )
    
    done = 1.0 - is_healthy
    return reward, done


  def _get_keeper_reward(
      self, state: State,  action: jp.ndarray, data0: mjx.Data, data: mjx.Data
  ) -> Tuple[jp.ndarray, jp.ndarray]:
    """Apply reward func for keeper.
    
    Based on distance to y_z ball coordinates, and whether it's a goal
    """
    ctrl_cost_weight = 0.1
    healthy_reward = 1.0
    healthy_z_range = (1.0, 3.0)
    distance_ball_reward = 5.0
    distance_ball_target = 0.0
    goal_reward = -1000
    
    com_before_ball = data0.subtree_com[-1]
    com_after_ball = data.subtree_com[-1]
    com_after_left_hand = data.subtree_com[32]
    com_after_right_hand = data.subtree_com[29]
    distance_left_hand = jp.sqrt(jp.square(com_after_ball[1] - com_after_left_hand[1]) + jp.square(com_after_ball[2] - com_after_left_hand[2]))
    distance_right_hand = jp.sqrt(jp.square(com_after_ball[1] - com_after_right_hand[1]) + jp.square(com_after_ball[2] - com_after_right_hand[2]))
    distance_hands = jp.mean(jp.array([distance_left_hand, distance_right_hand]))
    
    min_z, max_z = healthy_z_range
    is_healthy = jp.where(data.q[torso_index] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[torso_index] > max_z, 0.0, is_healthy)
    is_healthy = jp.where(com_after_ball[0] > goal_distance * 1.5, 0.0, is_healthy)
    
    ctrl_cost = ctrl_cost_weight * jp.sum(jp.square(action))
    
    distance_ball_reward = distance_ball_reward * (1 - (distance_hands / goal_width))
    
    is_goal = self._is_goal(com_before_ball, com_after_ball)
    goal_reward = goal_reward * is_goal
    
    reward = distance_ball_reward - ctrl_cost + goal_reward

    state.metrics.update(
        reward=reward,
        goal=is_goal,
    )
    
    done = 1.0 - is_healthy
    return reward, done
  
    
  def _is_goal(
      self, com_before_ball: jp.ndarray, com_after_ball: jp.ndarray
  ) -> jp.ndarray:
    """Check if it's a goal."""
    is_goal = jp.where(com_before_ball[0] < goal_distance, 1.0, 0.0)
    is_goal = jp.where(com_after_ball[0] >= goal_distance, is_goal, 0.0)
    is_goal = jp.where(com_after_ball[1] > post1_y, is_goal, 0.0)
    is_goal = jp.where(com_after_ball[1] < post2_y, is_goal, 0.0)
    is_goal = jp.where(com_after_ball[2] < goal_height, is_goal, 0.0)
    return is_goal
    

envs.register_environment("humanoid", Humanoid)

## 5 - Visualize a rollout

Let's instantiate the environment and visualize a short rollout.

In [None]:
# instantiate the environment
env_name = "humanoid"
env = envs.get_environment(env_name)

# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [None]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
print(f"Observations size: {len(state.obs)}")
print(f"Actions size: {env.sys.nu}")

rollout = [state.pipeline_state]

# grab a trajectory
for i in range(50):
  # ctrl: control / nu: number of actuators/controls = dim(ctrl)
  ctrl = -0.1 * jp.ones(env.sys.nu)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

media.show_video(env.render(rollout, camera="back_striker", height=480, width=640), fps=1.0 / env.dt)

## 6 - Define the training functions

Let's define the training functions using [PPO](https://openai.com/research/openai-baselines-ppo) to make the Humanoids take and stop penalties.

In [None]:
# Define the Acting/Evaluator (adapted from https://github.com/google/brax)

"""Brax training acting functions."""

import time
from typing import Callable, Sequence, Tuple, Union

from brax import envs
from brax.training.types import Metrics
from brax.training.types import Policy
from brax.training.types import PolicyParams
from brax.training.types import PRNGKey
from brax.training.types import Transition
from brax.v1 import envs as envs_v1
import jax
import jax.numpy as jp
import numpy as np

ActingState = Union[envs.State, envs_v1.State]
ActingEnv = Union[envs.Env, envs_v1.Env, envs_v1.Wrapper]


def actor_step(
    env: ActingEnv,
    env_state: ActingState,
    training_policy: Policy,
    non_training_policy: Policy,
    key: PRNGKey,
    extra_fields: Sequence[str] = ()
) -> Tuple[ActingState, Transition]:
  """Collect data."""
  # We care about ['policy_extras']['raw_action'] and ['policy_extras']['log_prob'] 
  # for computing PPO loss, so the relevant policy_extras come from the training_policy
  training_agent_actions, policy_extras = training_policy(env_state.obs, key)
  non_training_agent_actions, _ = non_training_policy(env_state.obs, key)
  # This is prone to error as the way to concatenate the actions 
  # depends on the order of the actuators in the xml
  if training_agent == "striker":
    actions = jp.concatenate((training_agent_actions, non_training_agent_actions), axis=1)
  elif training_agent == "keeper":
    actions = jp.concatenate((non_training_agent_actions, training_agent_actions), axis=1)
  else:
    raise Exception(f"{training_agent=} isn't supported.") 
  
  assert actions.shape[1] == env.action_size
  nstate = env.step(env_state, actions)
  state_extras = {x: nstate.info[x] for x in extra_fields}
  return nstate, Transition(  # pytype: disable=wrong-arg-types  # jax-ndarray
      observation=env_state.obs,
      action=actions,
      reward=nstate.reward,
      discount=1 - nstate.done,
      next_observation=nstate.obs,
      extras={
          'policy_extras': policy_extras,
          'state_extras': state_extras
      })


def generate_unroll(
    env: ActingEnv,
    env_state: ActingState,
    training_policy: Policy,
    non_training_policy: Policy,
    key: PRNGKey,
    unroll_length: int,
    extra_fields: Sequence[str] = ()
) -> Tuple[ActingState, Transition]:
  """Collect trajectories of given unroll_length."""

  @jax.jit
  def f(carry, unused_t):
    state, current_key = carry
    current_key, next_key = jax.random.split(current_key)
    nstate, transition = actor_step(
        env, state, training_policy, non_training_policy, current_key, extra_fields=extra_fields)
    return (nstate, next_key), transition

  (final_state, _), data = jax.lax.scan(
      f, (env_state, key), (), length=unroll_length)
  return final_state, data


# TODO: Consider moving this to its own file.
class Evaluator:
  """Class to run evaluations."""

  def __init__(self, eval_env: envs.Env,
               eval_policy_fn: Callable[[PolicyParams],
                                        Policy], num_eval_envs: int,
               episode_length: int, action_repeat: int, key: PRNGKey):
    """Init.

    Args:
      eval_env: Batched environment to run evals on.
      eval_policy_fn: Function returning the policy from the policy parameters.
      num_eval_envs: Each env will run 1 episode in parallel for each eval.
      episode_length: Maximum length of an episode.
      action_repeat: Number of physics steps per env step.
      key: RNG key.
    """
    self._key = key
    self._eval_walltime = 0.

    eval_env = envs.training.EvalWrapper(eval_env)

    def generate_eval_unroll(training_policy_params: PolicyParams,
                             non_training_policy_params: PolicyParams,
                             key: PRNGKey) -> ActingState:
      reset_keys = jax.random.split(key, num_eval_envs)
      eval_first_state = eval_env.reset(reset_keys)
      return generate_unroll(
          eval_env,
          eval_first_state,
          eval_policy_fn(training_policy_params),
          eval_policy_fn(non_training_policy_params),
          key,
          unroll_length=episode_length // action_repeat)[0]

    self._generate_eval_unroll = jax.jit(generate_eval_unroll)
    self._steps_per_unroll = episode_length * num_eval_envs

  def run_evaluation(self,
                     training_policy_params: PolicyParams,
                     non_training_policy_params: PolicyParams,
                     training_metrics: Metrics,
                     aggregate_episodes: bool = True) -> Metrics:
    """Run one epoch of evaluation."""
    self._key, unroll_key = jax.random.split(self._key)

    t = time.time()
    eval_state = self._generate_eval_unroll(training_policy_params,
                                            non_training_policy_params,
                                            unroll_key)
    eval_metrics = eval_state.info['eval_metrics']
    eval_metrics.active_episodes.block_until_ready()
    epoch_eval_time = time.time() - t
    metrics = {}
    for fn in [np.mean, np.std, np.max]:
      suffix = '_std' if fn == np.std else '_max' if fn == np.max else ''
      metrics.update(
          {
              f'eval/episode_{name}{suffix}': (
                  fn(value) if aggregate_episodes else value
              )
              for name, value in eval_metrics.episode_metrics.items()
          }
      )
    metrics['eval/avg_episode_length'] = np.mean(eval_metrics.episode_steps)
    metrics['eval/epoch_eval_time'] = epoch_eval_time
    metrics['eval/sps'] = self._steps_per_unroll / epoch_eval_time
    self._eval_walltime = self._eval_walltime + epoch_eval_time
    metrics = {
        'eval/walltime': self._eval_walltime,
        **training_metrics,
        **metrics
    }

    return metrics  # pytype: disable=bad-return-type  # jax-ndarray

In [None]:
# Define the Training Function (adapted from https://github.com/google/brax)
# This module was changed to allow training two agents in the same environment

"""Proximal policy optimization training.

See: https://arxiv.org/pdf/1707.06347.pdf
"""

import functools
import time
from typing import Callable, Optional, Tuple, Union

from absl import logging
from brax import base
from brax import envs
from brax.training import gradients
from brax.training import pmap
from brax.training import types
from brax.training.acme import running_statistics
from brax.training.acme import specs
from brax.training.agents.ppo import losses as ppo_losses
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.types import Params, PolicyParams, PreprocessorParams
from brax.training.types import PRNGKey
from brax.v1 import envs as envs_v1
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax


InferenceParams = Tuple[running_statistics.NestedMeanStd, Params]
Metrics = types.Metrics
ValueParams = Any

_PMAP_AXIS_NAME = 'i'


@flax.struct.dataclass
class TrainingState:
  """Contains training state for the learner."""
  optimizer_state: optax.OptState
  params: ppo_losses.PPONetworkParams
  normalizer_params: running_statistics.RunningStatisticsState
  env_steps: jnp.ndarray


def _unpmap(v):
  return jax.tree_util.tree_map(lambda x: x[0], v)


def _strip_weak_type(tree):
  # brax user code is sometimes ambiguous about weak_type.  in order to
  # avoid extra jit recompilations we strip all weak types from user input
  def f(leaf):
    leaf = jnp.asarray(leaf)
    return leaf.astype(leaf.dtype)
  return jax.tree_util.tree_map(f, tree)


def train(
    environment: Union[envs_v1.Env, envs.Env],
    num_timesteps: int,
    episode_length: int,
    action_repeat: int = 1,
    num_envs: int = 1,
    max_devices_per_host: Optional[int] = None,
    num_eval_envs: int = 128,
    learning_rate: float = 1e-4,
    entropy_cost: float = 1e-4,
    discounting: float = 0.9,
    seed: int = 0,
    unroll_length: int = 10,
    batch_size: int = 32,
    num_minibatches: int = 16,
    num_updates_per_batch: int = 2,
    num_evals: int = 1,
    num_resets_per_eval: int = 0,
    normalize_observations: bool = False,
    reward_scaling: float = 1.0,
    clipping_epsilon: float = 0.3,
    gae_lambda: float = 0.95,
    deterministic_eval: bool = False,
    network_factory: types.NetworkFactory[
        ppo_networks.PPONetworks
    ] = ppo_networks.make_ppo_networks,
    progress_fn: Callable[[int, Metrics], None] = lambda *args: None,
    normalize_advantage: bool = True,
    eval_env: Optional[envs.Env] = None,
    policy_params_fn: Callable[..., None] = lambda *args: None,
    randomization_fn: Optional[
        Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]]
    ] = None,
    striker_params: Optional[
        Tuple[PreprocessorParams, PolicyParams, ValueParams]
    ] = None,
    keeper_params: Optional[
        Tuple[PreprocessorParams, PolicyParams, ValueParams]
    ] = None,
    training_agent: str = "striker",
    humanoid_actuators: int = 21,
):
  """PPO training.

  Args:
    environment: the environment to train
    num_timesteps: the total number of environment steps to use during training
    episode_length: the length of an environment episode
    action_repeat: the number of timesteps to repeat an action
    num_envs: the number of parallel environments to use for rollouts
      NOTE: `num_envs` must be divisible by the total number of chips since each
        chip gets `num_envs // total_number_of_chips` environments to roll out
      NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since
        data generated by `num_envs` parallel envs gets used for gradient
        updates over `num_minibatches` of data, where each minibatch has a
        leading dimension of `batch_size`
    max_devices_per_host: maximum number of chips to use per host process
    num_eval_envs: the number of envs to use for evaluation. Each env will run 1
      episode, and all envs run in parallel during eval.
    learning_rate: learning rate for ppo loss
    entropy_cost: entropy reward for ppo loss, higher values increase entropy
      of the policy
    discounting: discounting rate
    seed: random seed
    unroll_length: the number of timesteps to unroll in each environment. The
      PPO loss is computed over `unroll_length` timesteps
    batch_size: the batch size for each minibatch SGD step
    num_minibatches: the number of times to run the SGD step, each with a
      different minibatch with leading dimension of `batch_size`
    num_updates_per_batch: the number of times to run the gradient update over
      all minibatches before doing a new environment rollout
    num_evals: the number of evals to run during the entire training run.
      Increasing the number of evals increases total training time
    num_resets_per_eval: the number of environment resets to run between each
      eval. The environment resets occur on the host
    normalize_observations: whether to normalize observations
    reward_scaling: float scaling for reward
    clipping_epsilon: clipping epsilon for PPO loss
    gae_lambda: General advantage estimation lambda
    deterministic_eval: whether to run the eval with a deterministic policy
    network_factory: function that generates networks for policy and value
      functions
    progress_fn: a user-defined callback function for reporting/plotting metrics
    normalize_advantage: whether to normalize advantage estimate
    eval_env: an optional environment for eval only, defaults to `environment`
    policy_params_fn: a user-defined callback function that can be used for
      saving policy checkpoints
    randomization_fn: a user-defined callback function that generates randomized
      environments
    striker_params: striker params; includes normalizer_params
      and policy and value network params
    keeper_params: keeper params; includes normalizer_params
      and policy and value network params
    training_agent: agent to train; can be 'striker' or 'keeper'
    humanoid_actuators: num of humanoid actuators; both for 'striker' and 'keeper'

  Returns:
    Tuple of (make_policy function, network params, metrics, max score network params)
  """
  assert batch_size * num_minibatches % num_envs == 0
  xt = time.time()

  process_count = jax.process_count()
  process_id = jax.process_index()
  local_device_count = jax.local_device_count()
  local_devices_to_use = local_device_count
  if max_devices_per_host:
    local_devices_to_use = min(local_devices_to_use, max_devices_per_host)
  logging.info(
      'Device count: %d, process count: %d (id %d), local device count: %d, '
      'devices to be used count: %d', jax.device_count(), process_count,
      process_id, local_device_count, local_devices_to_use)
  device_count = local_devices_to_use * process_count

  # The number of environment steps executed for every training step.
  env_step_per_training_step = (
      batch_size * unroll_length * num_minibatches * action_repeat)
  num_evals_after_init = max(num_evals - 1, 1)
  # The number of training_step calls per training_epoch call.
  # equals to ceil(num_timesteps / (num_evals * env_step_per_training_step *
  #                                 num_resets_per_eval))
  num_training_steps_per_epoch = np.ceil(
      num_timesteps
      / (
          num_evals_after_init
          * env_step_per_training_step
          * max(num_resets_per_eval, 1)
      )
  ).astype(int)

  key = jax.random.PRNGKey(seed)
  global_key, local_key = jax.random.split(key)
  del key
  local_key = jax.random.fold_in(local_key, process_id)
  local_key, key_env, eval_key = jax.random.split(local_key, 3)
  # key_networks should be global, so that networks are initialized the same
  # way for different processes.
  key_policy, key_value = jax.random.split(global_key)
  key_training_policy, key_non_training_policy = jax.random.split(key_policy)
  del global_key
  del key_policy

  assert num_envs % device_count == 0

  v_randomization_fn = None
  if randomization_fn is not None:
    randomization_batch_size = num_envs // local_device_count
    # all devices gets the same randomization rng
    randomization_rng = jax.random.split(key_env, randomization_batch_size)
    v_randomization_fn = functools.partial(
        randomization_fn, rng=randomization_rng
    )

  if isinstance(environment, envs.Env):
    wrap_for_training = envs.training.wrap
  else:
    wrap_for_training = envs_v1.wrappers.wrap_for_training

  env = wrap_for_training(
      environment,
      episode_length=episode_length,
      action_repeat=action_repeat,
      randomization_fn=v_randomization_fn,
  )

  reset_fn = jax.jit(jax.vmap(env.reset))
  key_envs = jax.random.split(key_env, num_envs // process_count)
  key_envs = jnp.reshape(key_envs,
                         (local_devices_to_use, -1) + key_envs.shape[1:])
  env_state = reset_fn(key_envs)

  normalize = lambda x, y: x
  if normalize_observations:
    normalize = running_statistics.normalize
  
  assert humanoid_actuators * 2 == env.action_size
  ppo_network = network_factory(
      env_state.obs.shape[-1],
      humanoid_actuators,
      preprocess_observations_fn=normalize)
  make_policy = ppo_networks.make_inference_fn(ppo_network)

  optimizer = optax.adam(learning_rate=learning_rate)

  loss_fn = functools.partial(
      ppo_losses.compute_ppo_loss,
      ppo_network=ppo_network,
      entropy_cost=entropy_cost,
      discounting=discounting,
      reward_scaling=reward_scaling,
      gae_lambda=gae_lambda,
      clipping_epsilon=clipping_epsilon,
      normalize_advantage=normalize_advantage)

  gradient_update_fn = gradients.gradient_update_fn(
      loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True)

  if training_agent == "striker":
    training_params = striker_params
    non_training_params = keeper_params
  elif training_agent == "keeper":
    training_params = keeper_params
    non_training_params = striker_params
  else:
    raise Exception(f"{training_agent=} isn't supported.") 
  
  if training_params is None:
    init_params = ppo_losses.PPONetworkParams(
      policy=ppo_network.policy_network.init(key_training_policy),
      value=ppo_network.value_network.init(key_value))
    normalizer_params = running_statistics.init_state(
      specs.Array(env_state.obs.shape[-1:], jnp.dtype('float32')))
  else:
    init_params = ppo_losses.PPONetworkParams(
      policy=training_params[1],
      value=training_params[2])
    normalizer_params = training_params[0]

  if non_training_params is None:
    non_training_params = []
    non_training_params.append(running_statistics.init_state(
      specs.Array(env_state.obs.shape[-1:], jnp.dtype('float32'))))
    non_training_params.append(ppo_network.policy_network.init(key_non_training_policy))

    
  def minibatch_step(
      carry, data: types.Transition,
      normalizer_params: running_statistics.RunningStatisticsState):
    optimizer_state, params, key = carry
    key, key_loss = jax.random.split(key)
    (_, metrics), params, optimizer_state = gradient_update_fn(
        params,
        normalizer_params,
        data,
        key_loss,
        optimizer_state=optimizer_state)

    return (optimizer_state, params, key), metrics

  def sgd_step(carry, unused_t, data: types.Transition,
               normalizer_params: running_statistics.RunningStatisticsState):
    optimizer_state, params, key = carry
    key, key_perm, key_grad = jax.random.split(key, 3)

    def convert_data(x: jnp.ndarray):
      x = jax.random.permutation(key_perm, x)
      x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:])
      return x

    shuffled_data = jax.tree_util.tree_map(convert_data, data)
    (optimizer_state, params, _), metrics = jax.lax.scan(
        functools.partial(minibatch_step, normalizer_params=normalizer_params),
        (optimizer_state, params, key_grad),
        shuffled_data,
        length=num_minibatches)
    return (optimizer_state, params, key), metrics

  def training_step(
      carry: Tuple[TrainingState, envs.State, PRNGKey],
      unused_t) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]:
    training_state, state, key = carry
    key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3)

    training_policy = make_policy(
        (training_state.normalizer_params, training_state.params.policy))
    non_training_policy = make_policy(
        (non_training_params[0], non_training_params[1]))

    def f(carry, unused_t):
      current_state, current_key = carry
      current_key, next_key = jax.random.split(current_key)
      next_state, data = generate_unroll(
          env,
          current_state,
          training_policy,
          non_training_policy,
          current_key,
          unroll_length,
          extra_fields=('truncation',))
      return (next_state, next_key), data

    (state, _), data = jax.lax.scan(
        f, (state, key_generate_unroll), (),
        length=batch_size * num_minibatches // num_envs)
    # Have leading dimensions (batch_size * num_minibatches, unroll_length)
    data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data)
    data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]),
                                  data)
    assert data.discount.shape[1:] == (unroll_length,)

    # Update normalization params and normalize observations.
    normalizer_params = running_statistics.update(
        training_state.normalizer_params,
        data.observation,
        pmap_axis_name=_PMAP_AXIS_NAME)

    (optimizer_state, params, _), metrics = jax.lax.scan(
        functools.partial(
            sgd_step, data=data, normalizer_params=normalizer_params),
        (training_state.optimizer_state, training_state.params, key_sgd), (),
        length=num_updates_per_batch)

    new_training_state = TrainingState(
        optimizer_state=optimizer_state,
        params=params,
        normalizer_params=normalizer_params,
        env_steps=training_state.env_steps + env_step_per_training_step)
    return (new_training_state, state, new_key), metrics

  def training_epoch(training_state: TrainingState, state: envs.State,
                     key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]:
    (training_state, state, _), loss_metrics = jax.lax.scan(
        training_step, (training_state, state, key), (),
        length=num_training_steps_per_epoch)
    loss_metrics = jax.tree_util.tree_map(jnp.mean, loss_metrics)
    return training_state, state, loss_metrics

  training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME)

  # Note that this is NOT a pure jittable method.
  def training_epoch_with_timing(
      training_state: TrainingState, env_state: envs.State,
      key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]:
    nonlocal training_walltime
    t = time.time()
    training_state, env_state = _strip_weak_type((training_state, env_state))
    result = training_epoch(training_state, env_state, key)
    training_state, env_state, metrics = _strip_weak_type(result)

    metrics = jax.tree_util.tree_map(jnp.mean, metrics)
    jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics)

    epoch_training_time = time.time() - t
    training_walltime += epoch_training_time
    sps = (num_training_steps_per_epoch *
           env_step_per_training_step *
           max(num_resets_per_eval, 1)) / epoch_training_time
    metrics = {
        'training/sps': sps,
        'training/walltime': training_walltime,
        **{f'training/{name}': value for name, value in metrics.items()}
    }
    return training_state, env_state, metrics  # pytype: disable=bad-return-type  # py311-upgrade
    

  training_state = TrainingState(  # pytype: disable=wrong-arg-types  # jax-ndarray
      optimizer_state=optimizer.init(init_params),  # pytype: disable=wrong-arg-types  # numpy-scalars
      params=init_params,
      normalizer_params=normalizer_params,
      env_steps=0)
  training_state = jax.device_put_replicated(
      training_state,
      jax.local_devices()[:local_devices_to_use])

  if not eval_env:
    eval_env = environment
  if randomization_fn is not None:
    v_randomization_fn = functools.partial(
        randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)
    )
  eval_env = wrap_for_training(
      eval_env,
      episode_length=episode_length,
      action_repeat=action_repeat,
      randomization_fn=v_randomization_fn,
  )

  evaluator = Evaluator(
      eval_env,
      functools.partial(make_policy, deterministic=deterministic_eval),
      num_eval_envs=num_eval_envs,
      episode_length=episode_length,
      action_repeat=action_repeat,
      key=eval_key)

  # Run initial eval
  metrics = {}
  if process_id == 0 and num_evals > 1:
    metrics = evaluator.run_evaluation(
        _unpmap(
            (training_state.normalizer_params, training_state.params.policy)),
        (non_training_params[0], non_training_params[1]),
        training_metrics={})
    logging.info(metrics)
    progress_fn(0, metrics)

  training_metrics = {}
  training_walltime = 0
  current_step = 0
  # Initialize variables to allow saving params of run with max score
  max_score = -99999
  max_score_params = {}
  for it in range(num_evals_after_init):
    logging.info('starting iteration %s %s', it, time.time() - xt)

    for _ in range(max(num_resets_per_eval, 1)):
      # optimization
      epoch_key, local_key = jax.random.split(local_key)
      epoch_keys = jax.random.split(epoch_key, local_devices_to_use)
      (training_state, env_state, training_metrics) = (
          training_epoch_with_timing(training_state, env_state, epoch_keys)
      )
      current_step = int(_unpmap(training_state.env_steps))

      key_envs = jax.vmap(
          lambda x, s: jax.random.split(x[0], s),
          in_axes=(0, None))(key_envs, key_envs.shape[1])
      # TODO: move extra reset logic to the AutoResetWrapper.
      env_state = reset_fn(key_envs) if num_resets_per_eval > 0 else env_state

    if process_id == 0:
      # Run evals.
      metrics = evaluator.run_evaluation(
          _unpmap(
              (training_state.normalizer_params, training_state.params.policy)),
          (non_training_params[0], non_training_params[1]),
          training_metrics)
      logging.info(metrics)
      progress_fn(current_step, metrics)
      params = _unpmap(
          (training_state.normalizer_params, training_state.params.policy,
           training_state.params.value))

      # Save params if this is the max score
      eval_score = metrics['eval/episode_reward']
      if eval_score > max_score:
        max_score = eval_score
        max_score_params = {
            "score": max_score,
            "params": params,
        }
      policy_params_fn(current_step, make_policy, params)

  total_steps = current_step
  assert total_steps >= num_timesteps

  # If there was no mistakes the training_state should still be identical on all
  # devices.
  pmap.assert_is_replicated(training_state)
  params = _unpmap(
      (training_state.normalizer_params, training_state.params.policy,
       training_state.params.value))
  logging.info('total steps: %s', total_steps)
  pmap.synchronize_hosts()
  return (make_policy, params, metrics, max_score_params)


In [None]:
# Define the PPO networks (adapted from https://github.com/google/brax)

@flax.struct.dataclass
class PPONetworks:
  policy_network: networks.FeedForwardNetwork
  value_network: networks.FeedForwardNetwork
  parametric_action_distribution: distribution.ParametricDistribution

def make_ppo_networks(
    observation_size: int,
    action_size: int,
    preprocess_observations_fn: types.PreprocessObservationFn = types
    .identity_observation_preprocessor,
    policy_hidden_layer_sizes: Sequence[int] = policy_hidden_layer_sizes,
    value_hidden_layer_sizes: Sequence[int] = value_hidden_layer_sizes,
    activation: networks.ActivationFn = linen.swish) -> PPONetworks:
  """Make PPO networks with preprocessor."""
  parametric_action_distribution = distribution.NormalTanhDistribution(
      event_size=action_size)
  policy_network = networks.make_policy_network(
      parametric_action_distribution.param_size,
      observation_size,
      preprocess_observations_fn=preprocess_observations_fn,
      hidden_layer_sizes=policy_hidden_layer_sizes,
      activation=activation)
  value_network = networks.make_value_network(
      observation_size,
      preprocess_observations_fn=preprocess_observations_fn,
      hidden_layer_sizes=value_hidden_layer_sizes,
      activation=activation)

  return PPONetworks(
      policy_network=policy_network,
      value_network=value_network,
      parametric_action_distribution=parametric_action_distribution)

## 7 - Training the Humanoids

My suggestion is to train the `striker` and `keeper` alternately for a few rounds: 
1. Start with the `striker` and train it until it scores a high percentage of penalties (set `training_agent` to `striker` in section `2 - Config`)
2. Save and load the neural network params (change `upload_striker_model` to `True` and update `striker_model_path`)
3. Switch and train the `keeper` until the percentage of scored penalties goes down significantly (change `training_agent` to `keeper` in section `2 - Config`)
4. Save and load the neural network params (change `upload_keeper_model` to `True` and update `keeper_model_path`)
5. Go back to step 1

Training the striker or the keeper for 20m timesteps with 9 evals takes about 35min with two T4 GPUs. Depending on the starting point, that should enable the striker to score and the keeper to stop ~65% of the penalties.

In [None]:
# Load params to restart training from a saved checkpoint
# (i.e. from the saved policy and value neural networks' weights)

# After you train the striker and keeper you'll be able to save and load 
# their neural nets, we recommend to save these as a Dataset 
# (click on the button Upload in the right pane under Input),
# and update the paths below accordingly
striker_model_path = "/kaggle/input/<dataset_name>/<neural_nets_filename>"
keeper_model_path = "/kaggle/input/<dataset_name>/<neural_nets_filename>"

training_save_path = "/kaggle/working/mjx_brax_nn_goals"
upload_striker_model = False
upload_keeper_model = False

if training_agent == "striker":
    if upload_striker_model:
      striker_params = model.load_params(training_save_path)
#       striker_params = model.load_params(striker_model_path)
    else:
      striker_params = None
    if upload_keeper_model:
      keeper_params = model.load_params(keeper_model_path)
    else:
      keeper_params = None
elif training_agent == "keeper":    
    if upload_striker_model:
      striker_params = model.load_params(striker_model_path)
    else:
      striker_params = None
    if upload_keeper_model:
#       keeper_params = model.load_params(training_save_path)
      keeper_params = model.load_params(keeper_model_path)
    else:
      keeper_params = None
else:
    raise Exception(f"{training_agent=} isn't supported.")

In [None]:
# Train
train_fn = functools.partial(
    train, num_timesteps=num_timesteps, num_evals=num_evals,
    episode_length=episode_length, normalize_observations=normalize_observations,
    action_repeat=action_repeat, unroll_length=unroll_length, num_minibatches=num_minibatches,
    num_updates_per_batch=num_updates_per_batch, discounting=discounting,
    learning_rate=learning_rate, entropy_cost=entropy_cost, num_envs=num_envs,
    reward_scaling=reward_scaling, clipping_epsilon=clipping_epsilon, gae_lambda=gae_lambda,
    normalize_advantage=normalize_advantage, batch_size=batch_size, seed=0,
    network_factory=make_ppo_networks, striker_params=striker_params,
    keeper_params=keeper_params, training_agent=training_agent)


x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]

max_y, min_y = 5000, 0
def progress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])

  plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.1])
  plt.ylim([min_y, max_y])

  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.title(f'y={y_data[-1]:.1f}')

  plt.errorbar(
      x_data, y_data, yerr=ydataerr)
  plt.show()

  if 'training/policy_loss' in metrics:
    print("Other metrics")
    print(f"value loss: {metrics['training/v_loss']:.2f}")
    print(f"max episode reward: {metrics['eval/episode_reward_max']:.0f}")
    print(f"goals: {int(metrics['eval/episode_goal'] * 100)} / 100")

make_inference_fn, train_params, _, max_score_params = train_fn(
    environment=env, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')
print(f'total time: {times[-1] - times[0]}\n')
print(f"max score: {int(max_score_params['score'])}")

## 8 - Save and load the policy

We can save and load the policy using the brax model API.

In [None]:
# Save the model
model.save_params(training_save_path, train_params)
# model.save_params(training_save_path, max_score_params["params"])

In [None]:
# Load the model and define the inference functions
training_inference_fn = make_inference_fn(model.load_params(training_save_path)[:2])
jit_training_inference_fn = jax.jit(training_inference_fn)

if training_agent == "striker":    
    if upload_keeper_model:
      non_training_inference_fn = make_inference_fn(keeper_params[:2])  
    else:
      # Ugly hack for when we don't have non_training_params
      # (maybe we need some random initialisation for these)
      non_training_inference_fn = make_inference_fn(model.load_params(training_save_path)[:2])
elif training_agent == "keeper":    
    if upload_striker_model:
      non_training_inference_fn = make_inference_fn(striker_params[:2])
    else:
      raise Exception("You need to upload a striker model to train the keeper")
else:
    raise Exception(f"{training_agent=} isn't supported.")

jit_non_training_inference_fn = jax.jit(non_training_inference_fn)

## 9 - Visualize the policies

Finally we can visualize the Humanoids in action and watch while they take and stop penalties!

This can also be saved to an mp4 file which you can then download from the `Output` section (can be found on the right if running in a laptop).

In [None]:
eval_env = envs.get_environment(env_name)

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

In [None]:
# Visualize the Humanoids and optionally save it to a mp4 file
init_time = datetime.now()
max_rollout_reward = 0
goals = 0
for i in range(num_penalties):
  # Initialize the state
  rng = jax.random.PRNGKey(i)
  state = jit_reset(rng)
  rollout = [state.pipeline_state]
  total_reward = 0

  # Grab a trajectory
  n_steps = 100
  render_every = 2

  for i in range(n_steps):
    act_rng, rng = jax.random.split(rng)
    training_ctrl, _ = jit_training_inference_fn(state.obs, act_rng)
    non_training_ctrl, _ = jit_non_training_inference_fn(state.obs, act_rng)
    # This is prone to error as the way to concatenate the actions 
    # depends on the order of the actuators in the xml
    if training_agent == "striker":
      ctrl = jp.concatenate((training_ctrl, non_training_ctrl), axis=0)
    elif training_agent == "keeper":
      ctrl = jp.concatenate((non_training_ctrl, training_ctrl), axis=0)
    else:
      raise Exception(f"{training_agent=} isn't supported.") 

    assert len(ctrl) == env.action_size
    state = jit_step(state, ctrl)
    total_reward += state.metrics["reward"]
    goals += state.metrics["goal"]
    rollout.append(state.pipeline_state)

    if state.done:
      break

  max_rollout_reward = max(max_rollout_reward, total_reward)
    
  print(f"Iteration with reward {int(total_reward)}")
  video = env.render(rollout[::render_every], camera="back_striker", height=480, width=640)
  media.show_video(video, fps=1.0 / env.dt / render_every)
  media.write_video(f"/kaggle/working/goals_{int(total_reward)}.mp4", video, fps=1.0 / env.dt / render_every)
    
print(f"Max rollout reward was - {int(max_rollout_reward)}")
print(f"Scored {int(goals)} / {num_penalties} penalties")
print(f'total time: {datetime.now() - init_time}')

## 10 - Tournament Mode

As a bonus feature you can also run the environment in Tournament Mode in which two teams of a striker + keeper compete against each other in a penalty shootout. To enable it you need to set `tournament_mode` in section `2 - Config` to `True`. Let's see who wins!

Note: you also need to have saved `striker` and `keeper` neural nets to use in the tournament (see below).

In [None]:
def get_jit_inference_func(path: str) -> Callable[jax.Array, str]:
    inference_fn = make_inference_fn(model.load_params(path)[:2])
    return jax.jit(inference_fn)


def display_video(camera: str, rollout: list) -> None:
    vid = env.render(rollout[::render_every], camera=camera, height=480, width=640)
    media.show_video(vid, fps=1.0 / env.dt / render_every)
    media.write_video(f"/kaggle/working/goals_{camera}_{datetime.now()}.mp4", vid, fps=1.0 / env.dt / render_every)


def penalty(
    striker_jit_inference_fn: Callable[jax.Array, str],
    keeper_jit_inference_fn: Callable[jax.Array, str],
    key: int,
) -> int:
    # Initialize the state
    rng = jax.random.PRNGKey(key)
    state = jit_reset(rng)
    rollout = [state.pipeline_state]

    # Grab a trajectory
    n_steps = 200
    render_every = 1
    goal = 0
    for _ in range(n_steps):
        act_rng, rng = jax.random.split(rng)
        striker_ctrl, _ = striker_jit_inference_fn(state.obs, act_rng)
        keeper_ctrl, _ = keeper_jit_inference_fn(state.obs, rng)
        ctrl = jp.concatenate((striker_ctrl, keeper_ctrl), axis=0)

        assert len(ctrl) == env.action_size
        state = jit_step(state, ctrl)
        goal += state.metrics["goal"]
        rollout.append(state.pipeline_state)

        if state.done:
          break

    display_video("back_striker", rollout)
    display_video("side_keeper", rollout)
    assert goal <= 1
    return goal

In [None]:
if tournament_mode:
    init_time = datetime.now()
    
    # These need to be updated with the neural nets that you trained and saved
    teamA_striker_path = "/kaggle/input/<dataset_name>/<neural_nets_filename>"
    teamA_keeper_path = "/kaggle/input/<dataset_name>/<neural_nets_filename>"
    teamB_striker_path = "/kaggle/input/<dataset_name>/<neural_nets_filename>"
    teamB_keeper_path = "/kaggle/input/<dataset_name>/<neural_nets_filename>"
    
    teamA_striker_jit_inference_fn = get_jit_inference_func(teamA_striker_path)
    teamA_keeper_jit_inference_fn = get_jit_inference_func(teamA_keeper_path)
    teamB_striker_jit_inference_fn = get_jit_inference_func(teamB_striker_path)
    teamB_keeper_jit_inference_fn = get_jit_inference_func(teamB_keeper_path)

    goals_teamA = 0
    goals_teamB = 0
    for i in range(num_penalties):
      print(f"\nPenalty {i}/{num_penalties}")
      goals_teamA += penalty(teamA_striker_jit_inference_fn, teamB_keeper_jit_inference_fn, i)
      print(f"Team_A {int(goals_teamA)} - {int(goals_teamB)} Team_B")
      
      goals_teamB += penalty(teamB_striker_jit_inference_fn, teamA_keeper_jit_inference_fn, i*10)
      print(f"Team_A {int(goals_teamA)} - {int(goals_teamB)} Team_B")
            
    print(f"\nFinal score: Team_A {int(goals_teamA)} - {int(goals_teamB)} Team_B")
    print(f"total time: {datetime.now() - init_time}")