# UR5 MuJoCo Inference on Google Colab

This notebook walks through installing dependencies, downloading a UR5 MuJoCo model, and wiring the UR5-specific observation/action transforms described in the repository into an `openpi` policy for simulation inference. We reuse the UR5e normalization statistics that ship with the base checkpoints so the policy receives states and actions in the correct scale.

## 0. Runtime preparation

* In Colab choose **Runtime → Change runtime type → GPU (T4 or better)**.
* Connect the runtime before running the cells below.

In [1]:
# Detect whether we're running inside Google Colab and install Miniconda if needed.
# The first execution in Colab restarts the runtime; rerun after reconnecting.
import os
import subprocess
import sys

IS_COLAB = "COLAB_RELEASE_TAG" in os.environ
if IS_COLAB:
    subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', 'condacolab'], check=True)
    import condacolab
    condacolab.install_miniconda(python_version='3.11')
else:
    print('Running outside of Colab; skipping Miniconda setup.')

Running outside of Colab; skipping Miniconda setup.


In [2]:
import os
import pathlib
import sys

NOTEBOOK_DIR = pathlib.Path.cwd()

def _find_repo_root(start: pathlib.Path) -> pathlib.Path:
    for path in (start,) + tuple(start.parents):
        if (path / 'pyproject.toml').exists():
            return path
    return start

REPO_ROOT = _find_repo_root(NOTEBOOK_DIR)

if IS_COLAB:
    runtime_root = pathlib.Path('/content')
else:
    default_root = pathlib.Path(os.environ.get('OPENPI_NOTEBOOK_DATA_ROOT', str(pathlib.Path.home() / 'openpi_ur5_notebook')))
    if not default_root.is_absolute():
        default_root = NOTEBOOK_DIR / default_root
    runtime_root = default_root

RUNTIME_ROOT = runtime_root.expanduser().resolve()
RUNTIME_ROOT.mkdir(parents=True, exist_ok=True)

MENAGERIE_DIR = RUNTIME_ROOT / 'mujoco_menagerie'
CACHE_DIR = RUNTIME_ROOT / 'openpi_cache'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

src_path = REPO_ROOT / 'src'
if src_path.exists() and str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

print(f'Notebook directory: {NOTEBOOK_DIR}')
print(f'Repository root: {REPO_ROOT}')
print(f'Runtime root: {RUNTIME_ROOT}')
print(f'Menagerie directory: {MENAGERIE_DIR}')
print(f'OpenPI cache directory: {CACHE_DIR}')

Notebook directory: /workspace/openpi/examples/ur5
Repository root: /workspace/openpi
Runtime root: /root/openpi_ur5_notebook
Menagerie directory: /root/openpi_ur5_notebook/mujoco_menagerie
OpenPI cache directory: /root/openpi_ur5_notebook/openpi_cache


## 1. Python packages

After the runtime restarts, rerun the previous cell once (it becomes a no-op) and execute the cell below to install the packages required for MuJoCo and `openpi`.

In [3]:
import os
import subprocess
import sys

os.environ.setdefault('GIT_LFS_SKIP_SMUDGE', '1')
os.environ.setdefault('XLA_PYTHON_CLIENT_PREALLOCATE', 'false')

pip_cmd = [sys.executable, '-m', 'pip']
if IS_COLAB:
    subprocess.run(pip_cmd + ['install', '--upgrade', 'pip'], check=True)
    subprocess.run(
        pip_cmd + ['install', '--quiet', 'mujoco==3.2.3', 'mediapy==1.2.4', 'imageio==2.36.1'],
        check=True,
    )
    subprocess.run(
        pip_cmd + ['install', '--quiet', 'git+https://github.com/Physical-Intelligence/openpi.git'],
        check=True,
    )
else:
    print('Non-Colab environment detected; skipping package installation.')

Non-Colab environment detected; skipping package installation.


## 2. Download MuJoCo UR5 assets

This step pulls only the UR5e files from the MuJoCo Menagerie repository using sparse checkout.

In [4]:
import subprocess

if not MENAGERIE_DIR.exists():
    MENAGERIE_DIR.parent.mkdir(parents=True, exist_ok=True)
    subprocess.run(
        [
            'git',
            'clone',
            '--filter=blob:none',
            '--depth',
            '1',
            '--sparse',
            'https://github.com/google-deepmind/mujoco_menagerie.git',
            str(MENAGERIE_DIR),
        ],
        check=True,
    )
    subprocess.run(['git', 'sparse-checkout', 'init', '--cone'], cwd=MENAGERIE_DIR, check=True)
    subprocess.run(['git', 'sparse-checkout', 'set', 'universal_robots_ur5e'], cwd=MENAGERIE_DIR, check=True)
else:
    subprocess.run(['git', 'pull'], cwd=MENAGERIE_DIR, check=True)

print(f'MuJoCo menagerie ready at {MENAGERIE_DIR}')

Already up to date.
MuJoCo menagerie ready at /root/openpi_ur5_notebook/mujoco_menagerie


In [5]:
import os

os.environ.setdefault('OPENPI_DATA_HOME', str(CACHE_DIR))
os.environ.setdefault('JAX_PLATFORM_NAME', 'gpu' if IS_COLAB else 'cpu')

UR5_ASSET_DIR = MENAGERIE_DIR / 'universal_robots_ur5e'
SCENE_XML = UR5_ASSET_DIR / 'scene.xml'
if not SCENE_XML.exists():
    raise FileNotFoundError(f'Missing MuJoCo scene at {SCENE_XML}')
print('Using MuJoCo scene:', SCENE_XML)

Using MuJoCo scene: /root/openpi_ur5_notebook/mujoco_menagerie/universal_robots_ur5e/scene.xml


In [6]:
# Provide a lightweight stub for `lerobot` when the package is unavailable.
import sys
import types

try:
    import lerobot.common.datasets.lerobot_dataset as _lerobot_dataset  # noqa: F401
except ModuleNotFoundError:
    root_mod = types.ModuleType('lerobot')
    common_mod = types.ModuleType('lerobot.common')
    datasets_mod = types.ModuleType('lerobot.common.datasets')

    class _StubLeRobotDatasetMetadata:
        def __init__(self, repo_id: str):
            self.repo_id = repo_id
            self.fps = 20.0
            self.tasks = []

    class _StubLeRobotDataset:
        def __init__(self, repo_id: str, delta_timestamps: dict | None = None, **_kwargs):
            self.repo_id = repo_id
            self.delta_timestamps = {} if delta_timestamps is None else delta_timestamps
            self.tasks = []

        def __getitem__(self, index):
            raise IndexError('LeRobot dataset stub has no samples')

        def __len__(self):
            return 0

        @classmethod
        def create(cls, *args, **kwargs):
            return cls(args[0] if args else kwargs.get('repo_id', 'stub'))

    dataset_mod = types.ModuleType('lerobot.common.datasets.lerobot_dataset')
    dataset_mod.LeRobotDatasetMetadata = _StubLeRobotDatasetMetadata
    dataset_mod.LeRobotDataset = _StubLeRobotDataset

    root_mod.common = common_mod
    common_mod.datasets = datasets_mod
    datasets_mod.lerobot_dataset = dataset_mod

    sys.modules['lerobot'] = root_mod
    sys.modules['lerobot.common'] = common_mod
    sys.modules['lerobot.common.datasets'] = datasets_mod
    sys.modules['lerobot.common.datasets.lerobot_dataset'] = dataset_mod

In [7]:
import dataclasses
import pathlib
from typing import Any

import numpy as np
from typing_extensions import override

from openpi import transforms as _transforms
from openpi.models import model as _model
from openpi.models import pi0_config
from openpi.policies import policy_config
from openpi.training import config as config_lib
from openpi.training import weight_loaders

def _parse_image(image: Any) -> np.ndarray:
    array = np.asarray(image)
    if np.issubdtype(array.dtype, np.floating):
        array = np.clip(array * 255.0, 0, 255).astype(np.uint8)
    if array.ndim == 3 and array.shape[0] in (1, 3):
        array = np.transpose(array, (1, 2, 0))
    return array

@dataclasses.dataclass(frozen=True)
class UR5Inputs(_transforms.DataTransformFn):
    model_type: _model.ModelType = _model.ModelType.PI0

    def __call__(self, data: dict) -> dict:
        state = np.concatenate([data['joints'], data['gripper']])
        base_image = _parse_image(data['base_rgb'])
        wrist_image = _parse_image(data['wrist_rgb'])
        inputs = {
            'state': state,
            'image': {
                'base_0_rgb': base_image,
                'left_wrist_0_rgb': wrist_image,
                'right_wrist_0_rgb': np.zeros_like(base_image),
            },
            'image_mask': {
                'base_0_rgb': np.True_,
                'left_wrist_0_rgb': np.True_,
                'right_wrist_0_rgb': np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
            },
        }
        if 'prompt' in data:
            inputs['prompt'] = data['prompt']
        if 'actions' in data:
            inputs['actions'] = data['actions']
        return inputs

