![MuJoCo banner](https://raw.githubusercontent.com/google-deepmind/mujoco/main/banner.png)

# <h1><center>Rollout Tutorial <a href="https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/python/rollout.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" width="140" align="center"/></a></center></h1>

This notebook provides a tutorial for [**MuJoCo** physics](https://github.com/google-deepmind/mujoco#readme), using the native Python bindings.

This notebook describes the `rollout` module included in the MuJoCo Python library. It performs simulation "rollouts" with an underlying C++ function. The rollouts can be multithreaded.

Below, the usage of each argument is explained with examples. An example of using `rollout` with minimize is also given. Then `rollout` is benchmarked against pure python and MJX. Finally, some examples for advanced use cases are provided.

Note the benchmarks were designed to run on a AMD 5800X3D and an RTX 4090. They do not run in a reasonable amount of time on a typical free colab runtime.

<!-- Copyright 2025 DeepMind Technologies Limited

     Licensed under the Apache License, Version 2.0 (the "License");
     you may not use this file except in compliance with the License.
     You may obtain a copy of the License at

         http://www.apache.org/licenses/LICENSE-2.0

     Unless required by applicable law or agreed to in writing, software
     distributed under the License is distributed on an "AS IS" BASIS,
     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     See the License for the specific language governing permissions and
     limitations under the License.
-->

# All Imports

In [None]:
#@title All imports

!pip install mujoco
!pip install mujoco_mjx
!pip install brax

# Set up GPU rendering.
#from google.colab import files
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

# Check if installation was succesful.
try:
  print('Checking that the installation succeeded:')
  import mujoco
  from mujoco import minimize
  from mujoco import rollout
  from mujoco import mjx
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

# Other imports and helper functions
import copy
import time
from multiprocessing import cpu_count
import threading
import itertools
import numpy as np
import jax
import jax.numpy as jp

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

from IPython.display import clear_output
clear_output()

# Set the number of threads to the number of cpu's that the multiprocessing module reports
nthread = cpu_count()

# Get MuJoCo's standard humanoid model.
print('Getting MuJoCo humanoid XML description from GitHub:')
!git clone https://github.com/google-deepmind/mujoco

# Helper Functions

In [2]:
#@title helper functions

def get_state(model, data, nbatch=1):
  state = np.zeros((mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS),))
  mujoco.mj_getState(model, data, state, mujoco.mjtState.mjSTATE_FULLPHYSICS)
  np.tile(state, (nbatch, 1))
  return state

def benchmark(f, x_list=[None], ntiming=1):
  x_times_list = []
  for x in x_list:
    times = [time.perf_counter()]
    for i in range(ntiming):
      f(x)
      times.append(time.perf_counter())
    x_times_list.append(np.mean(np.diff(times)))
  return np.array(x_times_list)

def render_many(model, data, state, framerate, camera=-1, shift_joint=None, ncols=10, spacing=(1., 1.), shape=(480, 640), transparent=True):
  nbatch = state.shape[0]

  if not isinstance(model, mujoco.MjModel):
    model = list(model)

  if isinstance(model, list) and len(model) == 1:
    model = model * nbatch
  elif isinstance(model, list):
    assert len(model) == nbatch
  else:
    model = [model] * nbatch

  if shift_joint is not None:
    data = copy.copy(data)

  # Visual options
  vopt = mujoco.MjvOption()
  vopt.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = transparent  # Transparent.
  pert = mujoco.MjvPerturb()  # Empty MjvPerturb object
  catmask = mujoco.mjtCatBit.mjCAT_DYNAMIC

  # Simulate and render.
  frames = []
  with mujoco.Renderer(model[0], *shape) as renderer:
    for i in range(state.shape[1]):
      if len(frames) < i * model[0].opt.timestep * framerate:
        for j in range(state.shape[0]):
          mujoco.mj_setState(model[j], data, state[j, i, :], mujoco.mjtState.mjSTATE_FULLPHYSICS)
          mujoco.mj_forward(model[j], data)

          if shift_joint is not None:
            grid_x = j % ncols
            grid_y = j // ncols
            #print(grid_x, grid_y)
            data.joint(shift_joint).qpos[:3] = data.joint(shift_joint).qpos[:3] + (grid_x * spacing[0], grid_y * spacing[1], 0)
            mujoco.mj_forward(model[j], data)

          # Add the first top to the scene
          if j == 0:
            renderer.update_scene(data, camera, scene_option=vopt)
          else:
            mujoco.mjv_addGeoms(model[j], data, vopt, pert, catmask, renderer.scene)
        # Render and add the frame.
        pixels = renderer.render()
        frames.append(pixels)
  return frames

# Using `rollout`

The `rollout.rollout` function in the `mujoco` Python library runs batches of simulations for a fixed number steps. It can run in single or multi-threaded modes. The speedup over pure Python is significant because `rollout` users can easily enable the usage of a lightweight threadpool.

Below we load the "tippe top", "humanoid", and "humanoid100" models which will be used in the following usage examples and benchmarks.

The tippe top is copied from the [tutorial notebook](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/python/tutorial.ipynb). The humanoid and humanoid100 models are distributed with MuJoCo.

In [None]:
#@title Benchmarked models
tippe_top = """
<mujoco model="tippe top">
  <option integrator="RK4"/>

  <asset>
    <texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3"
     rgb2=".2 .3 .4" width="300" height="300"/>
    <material name="grid" texture="grid" texrepeat="40 40" reflectance=".2"/>
  </asset>

  <worldbody>
    <geom size="1 1 .01" type="plane" material="grid"/>
    <light pos="0 0 .6"/>
    <camera name="closeup" pos="0 -.1 .07" xyaxes="1 0 0 0 1 2"/>
    <camera name="distant" pos="0 -.4 .4" xyaxes="1 0 0 0 1 1"/>
    <body name="top" pos="0 0 .02">
      <freejoint name="top"/>
      <site name="top" pos="0 0 0"/>
      <geom name="ball" type="sphere" size=".02" />
      <geom name="stem" type="cylinder" pos="0 0 .02" size="0.004 .008"/>
      <geom name="ballast" type="box" size=".023 .023 0.005"  pos="0 0 -.015"
       contype="0" conaffinity="0" group="3"/>
    </body>
  </worldbody>

  <sensor>
    <gyro name="gyro" site="top"/>
  </sensor>

  <keyframe>
    <key name="spinning" qpos="0 0 0.02 1 0 0 0" qvel="0 0 0 0 1 200" />
  </keyframe>
</mujoco>
"""

# Create and initialize top model
top_model = mujoco.MjModel.from_xml_string(tippe_top)
def init_top(model):
  data = mujoco.MjData(model)
  mujoco.mj_resetDataKeyframe(model, data, 0)  # Set to the state to a spinning upside down top
  return data
top_data = init_top(top_model)

# Create and initialize humanoid model
humanoid_xml_path = 'mujoco/model/humanoid/humanoid.xml'
humanoid_model = mujoco.MjModel.from_xml_path(humanoid_xml_path)
def init_humanoid(model):
  data = mujoco.MjData(model)
  data.qvel[2] = 4 # Make the humanoid jump
  return data
humanoid_data = init_humanoid(humanoid_model)

# Create and initialize humanoid100 model
humanoid100_xml_path = 'mujoco/model/humanoid/humanoid100.xml'
humanoid100_model = mujoco.MjModel.from_xml_path(humanoid100_xml_path)
def init_humanoid100(model):
  data = mujoco.MjData(model)
  return data
humanoid100_data = init_humanoid100(humanoid100_model)

start = time.time()
top_nstep = int(6 / top_model.opt.timestep)
top_state, _ = rollout.rollout(top_model, top_data, initial_state=get_state(top_model, top_data), nstep=top_nstep)

humanoid_nstep = int(3 / humanoid_model.opt.timestep)
humanoid_state, _ = rollout.rollout(humanoid_model, humanoid_data, initial_state=get_state(humanoid_model, humanoid_data), nstep=humanoid_nstep)

humanoid100_nstep = int(3 / humanoid100_model.opt.timestep)
humanoid100_state, _ = rollout.rollout(humanoid100_model, humanoid100_data, initial_state=get_state(humanoid100_model, humanoid100_data), nstep=humanoid100_nstep)
end = time.time()

start_render = time.time()
top_frames = render_many(top_model, top_data, top_state, framerate=60, shape=(240, 320), transparent=False)
humanoid_frames = render_many(humanoid_model, humanoid_data, humanoid_state, framerate=120, shape=(240, 320), transparent=False)
humanoid100_frames = render_many(humanoid100_model, humanoid100_data, humanoid100_state, framerate=120, shape=(240, 320), transparent=False)

media.show_video(np.concatenate((top_frames, humanoid_frames, humanoid100_frames), axis=2), fps=60) # humanoid and humanoid100 are shown at half speed
end_render = time.time()

print(f'Rollout took {end-start:.1f} seconds')
print(f'Rendering took {end_render-start_render:.1f} seconds')

## Detailed Usage

It is helpful to read `rollout`'s docstring before beginning. The main takeaways are that `rollout` runs nbatch rollouts for nstep steps. The MjModel's can be different but should be the same up to parameter values. Passing multiple MjData enables multithreading, one thread per MjData.

Next we give usage examples of the most common arguments. The more advanced arguments are discussed in the "Advanced Usage" section.

In [None]:
print(rollout.rollout.__doc__)

### Example: different initial states
`rollout` is designed to run nbatch rollouts in parallel for nstep steps. Lets simulate 100 tippe tops with different initial rotation speeds.

In [None]:
nbatch = 100 # Simulate this many tops

# Get nbatch initial states and scale the initial speed of the tippe top using the batch index
top_data = init_top(top_model)
initial_state = get_state(top_model, top_data)
initial_states = np.tile(initial_state, (nbatch, 1))
initial_states[:, -1] *= np.linspace(0.5, 1.5, num=nbatch)

# Run the rollout
start = time.time()
state, sensordata = rollout.rollout(top_model, [copy.copy(top_data) for _ in range(nthread)], # Create one MjData per thread
                                    initial_states, nstep=int(top_nstep*1.5))
end = time.time()

# Use state to render all the tops at once
start_render = time.time()
framerate = 60
media.show_video(render_many(top_model, top_data, state, framerate), fps=framerate)
end_render = time.time()

print(f'Rollout took {end-start:.1f} seconds')
print(f'Rendering took {end_render-start_render:.1f} seconds')

Our model has an angular velocity sensor the middle of the top. Let's plot the response using the `sensordata` array that rollout returns.

In [None]:
plt.subplot(3,1,1)
for i in range(nbatch): plt.plot(sensordata[i, :, 0])
plt.subplot(3,1,2)
for i in range(nbatch): plt.plot(sensordata[i, :, 1])
plt.subplot(3,1,3)
for i in range(nbatch): plt.plot(sensordata[i, :, 2])

### Example: different models
100 gray tops is kind of boring. It would be better if they were colorful and different sizes!

`rollout` supports using different models for each rollout, so long as they are of compatibile dimensions. Let's simulate 100 tippe tops with the same initial condition, but different sizes and colors.

**Note:** Strictly speaking, the models must have the same number of states, controls, degrees of freedom, and sensor outputs. The most common use case is multiple models of the same thing up to parameter values.

In [None]:
# Make 100 tippe tops with different colors and sizes
nbatch = 100
spec = mujoco.MjSpec.from_string(tippe_top)
models = []
for i in range(nbatch):
  for geom in spec.geoms:
    if geom.name in ['ball', 'stem', 'ballast']:
      geom.rgba[:3] = np.random.rand(3)
    if geom.name == 'stem':
      stem_geom = geom
    if geom.name == 'ball':
      ball_geom = geom

  # Save original geom size
  stem_geom_size = np.copy(stem_geom.size)
  ball_geom_size = np.copy(ball_geom.size)

  # Scale geoms and compile model
  size_scale = 0.75*np.random.rand(1) + 0.5
  stem_geom.size *= size_scale
  ball_geom.size *= size_scale
  models.append(spec.compile())

  # Restore original geom size
  stem_geom.size = stem_geom_size
  ball_geom.size = ball_geom_size

# Reset the intial state
top_data = init_top(top_model)

# Run the rollout
start = time.time()
state, sensordata = rollout.rollout(models, [copy.copy(top_data) for _ in range(nthread)], # Create one MjData per thread
                                    get_state(top_model, top_data), nstep=int(1.5*top_nstep))
end = time.time()

# Render video
start_render = time.time()
framerate = 60
cam = mujoco.MjvCamera()
mujoco.mjv_defaultCamera(cam)
cam.distance = 0.2
cam.azimuth = 135
cam.elevation = -25
cam.lookat = [0, 0, 0.07]
models[0].vis.global_.fovy = 60
frames = render_many(models, top_data, state, framerate, shift_joint='top', spacing=[-0.05, 0.05], camera=cam)
media.show_video(frames, fps=framerate)
end_render = time.time()

print(f'Rollout took {end-start:.1f} seconds')
print(f'Rendering took {end_render-start_render:.1f} seconds')

Because the models are now different, the measurements of the gyro sensor are not consistent even though the initial state for each rollout was the same.

In [None]:
plt.subplot(3,1,1)
for i in range(nbatch): plt.plot(sensordata[i, :, 0])
plt.subplot(3,1,2)
for i in range(nbatch): plt.plot(sensordata[i, :, 1])
plt.subplot(3,1,3)
for i in range(nbatch): plt.plot(sensordata[i, :, 2])

### Example: control inputs
Open loop controls can be passed to `rollout` via the `control` argument. If passed, `nstep` no longer needs to be specified as it can be inferred from the size of `control`.

Below we simulate 100 of the flailing humanoids from the [tutorial notebook](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/python/tutorial.ipynb). Each humanoid uses a different control signal.

In [None]:
# Episode parameters.
duration = 3       # (seconds)
framerate = 120     # (Hz)
humanoid_data.qvel[2] = 4   # Initial vertical velocity (m/s)
ctrl_phase = 2 * np.pi * np.random.rand(humanoid_model.nu)  # Control phase
ctrl_freq = 1     # Control frequency

# Generate 100 different controls
nbatch = 100
nstep = int(duration / humanoid_model.opt.timestep)
times = np.linspace(0.0, duration, nstep)
times = np.arange(0.0, duration, humanoid_model.opt.timestep)
control = np.sin((2 * np.pi * times * ctrl_freq).reshape(nstep, 1) + ctrl_phase.reshape(1, humanoid_model.nu))
control = np.stack([control]*nbatch, axis=0)
control += np.random.normal(size=control.shape)

# Initialize the model
humanoid_data = init_humanoid(humanoid_model)

# Run the rollout
start = time.time()
state, _ = rollout.rollout(humanoid_model, [copy.copy(humanoid_data) for _ in range(nthread)],
                           get_state(humanoid_model, humanoid_data), control)
end = time.time()

# Render the rollout
start_render = time.time()
framerate=120
cam = mujoco.MjvCamera()
mujoco.mjv_defaultCamera(cam)
cam.distance = 3.5
cam.azimuth = 132.5
cam.elevation = -45
cam.lookat = [0, 0, 3.0]
humanoid_model.vis.global_.fovy = 60
frames = render_many(humanoid_model, humanoid_data, state, framerate, shift_joint='root', spacing=[-1.0, 1.0], camera=cam)
media.show_video(frames, fps=framerate/2) # Show the video at half speed
end_render = time.time()

print(f'Rollout took {end-start:.1f} seconds')
print(f'Rendering took {end_render-start_render:.1f} seconds')

`rollout`'s `control_spec` argument can be used to indicate `control` contains values for actuators, generalized forces, cartesian forces, mocap poses, and/or the activation/deactivation of equality constraints. Internally, this is managed through [mj_setState](https://mujoco.readthedocs.io/en/stable/APIreference/APIfunctions.html#mj-setstate) and `control_spec` corresponds to `mj_setState`'s `spec` argument.

Let's try applying cartesian forces in addition to the control inputs. This will make the humanoids look like they are being dragged while waving their limbs.

In [None]:
xfrc = np.zeros((control.shape[0], control.shape[1], mujoco.mj_stateSize(humanoid_model, mujoco.mjtState.mjSTATE_XFRC_APPLIED)))
head_id = humanoid_model.body('head').id

# Apply a constant but different force to each model
xfrc[:, :, 3*head_id:3*head_id+2] = np.random.normal(scale=150.0, size=(control.shape[0], 1, 2))

humanoid_data = init_humanoid(humanoid_model)

control_xfrc = np.concatenate((control, xfrc), axis=2)
control_spec = mujoco.mjtState.mjSTATE_CTRL.value + mujoco.mjtState.mjSTATE_XFRC_APPLIED.value

start = time.time()
state, _ = rollout.rollout(humanoid_model, [copy.copy(humanoid_data) for _ in range(nthread)],
                           get_state(humanoid_model, humanoid_data), control_xfrc, control_spec=control_spec)
end = time.time()

start_render = time.time()
frames = render_many(humanoid_model, humanoid_data, state, framerate, shift_joint='root', spacing=[-1.0, 1.0], camera=cam)
media.show_video(frames, fps=framerate/2) # Show the video at half speed
end_render = time.time()

print(f'Rollout took {end-start:.1f} seconds')
print(f'Rendering took {end_render-start_render:.1f} seconds')

# Application: `rollout` + `minimize.least_squares`

`rollout` can be easily used with MuJoCo's nonlinear least squares utility, `minimize.least_squares`. Because `minimize` uses finite-differencing to estimate jacobians, it benefits greatly from multi-threaded rollouts.

As an example let's consider the "reach" sample from the [least squares notebook](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/python/least_squares.ipynb). The code is copied here with a small modification that allows multithreading.

The goal is for the humanoid to reach a target with one of its hands. By default, the humanoid does a jump, but does not reach its hand to the target.

In [None]:
#@title Humanoid Reaching XML and code
xml = """
<mujoco model="Humanoid">
  <option timestep="0.005"/>

  <visual>
    <map force="0.1" zfar="30"/>
    <rgba haze="0.15 0.25 0.35 1"/>
    <global offwidth="2560" offheight="1440" elevation="-20" azimuth="120"/>
  </visual>

  <statistic center="0 0 0.7"/>

  <asset>
    <texture type="skybox" builtin="gradient" rgb1=".3 .5 .7" rgb2="0 0 0" width="32" height="512"/>
    <texture name="body" type="cube" builtin="flat" mark="cross" width="128" height="128" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" markrgb="1 1 1" random="0.01"/>
    <material name="body" texture="body" texuniform="true" rgba="0.8 0.6 .4 1"/>
    <texture name="grid" type="2d" builtin="checker" width="512" height="512" rgb1=".1 .2 .3" rgb2=".2 .3 .4"/>
    <material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
  </asset>

  <default>
    <position inheritrange="0.95"/>
    <default class="body">

      <!-- geoms -->
      <geom type="capsule" condim="1" friction=".7" solimp="0.9 .99 .003" solref=".015 1" material="body" group="1"/>
      <default class="thigh">
        <geom size=".06"/>
      </default>
      <default class="shin">
        <geom fromto="0 0 0 0 0 -.3"  size=".049"/>
      </default>
      <default class="foot">
        <geom size=".027"/>
        <default class="foot1">
          <geom fromto="-.07 -.01 0 .14 -.03 0"/>
        </default>
        <default class="foot2">
          <geom fromto="-.07 .01 0 .14  .03 0"/>
        </default>
      </default>
      <default class="arm_upper">
        <geom size=".04"/>
      </default>
      <default class="arm_lower">
        <geom size=".031"/>
      </default>
      <default class="hand">
        <geom type="sphere" size=".04"/>
      </default>

      <!-- joints -->
      <joint type="hinge" damping=".2" stiffness="1" armature=".01" limited="true" solimplimit="0 .99 .01"/>
      <default class="joint_big">
        <joint damping="5" stiffness="10"/>
        <default class="hip_x">
          <joint range="-30 10"/>
        </default>
        <default class="hip_z">
          <joint range="-60 35"/>
        </default>
        <default class="hip_y">
          <joint axis="0 1 0" range="-150 20"/>
        </default>
        <default class="joint_big_stiff">
          <joint stiffness="20"/>
        </default>
      </default>
      <default class="knee">
        <joint pos="0 0 .02" axis="0 -1 0" range="-160 2"/>
      </default>
      <default class="ankle">
        <joint range="-50 50"/>
        <default class="ankle_y">
          <joint pos="0 0 .08" axis="0 1 0" stiffness="6"/>
        </default>
        <default class="ankle_x">
          <joint pos="0 0 .04" stiffness="3"/>
        </default>
      </default>
      <default class="shoulder">
        <joint range="-85 60"/>
      </default>
      <default class="elbow">
        <joint range="-100 50" stiffness="0"/>
      </default>
    </default>
  </default>

  <worldbody>
    <body name="target" pos=".2 -.2 1" mocap="true">
      <site name="target" size=".05" rgba="1 0 1 .4"/>
    </body>
    <geom name="floor" size="0 0 .05" type="plane" material="grid" condim="3"/>
    <light name="spotlight" mode="targetbodycom" target="torso" diffuse=".8 .8 .8" specular="0.3 0.3 0.3" pos="0 -6 4" cutoff="30"/>
    <body name="torso" pos="0 0 1.282" childclass="body">
      <light name="top" pos="0 0 2" mode="trackcom"/>
      <camera name="back" pos="-3 0 1" xyaxes="0 -1 0 1 0 2" mode="trackcom"/>
      <camera name="side" pos="0 -3 1" xyaxes="1 0 0 0 1 2" mode="trackcom"/>
      <freejoint name="root"/>
      <geom name="torso" fromto="0 -.07 0 0 .07 0" size=".07"/>
      <geom name="waist_upper" fromto="-.01 -.06 -.12 -.01 .06 -.12" size=".06"/>
      <body name="head" pos="0 0 .19">
        <geom name="head" type="sphere" size=".09"/>
        <camera name="egocentric" pos=".09 0 0" xyaxes="0 -1 0 .1 0 1" fovy="80"/>
      </body>
      <body name="waist_lower" pos="-.01 0 -.26">
        <geom name="waist_lower" fromto="0 -.06 0 0 .06 0" size=".06"/>
        <joint name="abdomen_z" pos="0 0 .065" axis="0 0 1" range="-45 45" class="joint_big_stiff"/>
        <joint name="abdomen_y" pos="0 0 .065" axis="0 1 0" range="-75 30" class="joint_big"/>
        <body name="pelvis" pos="0 0 -.165">
          <joint name="abdomen_x" pos="0 0 .1" axis="1 0 0" range="-35 35" class="joint_big"/>
          <geom name="butt" fromto="-.02 -.07 0 -.02 .07 0" size=".09"/>
          <body name="thigh_right" pos="0 -.1 -.04">
            <joint name="hip_x_right" axis="1 0 0" class="hip_x"/>
            <joint name="hip_z_right" axis="0 0 1" class="hip_z"/>
            <joint name="hip_y_right" class="hip_y"/>
            <geom name="thigh_right" fromto="0 0 0 0 .01 -.34" class="thigh"/>
            <body name="shin_right" pos="0 .01 -.4">
              <joint name="knee_right" class="knee"/>
              <geom name="shin_right" class="shin"/>
              <body name="foot_right" pos="0 0 -.39">
                <joint name="ankle_y_right" class="ankle_y"/>
                <joint name="ankle_x_right" class="ankle_x" axis="1 0 .5"/>
                <geom name="foot1_right" class="foot1"/>
                <geom name="foot2_right" class="foot2"/>
              </body>
            </body>
          </body>
          <body name="thigh_left" pos="0 .1 -.04">
            <joint name="hip_x_left" axis="-1 0 0" class="hip_x"/>
            <joint name="hip_z_left" axis="0 0 -1" class="hip_z"/>
            <joint name="hip_y_left" class="hip_y"/>
            <geom name="thigh_left" fromto="0 0 0 0 -.01 -.34" class="thigh"/>
            <body name="shin_left" pos="0 -.01 -.4">
              <joint name="knee_left" class="knee"/>
              <geom name="shin_left" fromto="0 0 0 0 0 -.3" class="shin"/>
              <body name="foot_left" pos="0 0 -.39">
                <joint name="ankle_y_left" class="ankle_y"/>
                <joint name="ankle_x_left" class="ankle_x" axis="-1 0 -.5"/>
                <geom name="foot1_left" class="foot1"/>
                <geom name="foot2_left" class="foot2"/>
              </body>
            </body>
          </body>
        </body>
      </body>
      <body name="upper_arm_right" pos="0 -.17 .06">
        <joint name="shoulder1_right" axis="2 1 1"  class="shoulder"/>
        <joint name="shoulder2_right" axis="0 -1 1" class="shoulder"/>
        <geom name="upper_arm_right" fromto="0 0 0 .16 -.16 -.16" class="arm_upper"/>
        <body name="lower_arm_right" pos=".18 -.18 -.18">
          <joint name="elbow_right" axis="0 -1 1" class="elbow"/>
          <geom name="lower_arm_right" fromto=".01 .01 .01 .17 .17 .17" class="arm_lower"/>
          <body name="hand_right" pos=".18 .18 .18">
            <geom name="hand_right" zaxis="1 1 1" class="hand" rgba="1 0 1 1"/>
          </body>
        </body>
      </body>
      <body name="upper_arm_left" pos="0 .17 .06">
        <joint name="shoulder1_left" axis="-2 1 -1" class="shoulder"/>
        <joint name="shoulder2_left" axis="0 -1 -1"  class="shoulder"/>
        <geom name="upper_arm_left" fromto="0 0 0 .16 .16 -.16" class="arm_upper"/>
        <body name="lower_arm_left" pos=".18 .18 -.18">
          <joint name="elbow_left" axis="0 -1 -1" class="elbow"/>
          <geom name="lower_arm_left" fromto=".01 -.01 .01 .17 -.17 .17" class="arm_lower"/>
          <body name="hand_left" pos=".18 -.18 .18">
            <geom name="hand_left" zaxis="1 -1 1" class="hand"/>
          </body>
        </body>
      </body>
    </body>
  </worldbody>

  <contact>
    <exclude body1="waist_lower" body2="thigh_right"/>
    <exclude body1="waist_lower" body2="thigh_left"/>
  </contact>

  <tendon>
    <fixed name="hamstring_right" limited="true" range="-0.3 2">
      <joint joint="hip_y_right" coef=".5"/>
      <joint joint="knee_right" coef="-.5"/>
    </fixed>
    <fixed name="hamstring_left" limited="true" range="-0.3 2">
      <joint joint="hip_y_left" coef=".5"/>
      <joint joint="knee_left" coef="-.5"/>
    </fixed>
  </tendon>

  <actuator>
    <position name="abdomen_z"       kp="40"  joint="abdomen_z"/>
    <position name="abdomen_y"       kp="40"  joint="abdomen_y"/>
    <position name="abdomen_x"       kp="40"  joint="abdomen_x"/>
    <position name="hip_x_right"     kp="40"  joint="hip_x_right"/>
    <position name="hip_z_right"     kp="40"  joint="hip_z_right"/>
    <position name="hip_y_right"     kp="120" joint="hip_y_right"/>
    <position name="knee_right"      kp="80"  joint="knee_right"/>
    <position name="ankle_y_right"   kp="20"  joint="ankle_y_right"/>
    <position name="ankle_x_right"   kp="20"  joint="ankle_x_right"/>
    <position name="hip_x_left"      kp="40"  joint="hip_x_left"/>
    <position name="hip_z_left"      kp="40"  joint="hip_z_left"/>
    <position name="hip_y_left"      kp="120" joint="hip_y_left"/>
    <position name="knee_left"       kp="80"  joint="knee_left"/>
    <position name="ankle_y_left"    kp="20"  joint="ankle_y_left"/>
    <position name="ankle_x_left"    kp="20"  joint="ankle_x_left"/>
    <position name="shoulder1_right" kp="20"  joint="shoulder1_right"/>
    <position name="shoulder2_right" kp="20"  joint="shoulder2_right"/>
    <position name="elbow_right"     kp="40"  joint="elbow_right"/>
    <position name="shoulder1_left"  kp="20"  joint="shoulder1_left"/>
    <position name="shoulder2_left"  kp="20"  joint="shoulder2_left"/>
    <position name="elbow_left"      kp="40"  joint="elbow_left"/>
  </actuator>

  <sensor>
    <framepos objtype="geom" objname="hand_right" reftype="xbody" refname="target"/>
    <actuatorfrc actuator="abdomen_z"/>
    <actuatorfrc actuator="abdomen_y"/>
    <actuatorfrc actuator="abdomen_x"/>
    <actuatorfrc actuator="hip_x_right"/>
    <actuatorfrc actuator="hip_z_right"/>
    <actuatorfrc actuator="hip_y_right"/>
    <actuatorfrc actuator="knee_right"/>
    <actuatorfrc actuator="ankle_y_right"/>
    <actuatorfrc actuator="ankle_x_right"/>
    <actuatorfrc actuator="hip_x_left"/>
    <actuatorfrc actuator="hip_z_left"/>
    <actuatorfrc actuator="hip_y_left"/>
    <actuatorfrc actuator="knee_left"/>
    <actuatorfrc actuator="ankle_y_left"/>
    <actuatorfrc actuator="ankle_x_left"/>
    <actuatorfrc actuator="shoulder1_right"/>
    <actuatorfrc actuator="shoulder2_right"/>
    <actuatorfrc actuator="elbow_right"/>
    <actuatorfrc actuator="shoulder1_left"/>
    <actuatorfrc actuator="shoulder2_left"/>
    <actuatorfrc actuator="elbow_left"/>
  </sensor>

  <keyframe>
    <!--
    The values below are split into rows for readibility:
      torso position
      torso orientation
      spinal
      right leg
      left leg
      arms
    -->
    <key name="squat"
         qpos="0 0 0.596
               0.988015 0 0.154359 0
               0 0.4 0
               -0.25 -0.5 -2.5 -2.65 -0.8 0.56
               -0.25 -0.5 -2.5 -2.65 -0.8 0.56
               0 0 0 0 0 0"/>
    <key name="stand_on_left_leg"
         qpos="0 0 1.21948
               0.971588 -0.179973 0.135318 -0.0729076
               -0.0516 -0.202 0.23
               -0.24 -0.007 -0.34 -1.76 -0.466 -0.0415
               -0.08 -0.01 -0.37 -0.685 -0.35 -0.09
               0.109 -0.067 -0.7 -0.05 0.12 0.16"/>
    <key name="prone"
         qpos="0.4 0 0.0757706
               0.7325 0 0.680767 0
               0 0.0729 0
               0.0077 0.0019 -0.026 -0.351 -0.27 0
               0.0077 0.0019 -0.026 -0.351 -0.27 0
               0.56 -0.62 -1.752
               0.56 -0.62 -1.752"/>
    <key name="supine"
         qpos="-0.4 0 0.08122
               0.722788 0 -0.69107 0
               0 -0.25 0
               0.0182 0.0142 0.3 0.042 -0.44 -0.02
               0.0182 0.0142 0.3 0.042 -0.44 -0.02
               0.186 -0.73 -1.73
               0.186 -0.73 -1.73"/>
  </keyframe>
</mujoco>
"""

# Load model, make data, make list of data for multithreading
model = mujoco.MjModel.from_xml_string(xml)
data = mujoco.MjData(model)
data_list = [mujoco.MjData(model) for _ in range(nthread)]

# Set the state to the "squat" keyframe, call mj_forward.
key = model.key('squat').id
mujoco.mj_resetDataKeyframe(model, data, key)
mujoco.mj_forward(model, data)

# If a renderer exists, close it.
if 'renderer' in locals():
  renderer.close()

# Make a Renderer and a camera.
renderer = mujoco.Renderer(model)
camera = mujoco.MjvCamera()
mujoco.mjv_defaultFreeCamera(model, camera)
camera.distance = 3
camera.elevation = -10

# Point the camera at the humanoid, render.
# camera.lookat = data.body('torso').subtree_com
# renderer.update_scene(data, camera)
# media.show_image(renderer.render())

def reach(ctrl0T, target, T, torque_scale, traj=None, multithread=False):
  """Residual for target-reaching task.

  Args:
    ctrl0T: contatenation of the first and last control vectors.
    target: target to which the right hand should reach.
    T: final time for the rollout.
    torque_scale: coefficient by which to scale the torques.
    traj: optional list of positions to be recorded.

  Returns:
    The residual of the target-reaching task.
  """
  # Extract the initial and final ctrl vectors, transpose to row vectors
  ctrl0 = ctrl0T[:model.nu, :].T
  ctrlT = ctrl0T[model.nu:, :].T

  # Move the mocap body to the target
  mocapid = model.body('target').mocapid
  data.mocap_pos[mocapid] = target

  # Append the mocap targets to the controls
  nroll  = ctrl0.shape[0]
  mocap = np.tile(data.mocap_pos[mocapid], (nroll, 1))
  ctrl0 = np.hstack((ctrl0, mocap))
  ctrlT = np.hstack((ctrlT, mocap))

  # Define control spec (ctrl + mocap_pos)
  mjtState = mujoco.mjtState
  control_spec = mjtState.mjSTATE_CTRL | mjtState.mjSTATE_MOCAP_POS

  # Interpolate and stack the control sequences
  nstep = int(np.round(T / model.opt.timestep))
  control = np.stack(np.linspace(ctrl0, ctrlT, nstep), axis=1)

  if not multithread:
    datas = [data]
  else:
    datas = data_list

  # Reset to the "squat" keyframe, get the initial state
  for d in datas:
    key = model.key('squat').id
    mujoco.mj_resetDataKeyframe(model, d, key)
    spec = mjtState.mjSTATE_FULLPHYSICS
    nstate = mujoco.mj_stateSize(model, spec)
    state = np.empty(nstate)
    mujoco.mj_getState(model, d, state, spec)

  # Perform rollouts (sensors.shape == nroll, nstep, nsensordata)
  states, sensors = rollout.rollout(model, datas, state, control,
                                    control_spec=control_spec)

  # If requested, extract qpos into traj
  if traj is not None:
    assert states.shape[0] == 1
    # Skip the first element in state (mjData.time)
    traj.extend(np.split(states[0, :, 1:model.nq+1], nstep))

  # Scale torque sensors
  sensors[:, :, 3:] *= torque_scale

  # Reshape to stack the sensor values, transpose to column vectors
  sensors = sensors.reshape((sensors.shape[0], -1)).T

  # The normalizer keeps objective values similar when changing T or timestep.
  normalizer = 100 * model.opt.timestep / T
  return normalizer * sensors

def render_solution(x, target):
  # Ask reach to save positions to traj.
  traj = []
  reach(x, target, T, torque_scale, traj=traj);

  frames = []
  counter = 0
  print('Rendering frames:', flush=True, end='')
  for qpos in traj:
    # Set positions, call mj_forward to update kinematics.
    data.qpos = qpos
    mujoco.mj_forward(model, data)

    # Render and save frames.
    camera.lookat = data.body('torso').subtree_com
    renderer.update_scene(data, camera)
    pixels = renderer.render()
    frames.append(pixels)
    counter += 1
    if counter % 10 == 0:
      print(f' {counter}', flush=True, end='')
  return frames

# Settings for the optimization
T = 0.7               # Rollout length (seconds)
torque_scale = 0.003  # Scaling for the torques

# Bounds are the stacked control bounds.
lower = np.atleast_2d(model.actuator_ctrlrange[:,0]).T
upper = np.atleast_2d(model.actuator_ctrlrange[:,1]).T
bounds = [np.vstack((lower, lower)), np.vstack((upper, upper))]

# Initial guess is midpoint of the bounds
x0 = 0.5 * (bounds[1] + bounds[0])
target = (.4, -.3, 1.2)

# Use default target.
target = data.mocap_pos[model.body('target').mocapid]

# Visualize the initial guess.
media.show_video(render_solution(x0, target))

Next, let's run the optimization in a single threaded and multi-threaded modes and render the resulting solutions.

In [None]:
reach_target = lambda x: reach(x, target, T, torque_scale, traj=None, multithread=False)
reach_target_multithread = lambda x: reach(x, target, T, torque_scale, traj=None, multithread=True)

print('Using 1 thread')
x_single, _ = minimize.least_squares(x0, reach_target, bounds, verbose=minimize.Verbosity.FINAL)
print()

print(f'Using {nthread} threads')
x_multi, _ = minimize.least_squares(x0, reach_target_multithread, bounds, verbose=minimize.Verbosity.FINAL)

# Render the solution to verify the results are the same
print()
frames_single = render_solution(x_single, target)
frames_multi = render_solution(x_multi, target)
media.show_video(np.concatenate((frames_single, frames_multi), axis=2))

By the using multithreaded `rollout` the minimization completed ~4x faster on a 5800X3D and the results are the same.

# Benchmarking `rollout`

The `rollout.rollout` function in the `mujoco` Python library runs batches of simulations for a fixed number steps. It can run in single or multi-threaded modes. The speedup over pure Python is significant because `rollout` can be easily configured to use multithreading.

To show the speedup, we will run benchmarks with the "tippe top", "humanoid", and "humanoid100" models.

## Python rollouts versus `rollout`

The benchmark runs the three models with varying batch and step counts.

The Python code for nbatch rollouts of nstep steps is:

In [None]:
def python_rollout(model, init_model, nbatch, nstep):
  for i in range(nbatch):
    data = init_model(model)
    for i in range(nstep):
      mujoco.mj_step(model, data)

To run nbatch rollouts with `rollout`, we need to make an array of nbatch initial states to start the rollouts from.

Additionally, to use `rollout`'s parallelism, we must pass one MjData per thread.

The resulting `rollout` call parameterized by nbatch, nstep, and nthread is:

In [None]:
def nthread_rollout(model, init_model, nbatch, nstep, nthread):
  # Initialize the MjData for the given model using the provided initializer
  data = init_model(model)
  rollout.rollout(model,
                  [copy.copy(data) for _ in range(nthread)], # Create one MjData per thread
                  np.tile(get_state(model, data), (nbatch, 1)), # Tile the initial condition nbatch times
                  nstep=nstep)

Next, we benchmark the Python loop and `rollout` in both single threaded and multithreaded modes. The three benchmarks take about 2.5 minutes in total to run in total on an AMD 5800X3D.

In [None]:
#@title Benchmarking and plotting code

def benchmark_rollout(model, data, init_model, nbatch, nstep, nominal_nbatch, nominal_nstep, ntiming=1):
  print('Benchmarking pure python', end='\r')
  start = time.time()
  t_python_nbatch = benchmark(lambda x: python_rollout(model, init_model, x,  nominal_nstep), nbatch, ntiming)
  t_python_nstep  = benchmark(lambda x: python_rollout(model, init_model, nominal_nbatch, x), nstep,  ntiming)
  end = time.time()
  print(f'Benchmarking pure python took {end-start:0.1f} seconds')

  print('Benchmarking single threaded rollout', end='\r')
  start = time.time()
  t_rollout_single_nbatch = benchmark(lambda x: nthread_rollout(model, init_model, x, nominal_nstep,  nthread=1), nbatch, ntiming)
  t_rollout_single_nstep  = benchmark(lambda x: nthread_rollout(model, init_model, nominal_nbatch, x, nthread=1), nstep,  ntiming)
  end = time.time()
  print(f'Benchmarking single threaded rollout took {end-start:0.1f} seconds')

  print(f'Benchmarking multithreaded rollout using {nthread} threads', end='\r')
  start = time.time()
  t_rollout_multi_nbatch = benchmark(lambda x: nthread_rollout(model, init_model, x, nominal_nstep,  nthread), nbatch, ntiming)
  t_rollout_multi_nstep  = benchmark(lambda x: nthread_rollout(model, init_model, nominal_nbatch, x, nthread), nstep,  ntiming)
  end = time.time()
  print(f'Benchmarking multithreaded rollout using {nthread} threads took {end-start:0.1f} seconds')

  return (t_python_nbatch, t_rollout_single_nbatch, t_rollout_multi_nbatch,
          t_python_nstep, t_rollout_single_nstep, t_rollout_multi_nstep)

def plot_benchmark(results, nbatch, nstep, nominal_nbatch, nominal_nstep):
  (t_python_nbatch, t_rollout_single_nbatch, t_rollout_multi_nbatch,
   t_python_nstep, t_rollout_single_nstep, t_rollout_multi_nstep) = results

  width = 0.25
  x = np.array([i for i in range(len(nbatch))])

  ticker = matplotlib.ticker.EngFormatter(unit='')

  fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
  steps_per_t = np.array(nbatch) * nominal_nstep
  steps_per_t_python = steps_per_t / t_python_nbatch
  steps_per_t_single = steps_per_t / t_rollout_single_nbatch
  steps_per_t_multi  = steps_per_t / t_rollout_multi_nbatch
  ax1.bar(x + 0*width, steps_per_t_python, width=width, label='python')
  ax1.bar(x + 1*width, steps_per_t_single, width=width, label='rollout single threaded')
  ax1.bar(x + 2*width, steps_per_t_multi, width=width, label='rollout multithreaded')
  ax1.set_xticks(x + width, nbatch)
  ax1.yaxis.set_major_formatter(ticker)
  ax1.grid()
  ax1.set_xlabel('nbatch')
  ax1.set_ylabel('steps per second')
  ax1.set_title(f'nbatch varied, nstep = {nominal_nstep}')

  x = np.array([i for i in range(len(nstep))])
  steps_per_t = np.array(nstep) * nominal_nbatch
  steps_per_t_python = steps_per_t / t_python_nstep
  steps_per_t_single = steps_per_t / t_rollout_single_nstep
  steps_per_t_multi  = steps_per_t / t_rollout_multi_nstep
  ax2.bar(x + 0*width, steps_per_t_python, width=width, label='python')
  ax2.bar(x + 1*width, steps_per_t_single, width=width, label='rollout single threaded')
  ax2.bar(x + 2*width, steps_per_t_multi, width=width, label='rollout multithreaded')
  ax2.set_xticks(x + width, nstep)
  ax2.yaxis.set_major_formatter(ticker)
  ax2.grid()
  ax2.set_xlabel('nstep')
  ax2.set_title(f'nstep varied, nbatch = {nominal_nbatch}')

  ax2.legend(loc=(1.04, 0.0))
  fig.set_size_inches(10, 4)
  plt.tight_layout()

### Tippe Top Benchmark

In [None]:
nominal_nbatch = 100 # Batch size to use when testing different nstep
nominal_nstep = 1000 # Step count to use when testing different nbatch
nbatch = [1, 10, 100, 500, 1000] # Batch sizes to benchmark
nstep = sorted([1, 10, 100, 1000, 2000, 4000]) # Step counts to benchmark

top_benchmark_results = benchmark_rollout(top_model, top_data, init_top, nbatch, nstep, nominal_nbatch, nominal_nstep)
plot_benchmark(top_benchmark_results, nbatch, nstep, nominal_nbatch, nominal_nstep)

### Humanoid Benchmark

In [None]:
nominal_nbatch = 200 # Batch size to use when testing different nstep
nominal_nstep = 500 # Step count to use when testing different nbatch
nbatch = [1, 10, 100, 200, 400] # Batch sizes to benchmark
nstep = sorted([1, 10, 100, 500, 1000]) # Step counts to benchmark

humanoid_benchmark_results = benchmark_rollout(humanoid_model, humanoid_data, init_humanoid, nbatch, nstep, nominal_nbatch, nominal_nstep)
plot_benchmark(humanoid_benchmark_results, nbatch, nstep, nominal_nbatch, nominal_nstep)

### Humanoid100 Benchmark

In [None]:
nominal_nbatch = 100 # Batch size to use when testing different nstep
nominal_nstep = 200 # Step count to use when testing different nbatch
nbatch = [1, 10, 50, 100, 200] # Batch sizes to benchmark
nstep = sorted([1, 10, 100, 200, 400]) # Step counts to benchmark

humanoid100_benchmark_results = benchmark_rollout(humanoid100_model, humanoid100_data, init_humanoid100, nbatch, nstep, nominal_nbatch, nominal_nstep)
plot_benchmark(humanoid100_benchmark_results, nbatch, nstep, nominal_nbatch, nominal_nstep)

## MJX versus `rollout`

Next we will benchmark `rollout` and MJX using the tippe top and humanoid models (humanoid100 is not supported by MJX).

The benchmark below takes about 5.5 minutes on an AMD 5800X3D and an NVIDIA 4090. Almost half the time is spent compiling the JIT functions. The JIT functions are cached so that subsequent runs of the benchmark run much faster.

**Note:** MJX is most useful when coupled with something else that runs best on a GPU, like a neural network. Without any such additional workload, CPU based simulation will sometimes be faster, especially when using less than state-of-the-art GPUs. In the results below, the tippe top model runs faster on the 4090 with batch sizes in the 1000's, however the humanoid model always runs slower than the CPU.

In [None]:
#@title MJX helper functions
def init_mjx_batch(model, init_model, nbatch, skip_jit=False):
  data = init_model(model)

  # Make MJX versions of model and data
  mjx_model = mjx.put_model(model)
  mjx_data = mjx.put_data(model, data)

  jit_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))
  batch = jax.vmap(lambda x: mjx_data)(jp.array(list(range(nbatch))))

  # Trigger JIT for model/batch so as not to include JIT time in benchmarking information
  if not skip_jit:
    batch = jit_step(mjx_model, batch)

  return mjx_model, mjx_data, jit_step, batch

def mjx_rollout(model, init_model, nbatch, nstep, jit_step=None):
  # Iniitalize model, skip JIT of stepping function if possible
  if jit_step is None:
    mjx_model, _, jit_step, batch = init_mjx_batch(model, init_model, nbatch)
  else:
    mjx_model, _, _, batch = init_mjx_batch(model, init_model, nbatch, skip_jit=True)

  for _ in range(nstep):
    batch = jit_step(mjx_model, batch)

def benchmark_mjx(model, init_model, nbatch, nstep, nominal_nbatch, nominal_nstep, ntiming=1, jit_steps=None):
  print(f'Benchmarking multithreaded rollout using {nthread} threads', end="\r")
  start = time.time()
  t_rollout_multi_nbatch = benchmark(lambda x: nthread_rollout(model, init_model, x, nominal_nstep,  nthread), nbatch, ntiming)
  t_rollout_multi_nstep  = benchmark(lambda x: nthread_rollout(model, init_model, nominal_nbatch, x, nthread), nstep,  ntiming)
  end = time.time()
  print(f'Benchmarking multithreaded rollout using {nthread} threads took {end-start:0.1f} seconds')

  print('Running JIT for MJX', end='\r')
  start = time.time()
  if jit_steps is None: jit_steps = {}
  for n in nbatch + [nominal_nbatch,]:
    if n not in jit_steps:
      _, _, jit_steps[n], _ = init_mjx_batch(model, init_model, n)
  end = time.time()
  print(f'Running JIT for MJX took {end-start:0.1f} seconds')

  print('Benchmarking MJX', end='\r')
  start = time.time()
  t_mjx_nbatch = benchmark(lambda x: mjx_rollout(model, init_model, x, nominal_nstep, jit_steps[x]), nbatch, ntiming)
  t_mjx_nstep  = benchmark(lambda x: mjx_rollout(model, init_model, nominal_nbatch, x, jit_steps[nominal_nbatch]), nstep, ntiming)
  end = time.time()
  print(f'Benchmarking MJX took {end-start:0.1f} seconds')

  return t_rollout_multi_nbatch, t_rollout_multi_nstep, t_mjx_nbatch, t_mjx_nstep

def plot_mjx_benchmark(results, nbatch, nstep, nominal_nbatch, nominal_nstep):
  t_rollout_multi_nbatch, t_rollout_multi_nstep, t_mjx_nbatch, t_mjx_nstep = results

  width = 0.333
  x = np.array([i for i in range(len(nbatch))])

  ticker = matplotlib.ticker.EngFormatter(unit='')

  fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
  steps_per_t = np.array(nbatch) * nominal_nstep
  steps_per_t_mjx = steps_per_t / t_mjx_nbatch
  steps_per_t_multi  = steps_per_t / t_rollout_multi_nbatch
  ax1.bar(x + 0*width, steps_per_t_mjx, width=width, label='mjx')
  ax1.bar(x + 1*width, steps_per_t_multi, width=width, label='rollout multithreaded')
  ax1.set_xticks(x + width / 2, nbatch)
  ax1.yaxis.set_major_formatter(ticker)
  ax1.grid()
  ax1.set_xlabel('nbatch')
  ax1.set_ylabel('steps per second')
  ax1.set_title(f'nbatch varied, nstep = {nominal_nstep}')

  x = np.array([i for i in range(len(nstep))])
  steps_per_t = np.array(nstep) * nominal_nbatch
  steps_per_t_mjx = steps_per_t / t_mjx_nstep
  steps_per_t_multi  = steps_per_t / t_rollout_multi_nstep
  ax2.bar(x + 0*width, steps_per_t_mjx, width=width, label='mjx')
  ax2.bar(x + 1*width, steps_per_t_multi, width=width, label='rollout multithreaded')
  ax2.set_xticks(x + width / 2, nstep)
  ax2.yaxis.set_major_formatter(ticker)
  ax2.grid()
  ax2.set_xlabel('nstep')
  ax2.set_title(f'nstep varied, nbatch = {nominal_nbatch}')

  ax2.legend(loc=(1.04, 0.0))
  fig.set_size_inches(10, 4)
  plt.tight_layout()

# Caches for jit_step functions, they take a long time to compile
top_jit_steps = {}
humanoid_jit_steps = {}
humanoid100_jit_steps = {}

### MJX Tippe Top Benchmark

In [None]:
nominal_nbatch = 50000 # Batch size to use when testing different nstep
nominal_nstep = 200 # Step count to use when testing different nbatch
nbatch = [100, 1000, 10000, 50000, 100000] # Batch sizes to benchmark
nstep = [1, 10, 100, 200] # Step counts to benchmark

mjx_top_results = benchmark_mjx(top_model, init_top, nbatch, nstep, nominal_nbatch, nominal_nstep, jit_steps=top_jit_steps)
plot_mjx_benchmark(mjx_top_results, nbatch, nstep, nominal_nbatch, nominal_nstep)

### MJX Humanoid Benchmark

In [None]:
nominal_nbatch = 10000 # Batch size to use when testing different nstep
nominal_nstep = 200 # Step count to use when testing different nbatch
nbatch = [100, 1000, 10000, 30000] # Batch sizes to benchmark
nstep = [1, 10, 100, 200, 400] # Step counts to benchmark

mjx_humanoid_results = benchmark_mjx(humanoid_model, init_humanoid, nbatch, nstep, nominal_nbatch, nominal_nstep, jit_steps=humanoid_jit_steps)
plot_mjx_benchmark(mjx_humanoid_results, nbatch, nstep, nominal_nbatch, nominal_nstep)

