Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Rendering in MJX for Simulated Camera Reinforcement Learning #1682

Closed
markusheimerl opened this issue May 21, 2024 · 6 comments
Closed
Labels
enhancement New feature or request

Comments

@markusheimerl
Copy link

markusheimerl commented May 21, 2024

import jax
import time
import mujoco
import functools
from brax import envs
from mujoco import mjx
from etils import epath
from jax import numpy as jp
from brax.io import mjcf, model
from brax.envs.base import PipelineEnv, State
from brax.training.agents.ppo import train as ppo

class Humanoid(PipelineEnv):

    def __init__(self, **kwargs):
        mj_model = mujoco.MjModel.from_xml_path((epath.Path(epath.resource_path("mujoco")) / ("mjx/test_data/humanoid") / "humanoid.xml").as_posix())
        mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
        mj_model.opt.iterations = 6
        mj_model.opt.ls_iterations = 6

        sys = mjcf.load_model(mj_model)

        physics_steps_per_control_step = 5
        kwargs["n_frames"] = kwargs.get("n_frames", physics_steps_per_control_step)
        kwargs["backend"] = "mjx"

        super().__init__(sys, **kwargs)

    def reset(self, rng: jp.ndarray) -> State:
        """Resets the environment to an initial state."""
        rng, rng1, rng2 = jax.random.split(rng, 3)

        low, hi = -1e-2, 1e-2 # reset_noise_scale
        qpos = self.sys.qpos0 + jax.random.uniform(rng1, (self.sys.nq,), minval=low, maxval=hi)
        qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi)

        data = self.pipeline_init(qpos, qvel)

        obs = self._get_obs(data, jp.zeros(self.sys.nu), None)
        reward, done, zero = jp.zeros(3)
        metrics = {
            "forward_reward": zero, "reward_linvel": zero, "reward_quadctrl": zero, "reward_alive": zero,
            "x_position": zero, "y_position": zero,"distance_from_origin": zero,"x_velocity": zero,"y_velocity": zero
        }
        return State(data, obs, reward, done, metrics)

    def step(self, state: State, action: jp.ndarray) -> State:
        """Runs one timestep of the environment's dynamics."""
        data0 = state.pipeline_state
        data = self.pipeline_step(data0, action)

        com_before = data0.subtree_com[1] # center of mass
        com_after = data.subtree_com[1]
        velocity = (com_after - com_before) / self.dt
        forward_reward = 1.25 * velocity[0]

        min_z, max_z = (1.0, 2.0) # healthy_z_range
        is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
        is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
        healthy_reward = 5.0
        ctrl_cost = (0.1 * jp.sum(jp.square(action)))

        state.metrics.update(
            forward_reward=forward_reward, reward_linvel=forward_reward, reward_quadctrl=-ctrl_cost, 
            reward_alive=healthy_reward, x_position=com_after[0], y_position=com_after[1], 
            distance_from_origin=jp.linalg.norm(com_after), x_velocity=velocity[0], y_velocity=velocity[1]
        )
        return state.replace(
            pipeline_state=data, obs=self._get_obs(data, action, state), 
            reward=(forward_reward + healthy_reward - ctrl_cost), done=(1.0 - is_healthy)
        )

    def _get_obs(self, data: mjx.Data, action: jp.ndarray, state) -> jp.ndarray:
        """Observes humanoid body position, velocities, and angles."""
        if state is None:
          pixels = jax.numpy.zeros((240,320)).ravel()
        else:
          pixels = jax.numpy.array(self.render(state.pipeline_state, camera="egocentric")).ravel()
          
        return jp.concatenate([data.qpos[2:], data.qvel, data.cinert[1:].ravel(), data.cvel[1:].ravel(), data.qfrc_actuator, pixels])


train_fn = functools.partial(
    ppo.train, num_timesteps=3_000_000, num_evals=5, reward_scaling=0.1, episode_length=1000, normalize_observations=True,
    action_repeat=1, unroll_length=10, num_minibatches=16, num_updates_per_batch=8, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3,
    num_envs=1024, batch_size=512, seed=0
)

start_time = time.time()

print(f"Registering environment... ({time.time() - start_time:.2f}s)")
envs.register_environment("humanoid_mjx", Humanoid)

print(f"Starting training... ({time.time() - start_time:.2f}s)")
make_inference_fn, params, _ = train_fn(environment=envs.get_environment("humanoid_mjx"))

print(f"Training completed. Saving parameters... ({time.time() - start_time:.2f}s)")
model.save_params("/home/markusheimerl/mjx_brax_policy_v2", params)

print(f"Parameters saved. ({time.time() - start_time:.2f}s)")

results in:

(venv_mjx) markusheimerl@tower:~/sim/hum$ python brax_humanoid.py 
Registering environment... (0.00s)
Starting training... (0.00s)
Traceback (most recent call last):
  File "/home/markusheimerl/sim/hum/brax_humanoid.py", line 95, in <module>
    make_inference_fn, params, _ = train_fn(environment=envs.get_environment("humanoid_mjx"))
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/training/agents/ppo/train.py", line 405, in train
    metrics = evaluator.run_evaluation(
              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/training/acting.py", line 125, in run_evaluation
    eval_state = self._generate_eval_unroll(policy_params, unroll_key)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/training/acting.py", line 107, in generate_eval_unroll
    return generate_unroll(
           ^^^^^^^^^^^^^^^^
  File "/home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/training/acting.py", line 75, in generate_unroll
    (final_state, _), data = jax.lax.scan(
                             ^^^^^^^^^^^^^
  File "/home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/training/acting.py", line 71, in f
    nstate, transition = actor_step(
                         ^^^^^^^^^^^
  File "/home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/training/acting.py", line 43, in actor_step
    nstate = env.step(env_state, actions)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/envs/wrappers/training.py", line 176, in step
    nstate = self.env.step(state, action)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/envs/wrappers/training.py", line 122, in step
    state = self.env.step(state, action)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/envs/wrappers/training.py", line 71, in step
    return jax.vmap(self.env.step)(state, action)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/envs/wrappers/training.py", line 93, in step
    state, rewards = jax.lax.scan(f, state, (), self.action_repeat)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/envs/wrappers/training.py", line 90, in f
    nstate = self.env.step(state, action)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/markusheimerl/sim/hum/brax_humanoid.py", line 69, in step
    pipeline_state=data, obs=self._get_obs(data, action, state), 
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/markusheimerl/sim/hum/brax_humanoid.py", line 78, in _get_obs
    pixels = jax.numpy.array(self.render(state.pipeline_state, camera="egocentric")).ravel()
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/envs/base.py", line 159, in render
    return image.render_array(self.sys, trajectory, height, width, camera)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/io/image.py", line 48, in render_array
    return get_image(trajectory)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/io/image.py", line 40, in get_image
    d.qpos, d.qvel = state.q, state.qd
    ^^^^^^
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[28].
The error occurred while tracing the function f at /home/markusheimerl/venv_mjx/lib/python3.12/site-packages/brax/envs/wrappers/training.py:89 for scan. This concrete value was not available in Python because it depends on the value of the argument state.pipeline_state.q.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Description:
I'm encountering an issue while using MuJoCo with JAX (mjx) for training a humanoid model in the Brax environment. The problem arises when attempting to render the environment state and retrieve camera images during training. The MuJoCo renderer does not seem to work properly when using mjx.

Problem Details:
When calling the render method within the custom Humanoid class, an error is thrown:

jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[28].

This error indicates that the conversion method is called on a traced array, which depends on the value of the argument state.pipeline_state.q. The current implementation of the MuJoCo renderer in mjx does not handle the conversion of traced arrays properly in this context.

Importance of Camera Input in Reinforcement Learning:
Using camera input is crucial when training robots with reinforcement learning. In real-world scenarios, robots rely on visual information captured by their cameras to perceive and interact with the environment. By incorporating camera pixels as part of the observation space during training, the learned policies can be more robust and adaptable to real-world conditions.

Proposed Solution:
To enable effective training with camera input using mjx, it is essential to address the compatibility issue between the MuJoCo renderer and JAX traced arrays. Possible solutions include:

  1. Enhancing the mjx renderer to handle JAX traced arrays directly, ensuring seamless integration with the Brax environment.
  2. Providing alternative rendering methods or workarounds specifically designed for mjx, allowing the retrieval of camera images during training.
  3. Improving documentation and examples that demonstrate the proper usage of camera input with mjx and provide guidance on handling traced arrays in this context.

Alternatives Considered:

  • Preprocessing the camera images outside of the JAX computation graph, but this approach may introduce additional overhead and complexity to the training pipeline.
  • Using alternative rendering libraries or techniques, but this may limit the flexibility and compatibility with the existing MuJoCo and Brax ecosystem.

Additional Context:
Integrating camera input into reinforcement learning algorithms is crucial for developing intelligent and adaptable robots. By leveraging the power of JAX and mjx, researchers and developers can accelerate the training process and build more sophisticated models. However, the current compatibility issue between the MuJoCo renderer and JAX traced arrays hinders the effective utilization of camera input in this setup.

Addressing this issue and providing a seamless integration between the MuJoCo renderer and mjx will greatly benefit the robotics and reinforcement learning community. It will enable researchers to train models that can effectively process visual information, leading to more advanced and capable robots.

Thank you for considering this issue. Your support in resolving the compatibility problem and enhancing the usability of camera input with mjx will contribute to the advancement of robotics research and real-world applications.

@markusheimerl markusheimerl added the enhancement New feature or request label May 21, 2024
@markusheimerl markusheimerl changed the title Support for Rendering with MJX for Enhanced Training Pipelines Support for Rendering in MJX for Simulated Camera Reinforcement Learning May 21, 2024
@yuvaltassa
Copy link
Collaborator

You are 100% right.
We are working on option 2.

@markusheimerl
Copy link
Author

markusheimerl commented May 22, 2024

@yuvaltassa Would you mind making the corresponding feature branch public in this repo? Id love to contribute.

@yuvaltassa
Copy link
Collaborator

The moment we have something that works it will be OSS and we would love you to contribute!

@erikfrey is leading this effort, perhaps there is something he'd like to add.

@erikfrey
Copy link
Collaborator

Hello! Please see #1604 and #1485 for related discussions. You can do visual observations today using mjx.ray, although this only works for toy environments.

We are working on integrating Madrona as a means for high throughput tiled rendering on GPU, but this is still very much a work in progress. We'll share more once we have a good proof of concept - no ETA but this is actively under development.

@markusheimerl
Copy link
Author

markusheimerl commented May 25, 2024

@yuvaltassa @erikfrey thanks for your input! Have you considered https://github.com/JoeyTeng/jaxrenderer ?
jaxrendererclasses

import re
import jax
import numpy as onp
from PIL import Image
from jax import numpy as jp
from renderer import Model
from renderer import ModelObject
from renderer import LightParameters
from renderer.geometry import rotation_matrix
from renderer import CameraParameters
from renderer import ShadowParameters
from renderer import Renderer, transpose_for_display
from numpngw import write_apng

# Load model and textures
obj_path, texture_path, spec_path = "african_head.obj", "african_head_diffuse.tga", "african_head_spec.tga"
image = Image.open(texture_path)
width, height = image.size
texture = onp.zeros((width, height, 3))
for y in range(height):
    for x in range(width):
        texture[y, x] = onp.array(image.getpixel((x, y)))
texture = jp.array(texture, dtype=jp.single) / 255

image = Image.open(spec_path)
specular_map = onp.zeros((width, height, 3))
for y in range(height):
    for x in range(width):
        specular_map[y, x] = onp.array(image.getpixel((x, y)))
specular_map = jp.array(specular_map, dtype=jp.single)[..., 0]

verts, norms, uv, faces, faces_norm, faces_uv = [], [], [], [], [], []
_float, _integer, _one_vertex = re.compile(r"(-?\d+\.?\d*(?:e[+-]\d+)?)"), re.compile(r"\d+"), re.compile(r"\d+/\d*/\d*")

with open(obj_path, 'r') as file:
    for line in file:
        if line.startswith("v "):
            verts.append(tuple(map(float, _float.findall(line, 2)[:3])))
        elif line.startswith("vn "):
            norms.append(tuple(map(float, _float.findall(line, 2)[:3])))
        elif line.startswith("vt "):
            uv.append(tuple(map(float, _float.findall(line, 2)[:2])))
        elif line.startswith("f "):
            face, face_norm, face_uv = [], [], []
            vertices = _one_vertex.findall(line)
            assert len(vertices) == 3, f"Expected 3 vertices, got {len(vertices)}"
            for vertex in vertices:
                v, vt, vn = list(map(int, _integer.findall(vertex)))
                face.append(v - 1)
                face_norm.append(vn - 1)
                face_uv.append(vt - 1)
            faces.append(face)
            faces_norm.append(face_norm)
            faces_uv.append(face_uv)

model = Model(
    verts=jp.array(verts),
    norms=jp.array(norms),
    uvs=jp.array(uv),
    faces=jp.array(faces),
    faces_norm=jp.array(faces_norm),
    faces_uv=jp.array(faces_uv),
    diffuse_map=jax.numpy.swapaxes(texture, 0, 1)[:, ::-1, :],
    specular_map=jax.numpy.swapaxes(specular_map, 0, 1)[:, ::-1],
)

canvas_width, canvas_height, frames, rotation_axis = 1920, 1080, 30, "Y"
rotation_axis = dict(X=(1., 0., 0.), Y=(0., 1., 0.), Z=(0., 0., 1.))[rotation_axis]
degrees = jax.lax.iota(float, frames) * 360. / frames

eye, center, up = jp.array((0, 0, 3.)), jp.array((0, 0, 0)), jp.array((0, 1, 0))
camera = CameraParameters(viewWidth=canvas_width, viewHeight=canvas_height, position=eye, target=center, up=up)
light = LightParameters(direction=jp.array([0.57735, -0.57735, 0.57735]), ambient=0.1, diffuse=0.85, specular=0.05)
shadow = ShadowParameters(centre=center)

@jax.default_matmul_precision("float32")
def render_instances(instances, width, height, camera, light, shadow):
    img = Renderer.get_camera_image(objects=instances, light=light, camera=camera, width=width, height=height, shadow_param=shadow, colour_default=jp.zeros(3, dtype=jp.single))
    return jax.lax.clamp(0., img, 1.)

def rotate(model, rotation_axis, degree):
    instance = ModelObject(model=model)
    return instance.replace_with_orientation(rotation_matrix=rotation_matrix(rotation_axis, degree))

batch_rotation = jax.jit(jax.vmap(lambda degree: rotate(model, rotation_axis, degree))).lower(degrees).compile()
instances = [batch_rotation(degrees)]

@jax.jit
def render(batched_instances):
    def _render(instances):
        _render = jax.jit(render_instances, static_argnames=("width", "height"), inline=True)
        img = _render(instances=instances, width=canvas_width, height=canvas_height, camera=camera, light=light, shadow=shadow)
        return transpose_for_display((img * 255).astype(jp.uint8))

    return jax.jit(jax.vmap(_render))(batched_instances)

render_compiled = jax.jit(render).lower(instances).compile()
images = list(map(onp.asarray, jax.device_get(render_compiled(instances))))

write_apng('animation.png', images, delay=1/30.)
# ffmpeg -i animation.png intermediate.gif
# gifsicle --optimize=3 --delay=5 intermediate.gif > output.gif

All these views were rendered in parallel using jax as the only dependency:

output

@markusheimerl
Copy link
Author

Thank you for your input. Godspeed on integrating Madrona.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants