![MuJoCo banner](https://raw.githubusercontent.com/google-deepmind/mujoco/main/banner.png)

This notebook provides a tutorial for using Predictive Sampling with MJX.

### Copyright notice

Copyright 2024 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

# Imports

In [0]:
!pip install mujoco
!pip install brax

# Set up GPU rendering.
from google.colab import files
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

# Check if installation was succesful.
try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

from IPython.display import clear_output
clear_output()
print("mujoco_mpc.mjx isn't installed yet -- please make sure you install it and fetch the task files from the MuJoCo Menagerie GitHub repository.")
from brax import base as brax_base
from brax.io import html
from brax.io import mjcf
from IPython.display import HTML
import jax
import matplotlib.pyplot as plt
from mujoco import mjx
from mujoco_mpc.mjx import predictive_sampling
from mujoco_mpc.mjx.tasks import insert
import numpy as np

# Load task, run optimization

In [0]:
sim_model_cpu, plan_model_cpu, cost_fn, instruction_fn = insert.get_models_and_cost_fn()

In [0]:
costs_to_compare = {}
for it in [0.8]:
  nsteps = 2000
  steps_per_plan = 10
  batch_size = 1024
  nsamples = 256
  nplans = batch_size // nsamples
  print(f'nplans: {nplans}')

  p = predictive_sampling.Planner(
      model=mjx.put_model(plan_model_cpu),
      cost=cost_fn,
      noise_scale=it,  # iterate on different values
      horizon=128,
      nspline=4,
      nsample=nsamples - 1,
      interp='zero',
      instruction_fn=instruction_fn,
  )

  sim_data = mujoco.MjData(sim_model_cpu)
  mujoco.mj_resetDataKeyframe(sim_model_cpu, sim_data, 0)
  # without kinematics, the first cost is off:
  mujoco.mj_forward(sim_model_cpu, sim_data)
  sim_data = mjx.put_data(sim_model_cpu, sim_data)
  q0s = np.tile(sim_data.qpos, (nplans, 1))
  def set_qpos(data, qpos):
    return data.replace(qpos=qpos)
  sim_datas = jax.vmap(set_qpos, in_axes=(None, 0))(sim_data, q0s)
  multi_policy = np.tile(sim_model_cpu.key_ctrl[0, :], (nplans, p.nspline, 1))
  mpc_rollout_multiplan = jax.vmap(
      predictive_sampling.mpc_rollout, in_axes=(
          None,  # nsteps
          None,  # steps_per_plan
          None,  # Planner
          0,     # init_policy
          0,     # rng
          None,  # sim_model
          0,     # sim_data
      )
  )

  sim_datas, final_policy, costs, trajectories, terms = jax.jit(
      mpc_rollout_multiplan, static_argnums=[0, 1]
  )(
      nsteps,
      steps_per_plan,
      p,
      jax.device_put(multi_policy),
      jax.random.split(jax.random.key(0), nplans),
      mjx.put_model(sim_model_cpu),
      sim_datas,
  )
  costs = np.sum(terms.reshape(nplans, -1, terms.shape[-1])[:, :, 2:13:2], axis=-1)
  costs_to_compare[it] = costs

  plt.figure()
  plt.xlim([0, nsteps * sim_model_cpu.opt.timestep])
  plt.ylim([0, max(costs.flatten())])
  plt.xlabel('time')
  plt.ylabel('cost')
  x_time = [i * sim_model_cpu.opt.timestep for i in range(nsteps)]
  for i in range(nplans):
    plt.plot(x_time, costs[i, :], alpha=0.1)
  avg = np.mean(costs, axis=0)
  plt.plot(x_time, avg, linewidth=2.0)
  var = np.var(costs, axis=0)
  plt.errorbar(
      x_time,
      avg,
      yerr=var,
      fmt='none',
      ecolor='b',
      elinewidth=1,
      alpha=0.2,
      capsize=0,
  )

  plt.show()
  costs_to_compare[it] = costs

In [0]:
plt.figure()
plt.xlim([0, nsteps * sim_model_cpu.opt.timestep])
plt.ylim([0, max(costs.flatten())])
plt.xlabel('time')
plt.ylabel('cost')
x_time = [i * sim_model_cpu.opt.timestep for i in range(nsteps)]
for val, costs in costs_to_compare.items():
  avg = np.mean(costs, axis=0)
  plt.plot(x_time, avg, label=str(val))
  var = np.var(costs, axis=0)
  plt.errorbar(
      x_time, avg, yerr=var, fmt='none', elinewidth=1, alpha=0.2, capsize=0
  )

plt.legend()
plt.show()

In [0]:
d = mujoco.MjData(sim_model_cpu)
sys = mjcf.load_model(sim_model_cpu)
xstates = []
for qpos in trajectories.q[2].reshape(-1, sim_model_cpu.nq):
  d.qpos = qpos
  mujoco.mj_kinematics(sim_model_cpu, d)
  x = brax_base.Transform(pos=d.xpos[1:].copy(), rot=d.xquat[1:].copy())
  xstates.append(brax_base.State(q=None, qd=None, x=x, xd=None, contact=None))

HTML(html.render(sys, xstates))