-
Notifications
You must be signed in to change notification settings - Fork 728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for Rendering in MJX for Simulated Camera Reinforcement Learning #1682
Comments
You are 100% right. |
@yuvaltassa Would you mind making the corresponding feature branch public in this repo? Id love to contribute. |
The moment we have something that works it will be OSS and we would love you to contribute! @erikfrey is leading this effort, perhaps there is something he'd like to add. |
Hello! Please see #1604 and #1485 for related discussions. You can do visual observations today using We are working on integrating Madrona as a means for high throughput tiled rendering on GPU, but this is still very much a work in progress. We'll share more once we have a good proof of concept - no ETA but this is actively under development. |
@yuvaltassa @erikfrey thanks for your input! Have you considered https://github.com/JoeyTeng/jaxrenderer ? import re
import jax
import numpy as onp
from PIL import Image
from jax import numpy as jp
from renderer import Model
from renderer import ModelObject
from renderer import LightParameters
from renderer.geometry import rotation_matrix
from renderer import CameraParameters
from renderer import ShadowParameters
from renderer import Renderer, transpose_for_display
from numpngw import write_apng
# Load model and textures
obj_path, texture_path, spec_path = "african_head.obj", "african_head_diffuse.tga", "african_head_spec.tga"
image = Image.open(texture_path)
width, height = image.size
texture = onp.zeros((width, height, 3))
for y in range(height):
for x in range(width):
texture[y, x] = onp.array(image.getpixel((x, y)))
texture = jp.array(texture, dtype=jp.single) / 255
image = Image.open(spec_path)
specular_map = onp.zeros((width, height, 3))
for y in range(height):
for x in range(width):
specular_map[y, x] = onp.array(image.getpixel((x, y)))
specular_map = jp.array(specular_map, dtype=jp.single)[..., 0]
verts, norms, uv, faces, faces_norm, faces_uv = [], [], [], [], [], []
_float, _integer, _one_vertex = re.compile(r"(-?\d+\.?\d*(?:e[+-]\d+)?)"), re.compile(r"\d+"), re.compile(r"\d+/\d*/\d*")
with open(obj_path, 'r') as file:
for line in file:
if line.startswith("v "):
verts.append(tuple(map(float, _float.findall(line, 2)[:3])))
elif line.startswith("vn "):
norms.append(tuple(map(float, _float.findall(line, 2)[:3])))
elif line.startswith("vt "):
uv.append(tuple(map(float, _float.findall(line, 2)[:2])))
elif line.startswith("f "):
face, face_norm, face_uv = [], [], []
vertices = _one_vertex.findall(line)
assert len(vertices) == 3, f"Expected 3 vertices, got {len(vertices)}"
for vertex in vertices:
v, vt, vn = list(map(int, _integer.findall(vertex)))
face.append(v - 1)
face_norm.append(vn - 1)
face_uv.append(vt - 1)
faces.append(face)
faces_norm.append(face_norm)
faces_uv.append(face_uv)
model = Model(
verts=jp.array(verts),
norms=jp.array(norms),
uvs=jp.array(uv),
faces=jp.array(faces),
faces_norm=jp.array(faces_norm),
faces_uv=jp.array(faces_uv),
diffuse_map=jax.numpy.swapaxes(texture, 0, 1)[:, ::-1, :],
specular_map=jax.numpy.swapaxes(specular_map, 0, 1)[:, ::-1],
)
canvas_width, canvas_height, frames, rotation_axis = 1920, 1080, 30, "Y"
rotation_axis = dict(X=(1., 0., 0.), Y=(0., 1., 0.), Z=(0., 0., 1.))[rotation_axis]
degrees = jax.lax.iota(float, frames) * 360. / frames
eye, center, up = jp.array((0, 0, 3.)), jp.array((0, 0, 0)), jp.array((0, 1, 0))
camera = CameraParameters(viewWidth=canvas_width, viewHeight=canvas_height, position=eye, target=center, up=up)
light = LightParameters(direction=jp.array([0.57735, -0.57735, 0.57735]), ambient=0.1, diffuse=0.85, specular=0.05)
shadow = ShadowParameters(centre=center)
@jax.default_matmul_precision("float32")
def render_instances(instances, width, height, camera, light, shadow):
img = Renderer.get_camera_image(objects=instances, light=light, camera=camera, width=width, height=height, shadow_param=shadow, colour_default=jp.zeros(3, dtype=jp.single))
return jax.lax.clamp(0., img, 1.)
def rotate(model, rotation_axis, degree):
instance = ModelObject(model=model)
return instance.replace_with_orientation(rotation_matrix=rotation_matrix(rotation_axis, degree))
batch_rotation = jax.jit(jax.vmap(lambda degree: rotate(model, rotation_axis, degree))).lower(degrees).compile()
instances = [batch_rotation(degrees)]
@jax.jit
def render(batched_instances):
def _render(instances):
_render = jax.jit(render_instances, static_argnames=("width", "height"), inline=True)
img = _render(instances=instances, width=canvas_width, height=canvas_height, camera=camera, light=light, shadow=shadow)
return transpose_for_display((img * 255).astype(jp.uint8))
return jax.jit(jax.vmap(_render))(batched_instances)
render_compiled = jax.jit(render).lower(instances).compile()
images = list(map(onp.asarray, jax.device_get(render_compiled(instances))))
write_apng('animation.png', images, delay=1/30.)
# ffmpeg -i animation.png intermediate.gif
# gifsicle --optimize=3 --delay=5 intermediate.gif > output.gif All these views were rendered in parallel using jax as the only dependency: |
Thank you for your input. Godspeed on integrating Madrona. |
results in:
Description:
I'm encountering an issue while using MuJoCo with JAX (mjx) for training a humanoid model in the Brax environment. The problem arises when attempting to render the environment state and retrieve camera images during training. The MuJoCo renderer does not seem to work properly when using mjx.
Problem Details:
When calling the
render
method within the customHumanoid
class, an error is thrown:jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[28]
.This error indicates that the conversion method is called on a traced array, which depends on the value of the argument
state.pipeline_state.q
. The current implementation of the MuJoCo renderer in mjx does not handle the conversion of traced arrays properly in this context.Importance of Camera Input in Reinforcement Learning:
Using camera input is crucial when training robots with reinforcement learning. In real-world scenarios, robots rely on visual information captured by their cameras to perceive and interact with the environment. By incorporating camera pixels as part of the observation space during training, the learned policies can be more robust and adaptable to real-world conditions.
Proposed Solution:
To enable effective training with camera input using mjx, it is essential to address the compatibility issue between the MuJoCo renderer and JAX traced arrays. Possible solutions include:
Alternatives Considered:
Additional Context:
Integrating camera input into reinforcement learning algorithms is crucial for developing intelligent and adaptable robots. By leveraging the power of JAX and mjx, researchers and developers can accelerate the training process and build more sophisticated models. However, the current compatibility issue between the MuJoCo renderer and JAX traced arrays hinders the effective utilization of camera input in this setup.
Addressing this issue and providing a seamless integration between the MuJoCo renderer and mjx will greatly benefit the robotics and reinforcement learning community. It will enable researchers to train models that can effectively process visual information, leading to more advanced and capable robots.
Thank you for considering this issue. Your support in resolving the compatibility problem and enhancing the usability of camera input with mjx will contribute to the advancement of robotics research and real-world applications.
The text was updated successfully, but these errors were encountered: