# Preparation

## Install MuJoCo, MJX, and Brax

In [1]:
!pip install mujoco
!pip install mujoco_mjx
!pip install brax

Collecting mujoco_mjx
  Downloading mujoco_mjx-3.3.4-py3-none-any.whl.metadata (3.4 kB)
Collecting jax (from mujoco_mjx)
  Downloading jax-0.7.0-py3-none-any.whl.metadata (13 kB)
Collecting jaxlib (from mujoco_mjx)
  Downloading jaxlib-0.7.0-cp313-cp313-macosx_11_0_arm64.whl.metadata (1.3 kB)
Collecting trimesh (from mujoco_mjx)
  Using cached trimesh-4.7.1-py3-none-any.whl.metadata (18 kB)
Collecting ml_dtypes>=0.5.0 (from jax->mujoco_mjx)
  Downloading ml_dtypes-0.5.3-cp313-cp313-macosx_10_13_universal2.whl.metadata (8.9 kB)
Collecting opt_einsum (from jax->mujoco_mjx)
  Downloading opt_einsum-3.4.0-py3-none-any.whl.metadata (6.3 kB)
Downloading mujoco_mjx-3.3.4-py3-none-any.whl (6.7 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.7/6.7 MB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m[36m0:00:01[0m[36m0:00:01[0m01[0m
[?25hDownloading jax-0.7.0-py3-none-any.whl (2.8 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.8/2.

## Import packages for plotting and creating graphics

In [20]:
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

Installing mediapy:


  pid, fd = os.forkpty()


## Import Mujoco, MJX, and Brax

In [21]:
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict


import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model


# LeLamp Model

Import LeLamp Mujoco Model

In [23]:
mj_model = mujoco.MjModel.from_xml_path('../models/lelamp/scene.xml')

# See number of joints in the model
mj_model.njnt

6

In [24]:
mj_data = mujoco.MjData(mj_model)

# See qpos and qvel
mj_data.qpos, mj_data.qvel, mj_data.sensordata

(array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 array([0., 0., 0., 0., 0., 0.]))

In [8]:
mj_data.subtree_com[1]

array([0., 0., 0.])

In [9]:
renderer = mujoco.Renderer(mj_model)

In [10]:
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

print(mj_data.qpos, type(mj_data.qpos))
print(mjx_data.qpos, type(mjx_data.qpos), mjx_data.qpos.devices())



[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] <class 'numpy.ndarray'>
[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] <class 'jaxlib._jax.ArrayImpl'> {CpuDevice(id=0)}




## Test Model Control

In [17]:
mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "dc15_a01_dummy_assy_idle_asm")

1

In [None]:
import numpy as np

# Your scene_option and simulation setup as before...
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True

duration = 3.8  # seconds
framerate = 60  # Hz

frames = []
mujoco.mj_resetData(mj_model, mj_data)

com = []

while mj_data.time < duration:
    t = mj_data.time

    # Control input for all 12 actuators
    ctrl = np.zeros(5)

    # Example: oscillate the last 5 position actuators with different frequencies
    for i in range(5):
        ctrl[i] = 0.5 * np.sin(2 * np.pi * (i - 6) * t)  # small amplitude sine

    mj_data.ctrl[:] = ctrl

    mujoco.mj_step(mj_model, mj_data)

    # Record the center of mass position
    # com.append(mj_data.subtree_com["dc15_a01_dummy_assy_idle_asm"].copy())

    if len(frames) < mj_data.time * framerate:
        renderer.update_scene(mj_data, scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)

media.show_video(frames, fps=framerate)


0
This browser does not support the video tag.


In [31]:
# RL Environment

MODEL_PATH = '../models/lelamp/scene.xml'

In [36]:
class LeLampEnv(PipelineEnv):
    """LeLamp environment."""


    def __init__(
        self,
        forward_reward_weight=1.25,
        healthy_reward=5.0,
        terminate_when_unhealthy=True,
        healthy_z_range=(1.0, 2.0),
        reset_noise_scale=1e-2,
        **kwargs,
    ):
        # Initialize the environment with the given parameters.
        mj_model = mujoco.MjModel.from_xml_path(MODEL_PATH)
        mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
        mj_model.opt.iterations = 6
        mj_model.opt.ls_iterations = 6

        sys = mjcf.load_model(mj_model)
        self.sys = sys

        physics_steps_per_control_step = 5
        kwargs['n_frames'] = kwargs.get('n_frames', physics_steps_per_control_step)
        kwargs['backend'] = 'mjx'

        super().__init__(sys, **kwargs)

        # Set up the pipeline for the environment.
        self._forward_reward_weight = forward_reward_weight
        self._healthy_reward = healthy_reward
        self._terminate_when_unhealthy = terminate_when_unhealthy
        self._healthy_z_range = healthy_z_range
        self._reset_noise_scale = reset_noise_scale

    def reset(self, rng: jp.ndarray) -> State:
        """Resets the environment to an initial state."""
        rng, rng1, rng2 = jax.random.split(rng, 3)

        low, hi = -self._reset_noise_scale, self._reset_noise_scale

        # Generalized q pos
        qpos = self.sys.qpos0 + jax.random.uniform(
            rng1, (self.sys.nq,), minval=low, maxval=hi
        )

        # Generalized q vel
        qvel = jax.random.uniform(
            rng2, (self.sys.nv,), minval=low, maxval=hi
        )

        # Init brax pipeline state
        data = self.pipeline_init(qpos, qvel)

        # Get the initial observation
        obs = self._get_obs(data, jp.zeros(self.sys.nu))

        # Initialize reward, done, and metrics
        reward, done, zero = jp.zeros(3)
        metrics = {
            'forward_reward': zero,
            'reward_linvel': zero,
            'reward_alive': zero,
            'x_position': zero,
            'y_position': zero,
            'distance_from_origin': zero,
            'x_velocity': zero,
            'y_velocity': zero,
        }
        return State(data, obs, reward, done, metrics)
    
    def step(self, state: State, action: jp.ndarray) -> State:
        """Steps the environment forward by one timestep."""
        data0 = state.pipeline_state
        data = self.pipeline_step(data0, action)

        # Forward reward
        com_before = data0.subtree_com['dc15_a01_dummy_assy_idle_asm']
        com_after = data.subtree_com['dc15_a01_dummy_assy_idle_asm']
        velocity = (com_after - com_before) / self.dt
        forward_reward = self._forward_reward_weight * velocity[0]

        # Healthy reward
        lamp_head_id = self.sys.body.index['lamp_head']
        lamp_head_pos = data.xpos[lamp_head_id]

        min_z, max_z = self._healthy_z_range
        is_healthy = jp.where(lamp_head_pos[2] < min_z, 0.0, 1.0)
        is_healthy = jp.where(lamp_head_pos[2] > max_z, 0.0, is_healthy)
        healthy_reward = self._healthy_reward * is_healthy

        obs = self._get_obs(data, action)

        # Calculate the reward
        reward = forward_reward + healthy_reward
        done = 0.0

        state.metrics.update({
            'forward_reward': forward_reward,
            'reward_linvel': forward_reward,
            'reward_alive': healthy_reward,
            'x_position': com_after[0],
            'y_position': com_after[1],
            'distance_from_origin': jp.linalg.norm(com_after),
            'x_velocity': velocity[0],
            'y_velocity': velocity[1],
        })

        return state.replace(
            pipeline_state=data,
            obs=obs,
            reward=reward,
            done=done,
        )

    def _get_obs(
        self, data: mjx.Data, action: jp.ndarray
    ) -> jp.ndarray:
        # Get the current position
        position = data.qpos[7:]

        # Get joint velocities
        velocities = data.qvel[6:]

        # Get sensor data
        sensor_data = data.sensordata[:6]

        return jp.concatenate([position, velocities, sensor_data])

envs.register_environment('lelamp', LeLampEnv)

In [37]:
# instantiate the environment
env_name = 'lelamp'
env = envs.get_environment(env_name)

# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)



In [None]:
import time

start = time.time()

state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

for i in range(10):
    ctrl = -0.1 * jp.ones(env.sys.nu)
    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)
    print(f"Step {i+1} done, elapsed: {time.time() - start:.2f} sec")

print("Rollout finished in", time.time() - start, "seconds")


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x1103642f0>>
Traceback (most recent call last):
  File "/Users/binhpham/miniconda3/envs/mujoco-tut/lib/python3.13/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