### MJX Multiple Humanoids in One Model

The MJX [documentation](https://mujoco.readthedocs.io/en/stable/mjx.html#mjx-the-sharp-bits) contains a chart comparing the speed of native MuJoCo vs MJX on a variety of devices.

Here we will produce a similar plot to compare MJX and with `rollout`. On a 5800X3D and 4090 devices the benchmark takes about 6.5 minutes to run.

**Note:** These results are not directly comparable since with the plot in the documentation was run on different devices and in particular an A100. Additionally, to run on a 4090 the batch size was redued from 8192 to 4096.

In [None]:
max_humanoids = 10
nbatch = 8192 // 2 # The original benchmark ran with a batch size of 8192, but on a 4090 we can only fit about 4096 humanoids
nstep = 200

jit_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))
t_rollout = []
t_mjx = []
for i in range(1, max_humanoids+1):
  print(f'Running benchmark on {i} humanoids')
  model = mujoco.MjModel.from_xml_path(f'mujoco/mjx/mujoco/mjx/test_data/humanoid/{i:02d}_humanoids.xml')
  data = mujoco.MjData(model)

  mjx_model = mjx.put_model(model)
  mjx_data = mjx.put_data(model, data)
  batch = jax.vmap(lambda x: mjx_data)(jp.array(list(range(nbatch))))

  start = time.perf_counter()
  rollout.rollout(model, [copy.copy(data) for _ in range(nthread)], initial_state=get_state(model, data, nbatch), nstep=humanoid_nstep)
  end = time.perf_counter()
  t_rollout.append(end-start)

  # Trigger JIT for model/batch so as not to include JIT time in benchmarking information
  batch = jit_step(mjx_model, batch)

  start = time.perf_counter()
  for _ in range(nstep):
    batch = jit_step(mjx_model, batch)
  end = time.perf_counter()
  t_mjx.append(end-start)

In [None]:
#@title Plot MJX nhumanoid benchmark

def plot_mjx_nhumanoid_benchmark(t_rollout, t_mjx, nbatch, nstep, max_humanoids):
  nhumanoids = [i for i in range(1, max_humanoids+1)]

  width = 0.333
  x = np.array([i for i in range(len(nhumanoids))])

  ticker = matplotlib.ticker.EngFormatter(unit='')

  fig, ax1 = plt.subplots(1, 1, sharey=True)
  steps_per_t = nbatch * nstep
  steps_per_t_mjx = steps_per_t / np.array(t_mjx)
  steps_per_t_multi  = steps_per_t / np.array(t_rollout)
  ax1.bar(x + 0*width, steps_per_t_mjx, width=width, label='mjx')
  ax1.bar(x + 1*width, steps_per_t_multi, width=width, label='rollout multithreaded')
  ax1.set_xticks(x + width / 2, nhumanoids)
  ax1.yaxis.set_major_formatter(ticker)
  ax1.set_yscale('log')
  ax1.grid()
  ax1.set_xlabel('number of humanoids')
  ax1.set_ylabel('steps per second')
  ax1.set_title(f'nhumanoids varied, nbatch = {nbatch}, nstep = {nstep}')

  ax1.legend(loc=(1.04, 0.0))
  fig.set_size_inches(8, 4)
  plt.tight_layout()