@dataclasses.dataclass(frozen=True)
class UR5Outputs(_transforms.DataTransformFn):
    def __call__(self, data: dict) -> dict:
        return {'actions': np.asarray(data['actions'][:, :7])}

@dataclasses.dataclass(frozen=True)
class LeRobotUR5DataConfig(config_lib.DataConfigFactory):
    @override
    def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> config_lib.DataConfig:
        repack_transform = _transforms.Group(
            inputs=[
                _transforms.RepackTransform(
                    {
                        'base_rgb': 'image',
                        'wrist_rgb': 'wrist_image',
                        'joints': 'joints',
                        'gripper': 'gripper',
                        'prompt': 'prompt',
                    }
                )
            ]
        )
        data_transforms = _transforms.Group(
            inputs=[UR5Inputs(model_type=model_config.model_type)],
            outputs=[UR5Outputs()],
        )
        delta_action_mask = _transforms.make_bool_mask(6, -1)
        data_transforms = data_transforms.push(
            inputs=[_transforms.DeltaActions(delta_action_mask)],
            outputs=[_transforms.AbsoluteActions(delta_action_mask)],
        )
        model_transforms = config_lib.ModelTransformFactory()(model_config)
        return dataclasses.replace(
            self.create_base_config(assets_dirs, model_config),
            repack_transforms=repack_transform,
            data_transforms=data_transforms,
            model_transforms=model_transforms,
        )

def build_ur5_train_config() -> config_lib.TrainConfig:
    return config_lib.TrainConfig(
        name='pi0_ur5_mujoco',
        model=pi0_config.Pi0Config(),
        data=LeRobotUR5DataConfig(
            repo_id='local/ur5_mujoco',
            assets=config_lib.AssetsConfig(
                assets_dir='gs://openpi-assets/checkpoints/pi0_base/assets',
                asset_id='ur5e',
            ),
            base_config=config_lib.DataConfig(prompt_from_task=True),
        ),
        weight_loader=weight_loaders.CheckpointWeightLoader('gs://openpi-assets/checkpoints/pi0_base/params'),
    )


In [8]:
ur5_config = build_ur5_train_config()
if IS_COLAB:
    try:
        policy = policy_config.create_trained_policy(
            ur5_config,
            checkpoint_dir='gs://openpi-assets/checkpoints/pi0_base',
            default_prompt='move to a ready pose',
        )
        print('Policy ready; model type:', ur5_config.model.model_type)
    except Exception as exc:
        print('Failed to load pretrained policy, falling back to a zero-action stub.\n', exc)

        class _ZeroPolicy:
            def infer(self, inputs):
                return {'actions': np.zeros((1, ur5_config.model.action_dim), dtype=np.float32)}

        policy = _ZeroPolicy()
else:
    print('Skipping remote checkpoint loading outside Colab; using a zero-action stub policy.')

    class _ZeroPolicy:
        def infer(self, inputs):
            return {'actions': np.zeros((1, ur5_config.model.action_dim), dtype=np.float32)}

    policy = _ZeroPolicy()

Skipping remote checkpoint loading outside Colab; using a zero-action stub policy.


In [9]:
import mediapy as media
import mujoco

