# Tutorial: Robot Tricks

In this tutorial you will learn how to use the physics simulator MuJoCo, and Reinforcement Learning to teach a robot to bounce a ball with its foot - an essential skill in the art of football.

## 1 - Kaggle
It's recommended to run this notebook on www.kaggle.com where you can use two T4 GPUs for 30h/week for free.

To import this notebook into Kaggle you need to:
- Login to your Kaggle account
- Create a new notebook
- Click on `File` and then `Import Notebook`
- Select the tab `GitHub`
- Search `goncalog/ai-robotics`
- Select the file `tutorials/robot_tricks.ipynb`
- Click the `Import` button

To run this notebook you can either click the `Run All` button or run each cell individually by clicking the `Run current cell` button.

## 2 - Config

This is the configuration to run the tutorial, it includes:
- Training hyperparameters
- Mujoco environment variables
- File paths
- Rendering variables

In [None]:
num_timesteps = 60_000_000
num_evals = 9
# num_envs: the number of parallel environments to use for rollouts
num_envs = 2048

# learning_rate: learning rate for ppo loss
learning_rate = 3e-4
# discounting: discounting rate
discounting = 0.97
# episode_length: the length of an environment episode
episode_length = 1000
# normalize_observations: whether to normalize observations
normalize_observations = True
# action_repeat: the number of timesteps to repeat an action
action_repeat = 1
# unroll_length: the number of timesteps to unroll in each environment.
# The PPO loss is computed over `unroll_length` timesteps
unroll_length = 10
# entropy_cost: entropy reward for ppo loss, higher values increase entropy of the policy
entropy_cost = 1e-3
# batch_size: the batch size for each minibatch SGD step
batch_size = 1024
# num_minibatches: the number of times to run the SGD step,
# each with a different minibatch with leading dimension of `batch_size`
num_minibatches = 32
# num_updates_per_batch: the number of times to run the gradient update over
# all minibatches before doing a new environment rollout
num_updates_per_batch = 8
# reward_scaling: float scaling for reward
reward_scaling = 1
# clipping_epsilon: clipping epsilon for PPO loss
clipping_epsilon = 0.3
# gae_lambda: General advantage estimation lambda
gae_lambda = 0.95
# normalize_advantage: whether to normalize advantage estimate
normalize_advantage = True

policy_hidden_layer_sizes = (32,) * 4
value_hidden_layer_sizes = (256,) * 5

ball_size = 0.04
torso_index = 3 # index of torso body in mjx data (it contains the head geom)
ball_height = 0.5 # z coordinate of centre of mass
ball_x = 0.12 # x coordinate of centre of mass
ball_y = 0.05 # y coordinate of centre of mass
foot_left_index = 15 # index of foot_left body in mjx data
op3_contacts = False

# Simulation time step in seconds. 
# This is the single most important parameter affecting the speed-accuracy trade-off 
# which is inherent in every physics simulation. 
# Smaller values result in better accuracy and stability
mj_model_timestep = 0.005

save_path = "/kaggle/working/mjx_brax_nn"
op3_assets_path = "/kaggle/input/assets-op3"

num_rollouts = 1
num_bounces_threshold = 0

## 3 - Install MuJoCo, MJX, and Brax

In [None]:
!pip install mujoco==3.1.2
!pip install mujoco_mjx==3.1.2
!pip install brax==0.10.0