plot_mjx_nhumanoid_benchmark(t_rollout, t_mjx, nbatch, nstep, max_humanoids)

# Advanced Usage
## skip_checks=True

By default rollout performs many checks on the dimensions of its arguments. This it allows it to infer dimensions such as `nbatch` and `nstep`, tile arguments that were not fully specified, and allocate the returned `state` and `sensordata` arrays.

However, these check take time, particularly if `state` and `sensordata` are large or if there are many models and `nstep` is low. So advanced users may want to use the `skip_checks=True` argument in order to acheive additional performance.

If used, certain arguments become non-optional, and all signals must be fully defined (no implicit tiling). In particular:
* `model` must be a list of length `nbatch`
* `data` must be a list of length `nthread`
* `nstep` must be specified
* `initial_state` must be an array of shape `nbatch x nstate`
* `control` is optional, but if passed must be an array of shape `nbatch x nstep x ncontrol`
* `state` is optional, but must be passed if state is to be returned and must be of shape `nbatch x nstep x nstate`
* `sensordata` is optional, but must be passed if sensor data is to be returned and must be of shape `nbatch x nstep x nsensordata`

As an extreme example, we pass 10,000 humanoid models to `rollout` and simulate 1 step each with and without checks.

In [None]:
nbatch = 1000
nstep = [1, 10, 100, 500]
ntiming = 5

top_data = init_top(top_model)
top_datas = [copy.copy(top_data) for _ in range(nthread)]
initial_state = get_state(top_model, top_data)
initial_state_tiled = np.tile(initial_state, (nbatch, 1))

# Note: state, sensordata array automatically allocated and return
def rollout_with_checks(nstep):
  state, sensordata = rollout.rollout([top_model]*nbatch, top_datas, initial_state, nstep=nstep)

# Note: state, sensordata arrays have to be preallocated
state = None
sensordata = None
def rollout_skip_checks(nstep):
  # Note initial state must be tiled
  rollout.rollout([top_model]*nbatch, top_datas, initial_state_tiled, nstep=nstep,
                  state=state, sensordata=sensordata, skip_checks=True)

t_with_checks = benchmark(lambda x: rollout_with_checks(x), nstep, ntiming=ntiming)
t_skip_checks = benchmark(lambda x: rollout_skip_checks(x), nstep, ntiming=ntiming)

steps_per_second = (nbatch * np.array(nstep)) / np.array(t_with_checks)
steps_per_second_skip_checks = (nbatch * np.array(nstep)) / np.array(t_skip_checks)

plt.loglog(nstep, steps_per_second, label='with checks')
plt.loglog(nstep, steps_per_second_skip_checks, label='skip checks')
plt.ylabel('steps per second')
plt.xlabel('nstep')
plt.legend()
plt.grid()

As expected, as `nstep` increases, the benefits of using skip checks fades quickly. However, at low nstep and high batch sizes, it can make a significant difference.