class UR5MujocoEnv:
    def __init__(self, scene_path: str, control_hz: float = 20.0, render_size: int = 224):
        self.model = mujoco.MjModel.from_xml_path(scene_path)
        self.data = mujoco.MjData(self.model)
        mujoco.mj_resetDataKeyframe(self.model, self.data, 0)
        mujoco.mj_forward(self.model, self.data)
        self._render_size = render_size
        try:
            self.renderer = mujoco.Renderer(self.model, width=render_size, height=render_size)
        except mujoco.FatalError as exc:
            print('MuJoCo renderer unavailable, falling back to blank image output.', exc)
            self.renderer = None
        self.base_camera = self._make_overview_camera() if self.renderer is not None else None
        self.wrist_camera = self._make_wrist_camera() if self.renderer is not None else None
        self.control_hz = control_hz
        self.sim_substeps = max(1, int(round((1.0 / control_hz) / self.model.opt.timestep)))
        self._ctrl_target = np.zeros(self.model.nu, dtype=np.float32)
        self._joint_lower = self.model.jnt_range[:, 0].copy()
        self._joint_upper = self.model.jnt_range[:, 1].copy()
        self.reset()

    def _make_overview_camera(self):
        cam = mujoco.MjvCamera()
        mujoco.mjv_defaultCamera(cam)
        cam.type = mujoco.mjtCamera.mjCAMERA_FREE
        cam.distance = 1.0
        cam.azimuth = 90
        cam.elevation = -35
        cam.lookat = np.array([0.4, 0.0, 0.2])
        return cam

    def _make_wrist_camera(self):
        cam = mujoco.MjvCamera()
        mujoco.mjv_defaultCamera(cam)
        cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING
        cam.trackbodyid = self.model.body('wrist_3_link').id
        cam.distance = 0.25
        cam.azimuth = 180
        cam.elevation = -15
        return cam

    def reset(self):
        mujoco.mj_resetDataKeyframe(self.model, self.data, 0)
        mujoco.mj_forward(self.model, self.data)
        self._ctrl_target = self.data.ctrl.copy()

    def render_images(self):
        if self.renderer is None:
            blank = np.zeros((self._render_size, self._render_size, 3), dtype=np.uint8)
            return blank, blank
        self.renderer.update_scene(self.data, camera=self.base_camera)
        base = np.asarray(self.renderer.render()).copy()
        self.renderer.update_scene(self.data, camera=self.wrist_camera)
        wrist = np.asarray(self.renderer.render()).copy()
        base = (base * 255).astype(np.uint8)
        wrist = (wrist * 255).astype(np.uint8)
        return base, wrist

    def observe(self, prompt: str) -> dict:
        base, wrist = self.render_images()
        return {
            'base_rgb': base,
            'wrist_rgb': wrist,
            'joints': self.data.qpos.copy().astype(np.float32),
            'gripper': np.zeros(1, dtype=np.float32),
            'prompt': prompt,
        }

    def step(self, action: np.ndarray, smoothing: float = 0.2):
        joint_target = np.array(action[: self.model.nq], dtype=np.float32)
        joint_target = np.clip(joint_target, self._joint_lower, self._joint_upper)
        self._ctrl_target = (1 - smoothing) * self._ctrl_target + smoothing * joint_target
        self.data.ctrl[:] = self._ctrl_target
        for _ in range(self.sim_substeps):
            mujoco.mj_step(self.model, self.data)

In [10]:
env = UR5MujocoEnv(str(SCENE_XML))
prompt = 'Move the arm smoothly.'
warmup_obs = env.observe(prompt)
_ = policy.infer(warmup_obs)

frames = []
for cycle in range(3):
    obs = env.observe(prompt)
    action_chunk = policy.infer(obs)['actions']
    print(f'Cycle {cycle}: action chunk shape {action_chunk.shape}')
    for action in action_chunk[:5]:
        env.step(action)
        base_img, wrist_img = env.render_images()
        panel = np.concatenate([base_img, wrist_img], axis=1)
        frames.append(panel)

try:
    media.show_video(frames, fps=20)
except RuntimeError as exc:
    print('Video preview unavailable:', exc)

/root/.pyenv/versions/3.11.12/lib/python3.11/site-packages/glfw/__init__.py:917: GLFWError: (65550) b'X11: The DISPLAY environment variable is missing'
/root/.pyenv/versions/3.11.12/lib/python3.11/site-packages/glfw/__init__.py:917: GLFWError: (65537) b'The GLFW library is not initialized'
Exception ignored in: <function Renderer.__del__ at 0x7f4c44c507c0>
Traceback (most recent call last):
  File "/root/.pyenv/versions/3.11.12/lib/python3.11/site-packages/mujoco/renderer.py", line 335, in __del__
    self.close()
  File "/root/.pyenv/versions/3.11.12/lib/python3.11/site-packages/mujoco/renderer.py", line 323, in close
    if self._mjr_context:
       ^^^^^^^^^^^^^^^^^
AttributeError: 'Renderer' object has no attribute '_mjr_context'


MuJoCo renderer unavailable, falling back to blank image output. an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
Cycle 0: action chunk shape (1, 32)
Cycle 1: action chunk shape (1, 32)
Cycle 2: action chunk shape (1, 32)
Video preview unavailable: Program 'ffmpeg' is not found; perhaps install ffmpeg using 'apt install ffmpeg'.


## 3. Next steps

* Replace the MuJoCo scene with your task objects; only the camera setup and joint observations need to match the UR5 mapping.
* Swap `checkpoint_dir` for your own fine-tuned checkpoint if you trained on UR5 data.
* Integrate robot-specific control or gripper dynamics if your simulated setup differs from the simple position-controlled example here.