# Check if MuJoCo installation was successful
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but this
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/'
NVIDIA_ICD_CONFIG_FILE = '10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  os.makedirs(NVIDIA_ICD_CONFIG_PATH)
  file_path = os.path.join(NVIDIA_ICD_CONFIG_PATH, NVIDIA_ICD_CONFIG_FILE)
  with open(file_path, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

# Import MuJoCo, MJX, and Brax

from datetime import datetime
import functools
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.io import html, mjcf, model
from brax.training import distribution, networks

from etils import epath
from flax import linen, struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx

## 4 - Setting up the OP3 environment with MJX
MJX is an implementation of MuJoCo written in [JAX](https://jax.readthedocs.io/en/latest/index.html), enabling large batch training on GPU/TPU. In this notebook, we train RL policies with MJX.

Here we implement our environment by adapting the original [OP3](https://github.com/google-deepmind/mujoco_menagerie/blob/main/robotis_op3/op3.xml) environment to also include a ball. Notice that `reset` initializes a `State`, and `step` steps through the physics step and reward logic. The reward and stepping logic train the robot to bounce a ball with its left foot.

Note: if this is the first time you're running this environment on Kaggle, you'll have to upload the OP3 [assets](https://github.com/google-deepmind/mujoco_menagerie/tree/main/robotis_op3/assets) into the `op3_assets_path` set in section `2 - Config` (you can do this by clicking the `Upload` button on the right-hand side and then `New Dataset`)

In [None]:
# OP3 XML

ball_material = """
    <texture name="texgeom" type="cube" builtin="flat" mark="cross" width="128" height="128" 
        rgb1="0.6 0.6 0.6" rgb2="0.6 0.6 0.6" markrgb="1 1 1"/>
    <material name="ball" texture="texgeom" texuniform="true" rgba=".1 .9 .1 1" />
    """
ball_default = f"""
    <default class="ball">
      <geom type="sphere" material="ball" size="{ball_size}" mass="0.045" friction="0.7 0.075 0.075" solref="0.02 1.0"/>
    </default>
    """
ball_body = f"""
    <body name="ball" pos="{ball_x} {ball_y} {ball_height}" quat="0.632456 -0.632456 0.316228 0.316228">
      <freejoint/>
      <geom class="ball" name="ball"/>
    </body>
    """
if op3_contacts:
    body_collision = """
        <geom mesh="bodyc" class="collision"/>
        <geom mesh="body1c" class="collision"/>
        <geom mesh="body2c" class="collision"/>
        <geom mesh="body3c" class="collision"/>
        <geom mesh="body4c" class="collision"/>
        """
    h1c_collision = '<geom mesh="h1c" class="collision"/>'
    h2_collision = """
        <geom mesh="h2c" class="collision"/>
        <geom mesh="h21c" class="collision"/>
        <geom mesh="h22c" class="collision"/>
        """
    la1c_collision = '<geom mesh="la1c" class="collision"/>'
    la2c_collision = '<geom mesh="la2c" class="collision"/>'
    la3c_collision = '<geom mesh="la3c" class="collision"/>'
    ra1c_collision = '<geom mesh="ra1c" class="collision"/>'
    ra2c_collision = '<geom mesh="ra2c" class="collision"/>'
    ra3c_collision = '<geom mesh="ra3c" class="collision"/>'
    ll1c_collision = '<geom mesh="ll1c" class="collision"/>'
    ll2c_collision = '<geom mesh="ll2c" class="collision"/>'
    ll3c_collision = '<geom mesh="ll3c" class="collision"/>'
    ll4c_collision = '<geom mesh="ll4c" class="collision"/>'
    ll5c_collision = '<geom mesh="ll5c" class="collision"/>'
    rl1c_collision = '<geom mesh="rl1c" class="collision"/>'
    rl2c_collision = '<geom mesh="rl2c" class="collision"/>'
    rl3c_collision = '<geom mesh="rl3c" class="collision"/>'
    rl4c_collision = '<geom mesh="rl4c" class="collision"/>'
    rl5c_collision = '<geom mesh="rl5c" class="collision"/>'
    foot_collision = 'class="collision"'
    
else:
    body_collision = ""
    h1c_collision = ""
    h2_collision = ""
    la1c_collision = ""
    la2c_collision = ""
    la3c_collision = ""
    ra1c_collision = ""
    ra2c_collision = ""
    ra3c_collision = ""
    ll1c_collision = ""
    ll2c_collision = ""
    ll3c_collision = ""
    ll4c_collision = ""
    ll5c_collision = ""
    rl1c_collision = ""
    rl2c_collision = ""
    rl3c_collision = ""
    rl4c_collision = ""
    rl5c_collision = ""
    foot_collision = 'class="no_collision"'

xml = f"""
<mujoco model="op3 and a ball">
  <compiler angle="radian" meshdir="{op3_assets_path}" autolimits="true"/>

  <default>
    <mesh scale="0.001 0.001 0.001"/>
    <geom type="mesh" solref=".004 1"/>
    <joint damping="1.084" armature="0.045"/>
    <site group="5" type="sphere"/>
    <position kp="21.1" ctrlrange="-3.141592 3.141592" forcerange="-5 5"/>
    <default {foot_collision}>
      <geom group="3"/>
      <default class="foot">
        <geom mass="0" type="box"/>
      </default>
    </default>
    <default class="visual">
      <geom material="black" contype="0" conaffinity="0" group="2"/>
    </default>
    {ball_default}
  </default>

  <asset>
    <texture type="skybox" builtin="gradient" rgb1="0.3 0.5 0.7" rgb2="0 0 0" width="512" height="3072"/>
    <texture type="2d" name="groundplane" builtin="checker" mark="edge" rgb1="0.2 0.3 0.4" rgb2="0.1 0.2 0.3"
      markrgb="0.8 0.8 0.8" width="300" height="300"/>
    <material name="groundplane" texture="groundplane" texuniform="true" texrepeat="5 5" reflectance="0.2"/>
    {ball_material}
    <material name="black" rgba="0.2 0.2 0.2 1"/>

    <mesh file="body.stl"/>
    <mesh file="ll1.stl"/>
    <mesh file="ll2.stl"/>
    <mesh file="ll3.stl"/>
    <mesh file="ll4.stl"/>
    <mesh file="ll5.stl"/>
    <mesh file="ll6.stl"/>
    <mesh file="rl1.stl"/>
    <mesh file="rl2.stl"/>
    <mesh file="rl3.stl"/>
    <mesh file="rl4.stl"/>
    <mesh file="rl5.stl"/>
    <mesh file="rl6.stl"/>
    <mesh file="la1.stl"/>
    <mesh file="la2.stl"/>
    <mesh file="la3.stl"/>
    <mesh file="ra1.stl"/>
    <mesh file="ra2.stl"/>
    <mesh file="ra3.stl"/>
    <mesh file="h1.stl"/>
    <mesh file="h2.stl"/>
    <mesh name="bodyc" file="simplified_convex/body.stl"/>
    <mesh name="body1c" file="simplified_convex/body_sub1.stl"/>
    <mesh name="body2c" file="simplified_convex/body_sub2.stl"/>
    <mesh name="body3c" file="simplified_convex/body_sub3.stl"/>
    <mesh name="body4c" file="simplified_convex/body_sub4.stl"/>
    <mesh name="ll1c" file="simplified_convex/ll1.stl"/>
    <mesh name="ll2c" file="simplified_convex/ll2.stl"/>
    <mesh name="ll3c" file="simplified_convex/ll3.stl"/>
    <mesh name="ll4c" file="simplified_convex/ll4.stl"/>
    <mesh name="ll5c" file="simplified_convex/ll5.stl"/>
    <mesh name="ll6c" file="simplified_convex/ll6.stl"/>
    <mesh name="rl1c" file="simplified_convex/rl1.stl"/>
    <mesh name="rl2c" file="simplified_convex/rl2.stl"/>
    <mesh name="rl3c" file="simplified_convex/rl3.stl"/>
    <mesh name="rl4c" file="simplified_convex/rl4.stl"/>
    <mesh name="rl5c" file="simplified_convex/rl5.stl"/>
    <mesh name="rl6c" file="simplified_convex/rl6.stl"/>
    <mesh name="la1c" file="simplified_convex/la1.stl"/>
    <mesh name="la2c" file="simplified_convex/la2.stl"/>
    <mesh name="la3c" file="simplified_convex/la3.stl"/>
    <mesh name="ra1c" file="simplified_convex/ra1.stl"/>
    <mesh name="ra2c" file="simplified_convex/ra2.stl"/>
    <mesh name="ra3c" file="simplified_convex/ra3.stl"/>
    <mesh name="h1c" file="simplified_convex/h1.stl"/>
    <mesh name="h2c" file="simplified_convex/h2.stl"/>
    <mesh name="h21c" file="simplified_convex/h2_sub1.stl"/>
    <mesh name="h22c" file="simplified_convex/h2_sub2.stl"/>
  </asset>

  <worldbody>
    <light pos="0 0 3.5" dir="0 0 -1" directional="true"/>
    <geom name="floor" size="20 20 0.005" type="plane" material="groundplane"/>

    <light name="spotlight" mode="targetbodycom" target="body_link" pos="0 -1 2"/>
    <body name="body_link" pos="0 0 0.3">
      <camera name="back" pos="-1.5 0 0.5" xyaxes="0 -1 0 1 0 3" mode="trackcom"/>
      <camera name="side" pos="0 -1.5 0.5" xyaxes="1 0 0 0 1 3" mode="trackcom"/>
      <inertial pos="-0.01501 0.00013 0.06582" quat="0.704708 0.704003 0.0667707 -0.0575246" mass="1.34928"
        diaginertia="0.00341264 0.00316574 0.00296931"/>
      <freejoint/>
      <geom mesh="body" class="visual"/>
      {body_collision}
      <site name="torso"/>
      <body name="head_pan_link" pos="-0.001 0 0.1365">
        <inertial pos="0.00233 0 0.00823" quat="0.663575 0.663575 0.244272 -0.244272" mass="0.01176"
          diaginertia="4.23401e-06 3.60599e-06 1.65e-06"/>
        <joint name="head_pan" axis="0 0 1"/>
        <geom mesh="h1" class="visual"/>
        {h1c_collision}
        <body name="head_tilt_link" pos="0.01 0.019 0.0285">
          <inertial pos="0.0023 -0.01863 0.0277" quat="0.997312 0.00973825 0.0726131 -0.00102702" mass="0.13631"
            diaginertia="0.000107452 8.72266e-05 4.39413e-05"/>
          <joint name="head_tilt" axis="0 -1 0"/>
          <geom mesh="h2" class="visual"/>
          {h2_collision}
          <camera name="egocentric" pos="0.01425 -0.019 0.04975" fovy="43.3" mode="fixed"
            euler="0.0 -1.570796 -1.570796"/>
        </body>
      </body>
      <body name="l_sho_pitch_link" pos="-0.001 0.06 0.111">
        <inertial pos="0 0.00823 -0.00233" quat="0.244272 0.663575 0.244272 0.663575" mass="0.01176"
          diaginertia="4.23401e-06 3.60599e-06 1.65e-06"/>
        <joint name="l_sho_pitch" axis="0 1 0"/>
        <geom mesh="la1" class="visual"/>
        {la1c_collision}
        <body name="l_sho_roll_link" pos="0.019 0.0285 -0.01">
          <inertial pos="-0.01844 0.04514 0.00028" quat="0.501853 0.50038 -0.498173 0.499588" mass="0.17758"
            diaginertia="0.000234742 0.00022804 3.04183e-05"/>
          <joint name="l_sho_roll" axis="-1 0 0"/>
          <geom mesh="la2" class="visual"/>
          {la2c_collision}
          <body name="l_el_link" pos="0 0.0904 -0.0001">
            <inertial pos="-0.019 0.07033 0.0038" quat="0.483289 0.51617 -0.51617 0.483289" mass="0.04127"
              diaginertia="6.8785e-05 6.196e-05 1.2065e-05"/>
            <joint name="l_el" axis="1 0 0"/>
            <geom mesh="la3" class="visual"/>
            {la3c_collision}
          </body>
        </body>
      </body>
      <body name="r_sho_pitch_link" pos="-0.001 -0.06 0.111">
        <inertial pos="0 -0.00823 -0.00233" quat="-0.244272 0.663575 -0.244272 0.663575" mass="0.01176"
          diaginertia="4.23401e-06 3.60599e-06 1.65e-06"/>
        <joint name="r_sho_pitch" axis="0 -1 0"/>
        <geom mesh="ra1" class="visual"/>
        {ra1c_collision}
        <body name="r_sho_roll_link" pos="0.019 -0.0285 -0.01">
          <inertial pos="-0.01844 -0.04514 0.00028" quat="0.50038 0.501853 -0.499588 0.498173" mass="0.17758"
            diaginertia="0.000234742 0.00022804 3.04183e-05"/>
          <joint name="r_sho_roll" axis="-1 0 0"/>
          <geom mesh="ra2" class="visual"/>
          {ra2c_collision}
          <body name="r_el_link" pos="0 -0.0904 -0.0001">
            <inertial pos="-0.019 -0.07033 0.0038" quat="0.51617 0.483289 -0.483289 0.51617" mass="0.04127"
              diaginertia="6.8785e-05 6.196e-05 1.2065e-05"/>
            <joint name="r_el" axis="1 0 0"/>
            <geom mesh="ra3" class="visual"/>
            {ra3c_collision}
          </body>
        </body>
      </body>
      <body name="l_hip_yaw_link" pos="0 0.035 0">
        <inertial pos="-0.00157 0 -0.00774" quat="0.499041 0.500957 0.500957 0.499041" mass="0.01181"
          diaginertia="4.3e-06 4.12004e-06 1.50996e-06"/>
        <joint name="l_hip_yaw" axis="0 0 -1"/>
        <geom mesh="ll1" class="visual"/>
        {ll1c_collision}
        <body name="l_hip_roll_link" pos="-0.024 0 -0.0285">
          <inertial pos="0.00388 0.00028 -0.01214" quat="0.502657 0.490852 0.498494 0.507842" mass="0.17886"
            diaginertia="0.000125243 0.000108598 4.65693e-05"/>
          <joint name="l_hip_roll" axis="-1 0 0"/>
          <geom mesh="ll2" class="visual"/>
          {ll2c_collision}
          <body name="l_hip_pitch_link" pos="0.0241 0.019 0">
            <inertial pos="0.00059 -0.01901 -0.08408" quat="0.999682 0.0246915 0.00447825 -0.002482" mass="0.11543"
              diaginertia="0.000104996 9.63044e-05 2.47492e-05"/>
            <joint name="l_hip_pitch" axis="0 1 0"/>
            <geom mesh="ll3" class="visual"/>
            {ll3c_collision}
            <body name="l_knee_link" pos="0 0 -0.11015">
              <inertial pos="0 -0.02151 -0.055" mass="0.04015" diaginertia="3.715e-05 2.751e-05 1.511e-05"/>
              <joint name="l_knee" axis="0 1 0"/>
              <geom name="left_shin" mesh="ll4" class="visual"/>
              {ll4c_collision}
              <body name="l_ank_pitch_link" pos="0 0 -0.11">
                <inertial pos="-0.02022 -0.01872 0.01214" quat="0.490852 0.502657 0.507842 0.498494" mass="0.17886"
                  diaginertia="0.000125243 0.000108598 4.65693e-05"/>
                <joint name="l_ank_pitch" axis="0 -1 0"/>
                <geom name="left_ankle" mesh="ll5" class="visual"/>
                {ll5c_collision}
                <body name="l_ank_roll_link" pos="-0.0241 -0.019 0">
                  <inertial pos="0.02373 0.01037 -0.0276" quat="0.0078515 0.707601 0.0113965 0.706477" mass="0.06934"
                    diaginertia="0.000115818 7.87135e-05 4.03389e-05"/>
                  <joint name="l_ank_roll" axis="1 0 0"/>
                  <geom name="left_foot" mesh="ll6" class="visual"/>
                  <geom name="foot1_left" class="foot" pos="0.024 0.013 -0.0265" size="0.0635 0.028 0.004"/>
                  <geom name="foot2_left" class="foot" pos="0.024 0.0125 -0.0265" size="0.057 0.039 0.004"/>
                </body>
              </body>
            </body>
          </body>
        </body>
      </body>
      <body name="r_hip_yaw_link" pos="0 -0.035 0">
        <inertial pos="-0.00157 0 -0.00774" quat="0.499041 0.500957 0.500957 0.499041" mass="0.01181"
          diaginertia="4.3e-06 4.12004e-06 1.50996e-06"/>
        <joint name="r_hip_yaw" axis="0 0 -1"/>
        <geom mesh="rl1" class="visual"/>
        {rl1c_collision}
        <body name="r_hip_roll_link" pos="-0.024 0 -0.0285">
          <inertial pos="0.00388 -0.00028 -0.01214" quat="0.507842 0.498494 0.490852 0.502657" mass="0.17886"
            diaginertia="0.000125243 0.000108598 4.65693e-05"/>
          <joint name="r_hip_roll" axis="-1 0 0"/>
          <geom mesh="rl2" class="visual"/>
          {rl2c_collision}
          <body name="r_hip_pitch_link" pos="0.0241 -0.019 0">
            <inertial pos="0.00059 0.01901 -0.08408" quat="0.999682 -0.0246915 0.00447825 0.002482" mass="0.11543"
              diaginertia="0.000104996 9.63044e-05 2.47492e-05"/>
            <joint name="r_hip_pitch" axis="0 -1 0"/>
            <geom mesh="rl3" class="visual"/>
            {rl3c_collision}
            <body name="r_knee_link" pos="0 0 -0.11015">
              <inertial pos="0 0.02151 -0.055" mass="0.04015" diaginertia="3.715e-05 2.751e-05 1.511e-05"/>
              <joint name="r_knee" axis="0 -1 0"/>
              <geom mesh="rl4" class="visual"/>
              {rl4c_collision}
              <body name="r_ank_pitch_link" pos="0 0 -0.11">
                <inertial pos="-0.02022 0.01872 0.01214" quat="0.498494 0.507842 0.502657 0.490852" mass="0.17886"
                  diaginertia="0.000125243 0.000108598 4.65693e-05"/>
                <joint name="r_ank_pitch" axis="0 1 0"/>
                <geom mesh="rl5" class="visual"/>
                {rl5c_collision}
                <body name="r_ank_roll_link" pos="-0.0241 0.019 0">
                  <inertial pos="0.02373 -0.01037 -0.0276" quat="-0.0078515 0.707601 -0.0113965 0.706477" mass="0.06934"
                    diaginertia="0.000115818 7.87135e-05 4.03389e-05"/>
                  <joint name="r_ank_roll" axis="1 0 0"/>
                  <geom mesh="rl6" class="visual"/>
                  <geom name="foot1_right" class="foot" pos="0.024 -0.013 -0.0265" size="0.0635 0.028 0.004"/>
                  <geom name="foot2_right" class="foot" pos="0.024 -0.0125 -0.0265" size="0.057 0.039 0.004"/>
                </body>
              </body>
            </body>
          </body>
        </body>
      </body>
    </body>
    {ball_body}
  </worldbody>

  <contact>
    <exclude body1="l_hip_yaw_link" body2="l_hip_pitch_link"/>
    <exclude body1="r_hip_yaw_link" body2="r_hip_pitch_link"/>
    <pair geom1="foot1_left" geom2="floor"/>
    <pair geom1="foot1_right" geom2="floor"/>
    <pair geom1="foot2_left" geom2="floor"/>
    <pair geom1="foot2_right" geom2="floor"/>
    <pair geom1="foot1_left" geom2="ball"/>
    <pair geom1="foot2_left" geom2="ball"/>
    <pair geom1="left_ankle" geom2="ball"/>
    <pair geom1="left_shin" geom2="ball"/>
  </contact>

  <actuator>
    <position name="head_pan_act" joint="head_pan"/>
    <position name="head_tilt_act" joint="head_tilt"/>
    <position name="l_sho_pitch_act" joint="l_sho_pitch"/>
    <position name="l_sho_roll_act" joint="l_sho_roll"/>
    <position name="l_el_act" joint="l_el"/>
    <position name="r_sho_pitch_act" joint="r_sho_pitch"/>
    <position name="r_sho_roll_act" joint="r_sho_roll"/>
    <position name="r_el_act" joint="r_el"/>
    <position name="l_hip_yaw_act" joint="l_hip_yaw"/>
    <position name="l_hip_roll_act" joint="l_hip_roll"/>
    <position name="l_hip_pitch_act" joint="l_hip_pitch"/>
    <position name="l_knee_act" joint="l_knee"/>
    <position name="l_ank_pitch_act" joint="l_ank_pitch"/>
    <position name="l_ank_roll_act" joint="l_ank_roll"/>
    <position name="r_hip_yaw_act" joint="r_hip_yaw"/>
    <position name="r_hip_roll_act" joint="r_hip_roll"/>
    <position name="r_hip_pitch_act" joint="r_hip_pitch"/>
    <position name="r_knee_act" joint="r_knee"/>
    <position name="r_ank_pitch_act" joint="r_ank_pitch"/>
    <position name="r_ank_roll_act" joint="r_ank_roll"/>
  </actuator>

</mujoco>
"""

In [None]:
# OP3 Env

class OP3(PipelineEnv):

  def __init__(
      self,
      terminate_when_unhealthy=True,
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,
  ):
    mj_model = mujoco.MjModel.from_xml_string(xml)
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )


  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )
    
    data = self.pipeline_init(qpos, qvel)

    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'ball_reward': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'reward': zero,
        'bounces': zero,
    }
    return State(data, obs, reward, done, metrics)


  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    
    reward, done = self._get_reward(state, action, data0, data)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )


  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes robot body and ball position, velocities, and angles."""
    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # external_contact_forces are excluded
    return jp.concatenate([
        # qpos: position / nq: number of generalized coordinates = dim(qpos)
        position,
        # qvel: velocity / nv: number of degrees of freedom = dim(qvel)
        data.qvel,
        # cinert: com-based body inertia and mass / (nbody, 10)
        data.cinert[1:].ravel(),
        # cvel: com-based velocity [3D rot; 3D tran] / (nbody, 6)
        data.cvel[1:].ravel(),
        # qfrc_actuator: actuator force / nv: number of degrees of freedom
        data.qfrc_actuator,
    ])


  def _get_reward(
      self, state: State,  action: jp.ndarray, data0: mjx.Data, data: mjx.Data
  ) -> Tuple[jp.ndarray, jp.ndarray]:
    """Apply reward func based on ball distance to normal of the left foot and target height."""
    ctrl_cost_weight = 0.1
    healthy_reward = 5.0
    healthy_z_range = (0.4, 1.5)
    ball_reward = 5.0
    ball_healthy_z_range = (ball_size*2.1, 1.0)
    ball_reward_min_z = ball_size*2.1
    ball_reward_target_z = 0.5
    distance_feet_reward = 5.0
    distance_feet_max = 1.0
    bounce_threshold = ball_reward_target_z - 0.05 # z coordinate
    
    com_before_ball = data0.subtree_com[-1]
    com_after_ball = data.subtree_com[-1]
    com_after_foot = data.subtree_com[foot_left_index]
    distance_foot = jp.sqrt(jp.square(com_after_ball[0] - com_after_foot[0]) + jp.square(com_after_ball[1] - com_after_foot[1]))

    min_z, max_z = healthy_z_range
    is_healthy = jp.where(data.q[torso_index] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[torso_index] > max_z, 0.0, is_healthy)

    ball_min_z, ball_max_z = ball_healthy_z_range
    is_healthy = jp.where(com_after_ball[2] < ball_min_z, 0.0, is_healthy)
    is_healthy = jp.where(com_after_ball[2] > ball_max_z, 0.0, is_healthy)
    
    is_healthy = jp.where(distance_foot > distance_feet_max, 0.0, is_healthy)

    ctrl_cost = ctrl_cost_weight * jp.sum(jp.square(action))
    
    distance_target_height = jp.sqrt(jp.square(com_after_ball[2] - ball_reward_target_z))
    ball_reward = ball_reward * (1 - (distance_target_height / (ball_max_z - ball_reward_target_z)))
    is_ball_reward = jp.where(com_after_ball[2] >= ball_reward_min_z, 1.0, 0.0)
    
    distance_feet_reward = distance_feet_reward * (1 - (distance_foot / distance_feet_max))
    
    reward = ball_reward * is_ball_reward + healthy_reward - ctrl_cost + distance_feet_reward

    state.metrics.update(
        ball_reward=ball_reward * is_ball_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        reward=reward,
        bounces=self._is_bounce(com_before_ball, com_after_ball, bounce_threshold),
    )
    
    done = 1.0 - is_healthy
    return reward, done
  
  
  # There's a lot of room to improve this function as it should check for contacts
  # between the ball and the lower left limb of the robot
  # (at the time of implementation the contacts data wasn't easily accessible in MJX)
  def _is_bounce(
      self, com_before_ball: jp.ndarray, com_after_ball: jp.ndarray, bounce_threshold: jp.ndarray
  ) -> jp.ndarray:
    """Check if ball bounced."""
    is_bounce = jp.where(com_before_ball[2] < bounce_threshold, 1.0, 0.0)
    is_bounce = jp.where(com_after_ball[2] >= bounce_threshold, is_bounce, 0.0)
    return is_bounce
    

envs.register_environment("op3", OP3)

## 5 - Visualize a rollout

Let's instantiate the environment and visualize a short rollout.

NOTE: Since episodes terminate early if the torso is below the healthy z-range, the only relevant contacts for this task are between the feet and the plane, and the lower left limb and the ball. The other contacts weren't included. This also speeds up the training later on.

In [None]:
# Instantiate the environment
env_name = "op3"
env = envs.get_environment(env_name)

# Define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [None]:
# Initialize the state
state = jit_reset(jax.random.PRNGKey(0))
print(f"Observations size: {len(state.obs)}")
print(f"Actions size: {env.sys.nu}")

rollout = [state.pipeline_state]

# Grab a trajectory
for i in range(50):
  # ctrl: control / nu: number of actuators/controls = dim(ctrl)
  ctrl = -0.1 * jp.ones(env.sys.nu)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

media.show_video(env.render(rollout, camera='side', height=480, width=640), fps=1.0 / env.dt)

## 6 - Define the training functions

Let's define the training functions using [PPO](https://openai.com/research/openai-baselines-ppo) to make the robot bounce the ball with its foot.

In [None]:
# Define the Acting/Evaluator (adapted from https://github.com/google/brax)

"""Brax training acting functions."""

import time
from typing import Callable, Sequence, Tuple, Union

from brax import envs
from brax.training.types import Metrics
from brax.training.types import Policy
from brax.training.types import PolicyParams
from brax.training.types import PRNGKey
from brax.training.types import Transition
from brax.v1 import envs as envs_v1
import jax
import numpy as np

ActingState = Union[envs.State, envs_v1.State]
ActingEnv = Union[envs.Env, envs_v1.Env, envs_v1.Wrapper]


def actor_step(
    env: ActingEnv,
    env_state: ActingState,
    policy: Policy,
    key: PRNGKey,
    extra_fields: Sequence[str] = ()
) -> Tuple[ActingState, Transition]:
  """Collect data."""
  actions, policy_extras = policy(env_state.obs, key)
  nstate = env.step(env_state, actions)
  state_extras = {x: nstate.info[x] for x in extra_fields}
  return nstate, Transition(  # pytype: disable=wrong-arg-types  # jax-ndarray
      observation=env_state.obs,
      action=actions,
      reward=nstate.reward,
      discount=1 - nstate.done,
      next_observation=nstate.obs,
      extras={
          'policy_extras': policy_extras,
          'state_extras': state_extras
      })


def generate_unroll(
    env: ActingEnv,
    env_state: ActingState,
    policy: Policy,
    key: PRNGKey,
    unroll_length: int,
    extra_fields: Sequence[str] = ()
) -> Tuple[ActingState, Transition]:
  """Collect trajectories of given unroll_length."""

  @jax.jit
  def f(carry, unused_t):
    state, current_key = carry
    current_key, next_key = jax.random.split(current_key)
    nstate, transition = actor_step(
        env, state, policy, current_key, extra_fields=extra_fields)
    return (nstate, next_key), transition

  (final_state, _), data = jax.lax.scan(
      f, (env_state, key), (), length=unroll_length)
  return final_state, data


# TODO: Consider moving this to its own file.
class Evaluator:
  """Class to run evaluations."""

  def __init__(self, eval_env: envs.Env,
               eval_policy_fn: Callable[[PolicyParams],
                                        Policy], num_eval_envs: int,
               episode_length: int, action_repeat: int, key: PRNGKey):
    """Init.

    Args:
      eval_env: Batched environment to run evals on.
      eval_policy_fn: Function returning the policy from the policy parameters.
      num_eval_envs: Each env will run 1 episode in parallel for each eval.
      episode_length: Maximum length of an episode.
      action_repeat: Number of physics steps per env step.
      key: RNG key.
    """
    self._key = key
    self._eval_walltime = 0.

    eval_env = envs.training.EvalWrapper(eval_env)

    def generate_eval_unroll(policy_params: PolicyParams,
                             key: PRNGKey) -> ActingState:
      reset_keys = jax.random.split(key, num_eval_envs)
      eval_first_state = eval_env.reset(reset_keys)
      return generate_unroll(
          eval_env,
          eval_first_state,
          eval_policy_fn(policy_params),
          key,
          unroll_length=episode_length // action_repeat)[0]

    self._generate_eval_unroll = jax.jit(generate_eval_unroll)
    self._steps_per_unroll = episode_length * num_eval_envs

  def run_evaluation(self,
                     policy_params: PolicyParams,
                     training_metrics: Metrics,
                     aggregate_episodes: bool = True) -> Metrics:
    """Run one epoch of evaluation."""
    self._key, unroll_key = jax.random.split(self._key)

    t = time.time()
    eval_state = self._generate_eval_unroll(policy_params, unroll_key)
    eval_metrics = eval_state.info['eval_metrics']
    eval_metrics.active_episodes.block_until_ready()
    epoch_eval_time = time.time() - t
    metrics = {}
    for fn in [np.mean, np.std, np.max]:
      suffix = '_std' if fn == np.std else '_max' if fn == np.max else ''
      metrics.update(
          {
              f'eval/episode_{name}{suffix}': (
                  fn(value) if aggregate_episodes else value
              )
              for name, value in eval_metrics.episode_metrics.items()
          }
      )
    metrics['eval/avg_episode_length'] = np.mean(eval_metrics.episode_steps)
    metrics['eval/epoch_eval_time'] = epoch_eval_time
    metrics['eval/sps'] = self._steps_per_unroll / epoch_eval_time
    self._eval_walltime = self._eval_walltime + epoch_eval_time
    metrics = {
        'eval/walltime': self._eval_walltime,
        **training_metrics,
        **metrics
    }

    return metrics  # pytype: disable=bad-return-type  # jax-ndarray

In [None]:
# Define the Training Function (adapted from https://github.com/google/brax)

"""Proximal policy optimization training.