Notice that the version with checks can use the non-tiled `initial_state`, however the skip checks version must used the tiled version, `initial_state_tiled`.

## Warmstarting

The `initial_warmstart` parameter can be used to warmstart the constraint solver as described in the [computation chapter](https://mujoco.readthedocs.io/en/stable/computation/index.html#warmstart-acceleration) of the documentation. This can be useful when rolling out models in chunks of steps. Without warmstarting, chaotic systems involving multi-body contact may diverge.

Below we demonstrate this with the tippe top model where the contact solver was changed to CG. This makes the contact force calculation a less repeatable than if the default, Newton's method, were used and allows demonstrating the benefits of warmstarting.

The simulation is run three times. Once with a 6000 step rollout, once with 100 chunks of 60 steps with warmstarting, and once more in 100 chunks of 60 steps without warmstarting.

In [None]:
model = copy.copy(top_model)
model.opt.solver = mujoco.mjtSolver.mjSOL_CG
data = init_top(model)

chunks = 100
steps_per_chunk = 60
nstep = steps_per_chunk*chunks

initial_state = get_state(model, data)

start = time.time()
state_all, _  = rollout.rollout(model, data, initial_state, nstep=nstep)

state_chunks = []
state_chunk, _ = rollout.rollout(model, data, initial_state, nstep=steps_per_chunk)
state_chunks.append(state_chunk)
for _ in range(chunks-1):
  state_chunk, _ = rollout.rollout(model, data, state_chunks[-1][0, -1, :], nstep=steps_per_chunk, initial_warmstart=data.qacc_warmstart)
  state_chunks.append(state_chunk)
state_all_chunked_warmstart = np.concatenate(state_chunks, axis=1)

state_chunks = []
state_chunk, _ = rollout.rollout(model, data, initial_state, nstep=steps_per_chunk)
state_chunks.append(state_chunk)
first_warmstart = None
for i in range(chunks-1):
  state_chunk, _ = rollout.rollout(model, data, state_chunks[-1][0, -1, :], nstep=steps_per_chunk)
  state_chunks.append(state_chunk)
state_all_chunked = np.concatenate(state_chunks, axis=1)
end = time.time()

start_render = time.time()
framerate = 60
state_render = np.concatenate((state_all, state_all_chunked, state_all_chunked_warmstart), axis=0)
camera = 'distant'
frames1 = render_many(model, data, state_all, framerate, shape=(240, 320), transparent=False, camera=camera)
frames2 = render_many(model, data, state_all_chunked_warmstart, framerate, shape=(240, 320), transparent=False, camera=camera)
frames3 = render_many(model, data, state_all_chunked, framerate, shape=(240, 320), transparent=False, camera=camera)
media.show_video(np.concatenate((frames1, frames2, frames3), axis=2))
end_render = time.time()

print(f'Rollout took {end-start:.1f} seconds')
print(f'Rendering took {end_render-start_render:.1f} seconds')

As expected, the middle animation (with warmstarting) matches the continuous rollout on the left. However, the model that did not use warmstarting diverged.

## chunk_size

To minimize communication overhead, `rollout` distributes rollouts to threads in groups of rollouts called chunks. By default, `max(1, 0.1 * (nbatch / nthread))` rollouts are assigned to each chunk. While this chunking rule works well for most workloads it is not always optimal, especially when doing short rollouts with small models.

Below we plot the steps per second versus chunk_size when running 1000 hoppers for 1 step each. In his case, the default chunk_size turns out to be quite a bit slower than using an increased chunk size.

In [None]:
nbatch = 100
nstep = 1
ntiming = 20

#print('Getting Hopper XML description from GitHub:')
!git clone https://github.com/google-deepmind/dm_control
hopper_model = mujoco.MjModel.from_xml_path('dm_control/dm_control/suite/hopper.xml')
hopper_data = mujoco.MjData(hopper_model)

initial_state = get_state(hopper_model, hopper_data)
initial_states = np.tile(initial_state, (nbatch, 1))

hopper_datas = [copy.copy(hopper_data) for _ in range(nthread)]

def rollout_chunk_size(chunk_size=None):
  rollout.rollout(hopper_model, hopper_datas, initial_states, nstep=nstep, chunk_size=chunk_size)

default_chunk_size = int(max(1.0, 0.1 * nbatch / nthread))
chunk_sizes = sorted([1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, default_chunk_size])
t_chunk_size = benchmark(lambda x: rollout_chunk_size(x), chunk_sizes, ntiming=ntiming)

steps_per_second = nbatch * nstep / t_chunk_size
default_index = [i for i, c in enumerate(chunk_sizes) if c == default_chunk_size][0]
optimal_index = np.argmax(steps_per_second)
plt.loglog(chunk_sizes, steps_per_second, color='b')
plt.plot(chunk_sizes[default_index], steps_per_second[default_index], marker='o', color='r', label='default chunk size')
plt.plot(chunk_sizes[optimal_index], steps_per_second[optimal_index], marker='o', color='g', label='optimal chunk size')
plt.ylabel('steps per second')
plt.xlabel('chunk size')
plt.legend()
plt.grid()

print(f'default chunk size: {default_chunk_size} \t steps per second: {steps_per_second[default_index]:0.1f}')
print(f'optimal chunk size: {chunk_sizes[optimal_index]} \t steps per second: {steps_per_second[optimal_index]:0.1f}')

## Reusing threadpools with the class `Rollout`

The `rollout` module provided the class `Rollout` in addition to the method `rollout`. The class `Rollout` is designed allow safe reuse of the internally managed thread pool.

Reuse can speed things up considerably when rollouts are short. Let's find out how the speedup changes for the tippe top model by rolling it out with increasing numbers of steps.

In [None]:
nbatch = 100
nsteps = [2**i for i in [2, 3, 4, 5, 6, 7]]
ntiming = 5

top_data = init_top(top_model)

initial_state = get_state(top_model, top_data)
initial_states = np.tile(initial_state, (nbatch, 1))

top_datas = [copy.copy(top_data) for _ in range(nthread)]

def rollout_method(nstep):
  for i in range(20):
    rollout.rollout(top_model, top_datas, initial_states, nstep=nstep)

def rollout_class(nstep):
  with rollout.Rollout(nthread=nthread) as rollout_:
    for i in range(20):
      rollout_.rollout(top_model, top_datas, initial_states, nstep=nstep)

t_method = benchmark(lambda x: rollout_method(x), nsteps, ntiming)
t_class = benchmark(lambda x: rollout_class(x), nsteps, ntiming)

plt.loglog(nsteps, nbatch * np.array(nsteps) / t_method, label='recreating threadpools')
plt.loglog(nsteps, nbatch * np.array(nsteps) / t_class, label='reusing threadpool')
plt.xlabel('nstep')
plt.ylabel('steps per second')
plt.legend()
plt.grid()

## Reusing threadpools with the method `rollout`

`rollout` will create and reuse a persistent threadpool by passing `persistent_pool=True`. However there are some caveats.

First, because `rollout` is a function and does not know when the user is done calling it, the threadpool pool needs to be shutdown manually like this:

In [30]:
nbatch = 1000
nstep = 1

top_data = init_top(top_model)
top_datas = [copy.copy(top_data) for _ in range(nthread)]

initial_state = get_state(top_model, top_data)
initial_states = np.tile(initial_state, (nbatch, 1))

rollout.rollout(model, top_datas, initial_states, nstep=nstep, persistent_pool=True) # Creates a pool
rollout.rollout(model, top_datas, initial_states, nstep=nstep, persistent_pool=True) # Reuses the previously created pool
rollout.shutdown_persistent_pool() # Shutdown the pool manually when finished

Second, if `rollout` reuses the same threadpool between calls, it is no longer safe to call `rollout` from multiple threads. For example the following is not allowed (the offending lines are commented out to avoid crashing the interpreter):

In [31]:
thread1 = threading.Thread(target=lambda: rollout.rollout(top_model, top_datas, initial_states, nstep=nstep, persistent_pool=True))
thread2 = threading.Thread(target=lambda: rollout.rollout(top_model, top_datas, initial_states, nstep=nstep, persistent_pool=True))

thread1.start()
#thread2.start() # Do not do this! rollout will be using the same persistent threadpool from two threads and may crash the interpreter
thread1.join()
#thread2.join()
rollout.shutdown_persistent_pool()