See: https://arxiv.org/pdf/1707.06347.pdf
"""

import functools
import time
from typing import Callable, Optional, Tuple, Union

from absl import logging
from brax import base
from brax import envs
from brax.training import gradients
from brax.training import pmap
from brax.training import types
from brax.training.acme import running_statistics
from brax.training.acme import specs
from brax.training.agents.ppo import losses as ppo_losses
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.types import Params, PolicyParams, PreprocessorParams
from brax.training.types import PRNGKey
from brax.v1 import envs as envs_v1
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax


InferenceParams = Tuple[running_statistics.NestedMeanStd, Params]
Metrics = types.Metrics
ValueParams = Any

_PMAP_AXIS_NAME = 'i'


@flax.struct.dataclass
class TrainingState:
  """Contains training state for the learner."""
  optimizer_state: optax.OptState
  params: ppo_losses.PPONetworkParams
  normalizer_params: running_statistics.RunningStatisticsState
  env_steps: jnp.ndarray


def _unpmap(v):
  return jax.tree_util.tree_map(lambda x: x[0], v)


def _strip_weak_type(tree):
  # brax user code is sometimes ambiguous about weak_type.  in order to
  # avoid extra jit recompilations we strip all weak types from user input
  def f(leaf):
    leaf = jnp.asarray(leaf)
    return leaf.astype(leaf.dtype)
  return jax.tree_util.tree_map(f, tree)


def train(
    environment: Union[envs_v1.Env, envs.Env],
    num_timesteps: int,
    episode_length: int,
    action_repeat: int = 1,
    num_envs: int = 1,
    max_devices_per_host: Optional[int] = None,
    num_eval_envs: int = 128,
    learning_rate: float = 1e-4,
    entropy_cost: float = 1e-4,
    discounting: float = 0.9,
    seed: int = 0,
    unroll_length: int = 10,
    batch_size: int = 32,
    num_minibatches: int = 16,
    num_updates_per_batch: int = 2,
    num_evals: int = 1,
    num_resets_per_eval: int = 0,
    normalize_observations: bool = False,
    reward_scaling: float = 1.0,
    clipping_epsilon: float = 0.3,
    gae_lambda: float = 0.95,
    deterministic_eval: bool = False,
    network_factory: types.NetworkFactory[
        ppo_networks.PPONetworks
    ] = ppo_networks.make_ppo_networks,
    progress_fn: Callable[[int, Metrics], None] = lambda *args: None,
    normalize_advantage: bool = True,
    eval_env: Optional[envs.Env] = None,
    policy_params_fn: Callable[..., None] = lambda *args: None,
    randomization_fn: Optional[
        Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]]
    ] = None,
    saved_params: Optional[
        Tuple[PreprocessorParams, PolicyParams, ValueParams]
    ] = None,
):
  """PPO training.

  Args:
    environment: the environment to train
    num_timesteps: the total number of environment steps to use during training
    episode_length: the length of an environment episode
    action_repeat: the number of timesteps to repeat an action
    num_envs: the number of parallel environments to use for rollouts
      NOTE: `num_envs` must be divisible by the total number of chips since each
        chip gets `num_envs // total_number_of_chips` environments to roll out
      NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since
        data generated by `num_envs` parallel envs gets used for gradient
        updates over `num_minibatches` of data, where each minibatch has a
        leading dimension of `batch_size`
    max_devices_per_host: maximum number of chips to use per host process
    num_eval_envs: the number of envs to use for evaluation. Each env will run 1
      episode, and all envs run in parallel during eval.
    learning_rate: learning rate for ppo loss
    entropy_cost: entropy reward for ppo loss, higher values increase entropy
      of the policy
    discounting: discounting rate
    seed: random seed
    unroll_length: the number of timesteps to unroll in each environment. The
      PPO loss is computed over `unroll_length` timesteps
    batch_size: the batch size for each minibatch SGD step
    num_minibatches: the number of times to run the SGD step, each with a
      different minibatch with leading dimension of `batch_size`
    num_updates_per_batch: the number of times to run the gradient update over
      all minibatches before doing a new environment rollout
    num_evals: the number of evals to run during the entire training run.
      Increasing the number of evals increases total training time
    num_resets_per_eval: the number of environment resets to run between each
      eval. The environment resets occur on the host
    normalize_observations: whether to normalize observations
    reward_scaling: float scaling for reward
    clipping_epsilon: clipping epsilon for PPO loss
    gae_lambda: General advantage estimation lambda
    deterministic_eval: whether to run the eval with a deterministic policy
    network_factory: function that generates networks for policy and value
      functions
    progress_fn: a user-defined callback function for reporting/plotting metrics
    normalize_advantage: whether to normalize advantage estimate
    eval_env: an optional environment for eval only, defaults to `environment`
    policy_params_fn: a user-defined callback function that can be used for
      saving policy checkpoints
    randomization_fn: a user-defined callback function that generates randomized
      environments
    saved_params: params to init the training with; includes normalizer_params
      and policy and value network params

  Returns:
    Tuple of (make_policy function, network params, metrics)
  """
  assert batch_size * num_minibatches % num_envs == 0
  xt = time.time()

  process_count = jax.process_count()
  process_id = jax.process_index()
  local_device_count = jax.local_device_count()
  local_devices_to_use = local_device_count
  if max_devices_per_host:
    local_devices_to_use = min(local_devices_to_use, max_devices_per_host)
  logging.info(
      'Device count: %d, process count: %d (id %d), local device count: %d, '
      'devices to be used count: %d', jax.device_count(), process_count,
      process_id, local_device_count, local_devices_to_use)
  device_count = local_devices_to_use * process_count

  # The number of environment steps executed for every training step.
  env_step_per_training_step = (
      batch_size * unroll_length * num_minibatches * action_repeat)
  num_evals_after_init = max(num_evals - 1, 1)
  # The number of training_step calls per training_epoch call.
  # equals to ceil(num_timesteps / (num_evals * env_step_per_training_step *
  #                                 num_resets_per_eval))
  num_training_steps_per_epoch = np.ceil(
      num_timesteps
      / (
          num_evals_after_init
          * env_step_per_training_step
          * max(num_resets_per_eval, 1)
      )
  ).astype(int)

  key = jax.random.PRNGKey(seed)
  global_key, local_key = jax.random.split(key)
  del key
  local_key = jax.random.fold_in(local_key, process_id)
  local_key, key_env, eval_key = jax.random.split(local_key, 3)
  # key_networks should be global, so that networks are initialized the same
  # way for different processes.
  key_policy, key_value = jax.random.split(global_key)
  del global_key

  assert num_envs % device_count == 0

  v_randomization_fn = None
  if randomization_fn is not None:
    randomization_batch_size = num_envs // local_device_count
    # all devices gets the same randomization rng
    randomization_rng = jax.random.split(key_env, randomization_batch_size)
    v_randomization_fn = functools.partial(
        randomization_fn, rng=randomization_rng
    )

  if isinstance(environment, envs.Env):
    wrap_for_training = envs.training.wrap
  else:
    wrap_for_training = envs_v1.wrappers.wrap_for_training

  env = wrap_for_training(
      environment,
      episode_length=episode_length,
      action_repeat=action_repeat,
      randomization_fn=v_randomization_fn,
  )

  reset_fn = jax.jit(jax.vmap(env.reset))
  key_envs = jax.random.split(key_env, num_envs // process_count)
  key_envs = jnp.reshape(key_envs,
                         (local_devices_to_use, -1) + key_envs.shape[1:])
  env_state = reset_fn(key_envs)

  normalize = lambda x, y: x
  if normalize_observations:
    normalize = running_statistics.normalize
  ppo_network = network_factory(
      env_state.obs.shape[-1],
      env.action_size,
      preprocess_observations_fn=normalize)
  make_policy = ppo_networks.make_inference_fn(ppo_network)

  optimizer = optax.adam(learning_rate=learning_rate)

  loss_fn = functools.partial(
      ppo_losses.compute_ppo_loss,
      ppo_network=ppo_network,
      entropy_cost=entropy_cost,
      discounting=discounting,
      reward_scaling=reward_scaling,
      gae_lambda=gae_lambda,
      clipping_epsilon=clipping_epsilon,
      normalize_advantage=normalize_advantage)

  gradient_update_fn = gradients.gradient_update_fn(
      loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True)

  def minibatch_step(
      carry, data: types.Transition,
      normalizer_params: running_statistics.RunningStatisticsState):
    optimizer_state, params, key = carry
    key, key_loss = jax.random.split(key)
    (_, metrics), params, optimizer_state = gradient_update_fn(
        params,
        normalizer_params,
        data,
        key_loss,
        optimizer_state=optimizer_state)

    return (optimizer_state, params, key), metrics

  def sgd_step(carry, unused_t, data: types.Transition,
               normalizer_params: running_statistics.RunningStatisticsState):
    optimizer_state, params, key = carry
    key, key_perm, key_grad = jax.random.split(key, 3)

    def convert_data(x: jnp.ndarray):
      x = jax.random.permutation(key_perm, x)
      x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:])
      return x

    shuffled_data = jax.tree_util.tree_map(convert_data, data)
    (optimizer_state, params, _), metrics = jax.lax.scan(
        functools.partial(minibatch_step, normalizer_params=normalizer_params),
        (optimizer_state, params, key_grad),
        shuffled_data,
        length=num_minibatches)
    return (optimizer_state, params, key), metrics

  def training_step(
      carry: Tuple[TrainingState, envs.State, PRNGKey],
      unused_t) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]:
    training_state, state, key = carry
    key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3)

    policy = make_policy(
        (training_state.normalizer_params, training_state.params.policy))

    def f(carry, unused_t):
      current_state, current_key = carry
      current_key, next_key = jax.random.split(current_key)
      next_state, data = generate_unroll(
          env,
          current_state,
          policy,
          current_key,
          unroll_length,
          extra_fields=('truncation',))
      return (next_state, next_key), data

    (state, _), data = jax.lax.scan(
        f, (state, key_generate_unroll), (),
        length=batch_size * num_minibatches // num_envs)
    # Have leading dimensions (batch_size * num_minibatches, unroll_length)
    data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data)
    data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]),
                                  data)
    assert data.discount.shape[1:] == (unroll_length,)

    # Update normalization params and normalize observations.
    normalizer_params = running_statistics.update(
        training_state.normalizer_params,
        data.observation,
        pmap_axis_name=_PMAP_AXIS_NAME)

    (optimizer_state, params, _), metrics = jax.lax.scan(
        functools.partial(
            sgd_step, data=data, normalizer_params=normalizer_params),
        (training_state.optimizer_state, training_state.params, key_sgd), (),
        length=num_updates_per_batch)

    new_training_state = TrainingState(
        optimizer_state=optimizer_state,
        params=params,
        normalizer_params=normalizer_params,
        env_steps=training_state.env_steps + env_step_per_training_step)
    return (new_training_state, state, new_key), metrics

  def training_epoch(training_state: TrainingState, state: envs.State,
                     key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]:
    (training_state, state, _), loss_metrics = jax.lax.scan(
        training_step, (training_state, state, key), (),
        length=num_training_steps_per_epoch)
    loss_metrics = jax.tree_util.tree_map(jnp.mean, loss_metrics)
    return training_state, state, loss_metrics

  training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME)

  # Note that this is NOT a pure jittable method.
  def training_epoch_with_timing(
      training_state: TrainingState, env_state: envs.State,
      key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]:
    nonlocal training_walltime
    t = time.time()
    training_state, env_state = _strip_weak_type((training_state, env_state))
    result = training_epoch(training_state, env_state, key)
    training_state, env_state, metrics = _strip_weak_type(result)

    metrics = jax.tree_util.tree_map(jnp.mean, metrics)
    jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics)

    epoch_training_time = time.time() - t
    training_walltime += epoch_training_time
    sps = (num_training_steps_per_epoch *
           env_step_per_training_step *
           max(num_resets_per_eval, 1)) / epoch_training_time
    metrics = {
        'training/sps': sps,
        'training/walltime': training_walltime,
        **{f'training/{name}': value for name, value in metrics.items()}
    }
    return training_state, env_state, metrics  # pytype: disable=bad-return-type  # py311-upgrade


  if saved_params is None:
    init_params = ppo_losses.PPONetworkParams(
      policy=ppo_network.policy_network.init(key_policy),
      value=ppo_network.value_network.init(key_value))
    normalizer_params = running_statistics.init_state(
      specs.Array(env_state.obs.shape[-1:], jnp.dtype('float32')))
  else:
    init_params = ppo_losses.PPONetworkParams(
      policy=saved_params[1],
      value=saved_params[2])
    normalizer_params = saved_params[0]

  training_state = TrainingState(  # pytype: disable=wrong-arg-types  # jax-ndarray
      optimizer_state=optimizer.init(init_params),  # pytype: disable=wrong-arg-types  # numpy-scalars
      params=init_params,
      normalizer_params=normalizer_params,
      env_steps=0)
  training_state = jax.device_put_replicated(
      training_state,
      jax.local_devices()[:local_devices_to_use])

  if not eval_env:
    eval_env = environment
  if randomization_fn is not None:
    v_randomization_fn = functools.partial(
        randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)
    )
  eval_env = wrap_for_training(
      eval_env,
      episode_length=episode_length,
      action_repeat=action_repeat,
      randomization_fn=v_randomization_fn,
  )

  evaluator = Evaluator(
      eval_env,
      functools.partial(make_policy, deterministic=deterministic_eval),
      num_eval_envs=num_eval_envs,
      episode_length=episode_length,
      action_repeat=action_repeat,
      key=eval_key)

  # Run initial eval
  metrics = {}
  if process_id == 0 and num_evals > 1:
    metrics = evaluator.run_evaluation(
        _unpmap(
            (training_state.normalizer_params, training_state.params.policy)),
        training_metrics={})
    logging.info(metrics)
    progress_fn(0, metrics)

  training_metrics = {}
  training_walltime = 0
  current_step = 0
  # Initialize variables to allow saving params of run with max score
  max_score = 0
  max_score_params = {}
  for it in range(num_evals_after_init):
    logging.info('starting iteration %s %s', it, time.time() - xt)

    for _ in range(max(num_resets_per_eval, 1)):
      # optimization
      epoch_key, local_key = jax.random.split(local_key)
      epoch_keys = jax.random.split(epoch_key, local_devices_to_use)
      (training_state, env_state, training_metrics) = (
          training_epoch_with_timing(training_state, env_state, epoch_keys)
      )
      current_step = int(_unpmap(training_state.env_steps))

      key_envs = jax.vmap(
          lambda x, s: jax.random.split(x[0], s),
          in_axes=(0, None))(key_envs, key_envs.shape[1])
      # TODO: move extra reset logic to the AutoResetWrapper.
      env_state = reset_fn(key_envs) if num_resets_per_eval > 0 else env_state

    if process_id == 0:
      # Run evals.
      metrics = evaluator.run_evaluation(
          _unpmap(
              (training_state.normalizer_params, training_state.params.policy)),
          training_metrics)
      logging.info(metrics)
      progress_fn(current_step, metrics)
      params = _unpmap(
          (training_state.normalizer_params, training_state.params.policy,
           training_state.params.value))

      # Save params if this is the max score
      eval_score = metrics['eval/episode_reward']
      if eval_score > max_score:
        max_score = eval_score
        max_score_params = {
            "score": max_score,
            "params": params,
        }
      policy_params_fn(current_step, make_policy, params)

  total_steps = current_step
  assert total_steps >= num_timesteps

  # If there was no mistakes the training_state should still be identical on all
  # devices.
  pmap.assert_is_replicated(training_state)
  params = _unpmap(
      (training_state.normalizer_params, training_state.params.policy,
       training_state.params.value))
  logging.info('total steps: %s', total_steps)
  pmap.synchronize_hosts()
  return (make_policy, params, metrics, max_score_params)


In [None]:
# Define the PPO networks (adapted from https://github.com/google/brax)

@flax.struct.dataclass
class PPONetworks:
  policy_network: networks.FeedForwardNetwork
  value_network: networks.FeedForwardNetwork
  parametric_action_distribution: distribution.ParametricDistribution

def make_ppo_networks(
    observation_size: int,
    action_size: int,
    preprocess_observations_fn: types.PreprocessObservationFn = types
    .identity_observation_preprocessor,
    policy_hidden_layer_sizes: Sequence[int] = policy_hidden_layer_sizes,
    value_hidden_layer_sizes: Sequence[int] = value_hidden_layer_sizes,
    activation: networks.ActivationFn = linen.swish) -> PPONetworks:
  """Make PPO networks with preprocessor."""
  parametric_action_distribution = distribution.NormalTanhDistribution(
      event_size=action_size)
  policy_network = networks.make_policy_network(
      parametric_action_distribution.param_size,
      observation_size,
      preprocess_observations_fn=preprocess_observations_fn,
      hidden_layer_sizes=policy_hidden_layer_sizes,
      activation=activation)
  value_network = networks.make_value_network(
      observation_size,
      preprocess_observations_fn=preprocess_observations_fn,
      hidden_layer_sizes=value_hidden_layer_sizes,
      activation=activation)

  return PPONetworks(
      policy_network=policy_network,
      value_network=value_network,
      parametric_action_distribution=parametric_action_distribution)

## 7 - Training the robot

Training for 60m timesteps with 9 evals takes about 50min with two T4 GPUs. That can be enough for it to learn to do 2 bounces on average and a max of 3 (although it can take longer in some training runs as the optimization is non-deterministic).

Learning to do better (~10 bounces on average) is possible in about 4 hours.

In [None]:
# Load params to restart training from a saved checkpoint
# (i.e. from the saved policy and value neural networks' weights)
upload_model = False
if upload_model:
  saved_params = model.load_params(save_path)
else:
  saved_params = None

In [None]:
# Train
train_fn = functools.partial(
    train, num_timesteps=num_timesteps, num_evals=num_evals,
    episode_length=episode_length, normalize_observations=normalize_observations,
    action_repeat=action_repeat, unroll_length=unroll_length, num_minibatches=num_minibatches,
    num_updates_per_batch=num_updates_per_batch, discounting=discounting,
    learning_rate=learning_rate, entropy_cost=entropy_cost, num_envs=num_envs,
    reward_scaling=reward_scaling, clipping_epsilon=clipping_epsilon, gae_lambda=gae_lambda,
    normalize_advantage=normalize_advantage, batch_size=batch_size, seed=0,
    network_factory=make_ppo_networks, saved_params=saved_params)

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]

max_y, min_y = 15000, 0
def progress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])

  plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.1])
  plt.ylim([min_y, max_y])

  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.title(f'y={y_data[-1]:.1f}')

  plt.errorbar(
      x_data, y_data, yerr=ydataerr)
  plt.show()

  if 'training/policy_loss' in metrics:
    print("Other metrics")    
    print(f"entropy loss: {metrics['training/entropy_loss']:.2f}")
    print(f"value loss: {metrics['training/v_loss']:.2f}")
    print(f"max episode reward: {int(metrics['eval/episode_reward_max'])}")
    print(f"avg bounces: {metrics['eval/episode_bounces']:.2f}")
    print(f"max bounces: {metrics['eval/episode_bounces_max']}\n")

make_inference_fn, train_params, _, max_score_params = train_fn(
    environment=env, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')
print(f'total time: {times[-1] - times[0]}\n')
print(f"max score: {int(max_score_params['score'])}")

## 8 - Save and load the policy

We can save and load the policy using the brax model API.

In [None]:
# Save the model
model.save_params(save_path, train_params)
# model.save_params(save_path, max_score_params['params'])

In [None]:
# Load the model and define the inference function
inference_fn = make_inference_fn(model.load_params(save_path)[:2])
jit_inference_fn = jax.jit(inference_fn)

## 9 - Visualize the policy

Finally we can visualize the robot in action and watch while it bounces the ball with its foot!

This can also be saved to an mp4 file which you can then download from the `Output` section (can be found on the right if running in a laptop).

In [None]:
eval_env = envs.get_environment(env_name)

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

In [None]:
# Visualize the robot and optionally save it to a mp4 file

init_time = datetime.now()
rollouts = []
max_rollout_reward = 0
max_bounces = 0
for i in range(num_rollouts):
  # Initialize the state
  rng = jax.random.PRNGKey(i)
  state = jit_reset(rng)
  rollout = [state.pipeline_state]
  total_reward = 0
  total_bounces = 0

  # Grab a trajectory
  n_steps = 100000
  render_every = 2

  for i in range(n_steps):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    total_reward += state.metrics["reward"]
    total_bounces += state.metrics["bounces"]
    rollout.append(state.pipeline_state)

    if state.done:
      break

  max_rollout_reward = max(max_rollout_reward, total_reward)
  max_bounces = max(max_bounces, total_bounces)
    
  if total_bounces > num_bounces_threshold:
    print(f"Iteration with reward {int(total_reward)} and {int(total_bounces)} bounces")
    video = env.render(rollout[::render_every], camera='side', height=480, width=640)
    media.show_video(video, fps=1.0 / env.dt / render_every)
    media.write_video(f"/kaggle/working/ball_bounce_{int(total_reward)}_{int(total_bounces)}.mp4", video, fps=1.0 / env.dt / render_every)
    
print(f"Max rollout reward was - {int(max_rollout_reward)}")
print(f"Max bounces was - {int(max_bounces)}")
print(f'total time: {datetime.now() - init_time}')