# MuJoCo to Pytorch3D

MuJoCo XLA provides a diferentiable framework for differentiable physics. However, MuJoCo rasterizer is not differentiable, so it is not possible to supervise our learning process on 2D images. In this notebook I write a Dataparser to convert a MuJoCo model into a structure that Pytorch3D can process. I show how this can be used for optimizing a physics simulation based on 2D supervision. 

In [1]:
import mujoco
from mujoco import mjx
import mediapy as media
import jax
import os
import jax.numpy as jnp
import sys
from PIL import Image
import numpy as np
import trimesh
from pytorch3d.structures import Meshes
import torch
from pytorch3d.transforms import matrix_to_quaternion, Transform3d
import torch.nn.functional as F
from pytorch3d.renderer import (
    FoVPerspectiveCameras, look_at_view_transform,
    RasterizationSettings, BlendParams,
    MeshRenderer, MeshRasterizer, HardPhongShader, SoftPhongShader, TexturesVertex
)
from pytorch3d.structures.meshes import Meshes
%env MUJOCO_GL=egl

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

env: MUJOCO_GL=egl
gpu


Define the Model

In [2]:
target_position = [10, 0, 1]

model_XML=f"""
<mujoco>
  <option iterations="100" solver="Newton"/>
  
  <asset>
    <texture type="skybox" builtin="gradient" rgb1=".3 .5 .7" rgb2="1 1 1" width="32" height="512"/>
    <texture name="grid" type="2d" builtin="checker" width="512" height="512" rgb1=".1 .2 .3" rgb2=".2 .3 .4"/>
    <material name="grid" texture="grid" texrepeat="2 2" texuniform="true" reflectance=".2"/>
  </asset>

  <visual>
    <global offheight="720" offwidth="1280"/>
    <quality offsamples="8"/>
  </visual>

  <asset>
  
    <mesh name="torus_complete" file="/home/mauro/Documents/Projects/JAX/assets/torus.obj" scale="2 2 2"/>
    
    <mesh name="torus0" file="/home/mauro/Documents/Projects/JAX/assets/obj/torus-dec/torus-dec_collision_0.obj" scale="2 2 2"/>
    <mesh name="torus1" file="/home/mauro/Documents/Projects/JAX/assets/obj/torus-dec/torus-dec_collision_1.obj" scale="2 2 2"/>
    <mesh name="torus2" file="/home/mauro/Documents/Projects/JAX/assets/obj/torus-dec/torus-dec_collision_2.obj" scale="2 2 2"/>
    <mesh name="torus3" file="/home/mauro/Documents/Projects/JAX/assets/obj/torus-dec/torus-dec_collision_3.obj" scale="2 2 2"/>
    <mesh name="torus4" file="/home/mauro/Documents/Projects/JAX/assets/obj/torus-dec/torus-dec_collision_4.obj" scale="2 2 2"/>
    <mesh name="torus5" file="/home/mauro/Documents/Projects/JAX/assets/obj/torus-dec/torus-dec_collision_5.obj" scale="2 2 2"/>
    <mesh name="torus6" file="/home/mauro/Documents/Projects/JAX/assets/obj/torus-dec/torus-dec_collision_6.obj" scale="2 2 2"/>
    <mesh name="torus7" file="/home/mauro/Documents/Projects/JAX/assets/obj/torus-dec/torus-dec_collision_7.obj" scale="2 2 2"/>
    <mesh name="torus8" file="/home/mauro/Documents/Projects/JAX/assets/obj/torus-dec/torus-dec_collision_8.obj" scale="2 2 2"/>
    <mesh name="torus9" file="/home/mauro/Documents/Projects/JAX/assets/obj/torus-dec/torus-dec_collision_9.obj" scale="2 2 2"/>
    <mesh name="torus10" file="/home/mauro/Documents/Projects/JAX/assets/obj/torus-dec/torus-dec_collision_10.obj" scale="2 2 2"/>
    
    <mesh name="sphere" file="/home/mauro/Documents/Projects/JAX/assets/obj/sphere.obj" scale="0.2 0.2 0.2"/>
    
  </asset>

  <worldbody>
    <light name="top" pos="1 0 0"/>
    
    <camera name="cam1" pos="5 -15 2.5" xyaxes="1 0 0 0 0 1"/>
    <camera name="cam2" pos="10 0 15" xyaxes="1 0 0 0 1 0"/>
    
    
    <body>
      <freejoint/>
      <geom mass="5000" type="mesh" mesh="sphere" rgba="1 0 0 1"/>
    </body>

    <body pos="{target_position[0]} {target_position[1]} {target_position[2]}">
        
        <geom type="mesh" mesh="torus_complete" contype="0" conaffinity="0"/>
        <!-- Keeping the convex hulls for collision detection -->
        <geom type="mesh" mesh="torus0" conaffinity="1" contype="1" rgba="0 0 0 0" />
        <geom type="mesh" mesh="torus1" conaffinity="1" contype="1" rgba="0 0 0 0" />
        <geom type="mesh" mesh="torus2" conaffinity="1" contype="1" rgba="0 0 0 0" />
        <geom type="mesh" mesh="torus3" conaffinity="1" contype="1" rgba="0 0 0 0" />
        <geom type="mesh" mesh="torus4" conaffinity="1" contype="1" rgba="0 0 0 0" />
        <geom type="mesh" mesh="torus5" conaffinity="1" contype="1" rgba="0 0 0 0" />
        <geom type="mesh" mesh="torus6" conaffinity="1" contype="1" rgba="0 0 0 0" />
        <geom type="mesh" mesh="torus7" conaffinity="1" contype="1" rgba="0 0 0 0" />
        <geom type="mesh" mesh="torus8" conaffinity="1" contype="1" rgba="0 0 0 0" />
        <geom type="mesh" mesh="torus9" conaffinity="1" contype="1" rgba="0 0 0 0" />
        <geom type="mesh" mesh="torus10" conaffinity="1" contype="1" rgba="0 0 0 0" />
    </body>

  </worldbody>
  
</mujoco>
"""

Run the simulation to collect GT data

In [3]:


########################## Initial params ##################################
vel = jnp.array([6., 0., 10.])
timesteps = 1500
fps = 48
width = 1280
height = 720
############################################################################

# print(os.environ)
# Instantiate models and data
mj_model = mujoco.MjModel.from_xml_string(model_XML)
mjx_model = mjx.put_model(mj_model)

def init_simulation(vel, mjx_model):
    mjx_data = mjx.make_data(mjx_model)
    qvel = mjx_data.qvel.at[0].set(vel[0]).at[1].set(vel[1]).at[2].set(vel[2])
    mjx_data = mjx_data.replace(qvel=qvel)
    return mjx_data

jit_step = jax.jit(mjx.step)

def sim(i, mjx_data):
    mjx_data = jit_step(mjx_model, mjx_data)
    return mjx_data

def run_simulation(mj_model, mjx_model, mjx_data):
    
    render_interval = int((1 / mj_model.opt.timestep) / fps)
    print(f"Render interval: {render_interval}")

    frames = {'cam1': [], 'cam2': []}
    
    with mujoco.Renderer(mj_model, height=height, width=width) as renderer:
        for t in range(timesteps):
            mjx_data = sim(t, mjx_data)
            # MuJoCo render
            if t % render_interval == 0:
                mj_data = mjx.get_data(mj_model, mjx_data)
                renderer.update_scene(mj_data, camera="cam1")
                pixels = renderer.render()   
                frames['cam1'].append(pixels)

                renderer.update_scene(mj_data, camera="cam2")
                pixels = renderer.render()   
                frames['cam2'].append(pixels)
                # im = Image.fromarray(pixels)
                # im.show()
    return frames
    
mjx_data = init_simulation(vel, mjx_model)
frames=run_simulation(mj_model, mjx_model, mjx_data) 



# Save videos
# media.write_video('/home/mauro/Downloads/cam1.mp4', frames['cam1'], fps=fps)
# media.write_video('/home/mauro/Downloads/cam2.mp4', frames['cam2'], fps=fps)

# gt_image = frames['cam1'][-1]
# im = Image.fromarray(gt_image)
# im.show()
# media.show_image(frames['cam1'][-1])




Render interval: 10


  x = np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


In [4]:
media.show_video(frames['cam1'], fps=fps)
media.show_video(frames['cam2'], fps=fps)

0
This browser does not support the video tag.


0
This browser does not support the video tag.


In [None]:
class DataParser:
    """data parsing from MuJoCo XLA to Pytorch3D"""
    def __init__(self, mjx_model, mjx_data):
        self.model = mjx_model
        self.data = mjx_data

    def parse(self):
        
        
        
    
        

In [114]:
renderer = mujoco.Renderer(mj_model, height=height, width=width)
mj_data = mjx.get_data(mj_model, mjx_data)
renderer.update_scene(mj_data, camera='cam1')

print(renderer)

<mujoco.renderer.Renderer object at 0x7b468d66d670>


In [24]:
def pad_tensor(values_all, element, max_values=None):
    """

    Args:
        values_all (list): it contains
    """
    assert element == 'verts' or element == 'faces', 'Element must be verts or faces'
    pad_value = 0 if element == 'verts' else -1
    if not max_values:
        len_values = [verts.shape[0] for values in values_all]
        max_values = max(len_values)
    
    padded_values = []
    for values in values_all:
        padding = (0, 0, 0, max_values - values.shape[0])   # For a 2D tensor, pad=(pad_left, pad_right, pad_top, pad_bottom)
        padded_values.append(F.pad(values, padding, "constant", pad_value))
    
    # Stack all padded tensors into a single tensor
    padded_values_tensor = torch.stack(padded_values)
    
    return padded_values_tensor
    
def jnp_to_torch(jnp_array):
    return torch.from_numpy(jnp_array.__array__().copy())
    
MESH_IDX = 7

verts_all = []
faces_all = []
# Loop through all geometries
for geom_idx in range(mjx_model.ngeom):

    if mjx_model.geom_type[geom_idx] == MESH_IDX:

        verts = mj_model.mesh_vert[mjx_model.mesh_vertadr[geom_idx] : mjx_model.mesh_vertadr[geom_idx] + mjx_model.mesh_vertnum[geom_idx], :]
        verts = jnp_to_torch(verts)
        
        faces = mj_model.mesh_face[mjx_model.mesh_faceadr[geom_idx] : mjx_model.mesh_faceadr[geom_idx] + mj_model.mesh_facenum[geom_idx], :]
        # faces = faces - faces.min()
        faces = jnp_to_torch(faces)
        
        # Get the world-space position of this geometry at the current timestep
        geom_pos = mjx_data.geom_xpos[geom_idx]  # (x, y, z) position for this geometry
        geom_pos = jnp_to_torch(geom_pos)
    
        # Get the world-space position of this geometry at the current timestep
        geom_mat = mjx_data.geom_xmat[geom_idx]  # (x, y, z) position for this geometry
        geom_mat = jnp_to_torch(geom_mat)
           
        # Now you have both the position and orientation (as a quaternion) for this geometry
        print(f"Geometry {geom_idx}: Position {geom_pos}, Rotation {geom_mat}")
    
        # Apply the translation and rotation transformation to your mesh or geometry
        transform = Transform3d().translate(*geom_pos).rotate(geom_mat)
        
        # If rendering a mesh, apply the transformation to its vertices
        # transformed_verts = transform.transform_points(verts)

        # verts_all.extend(verts)
        # faces_all.extend(faces)  
        verts_all.append(verts)
        faces_all.append(faces) 

# Initialize an OpenGL perspective camera.
R, T = look_at_view_transform(dist=2, elev=10, azim=80)
cameras = FoVPerspectiveCameras(device='cuda', R=R, T=T)

# Define the settings for rasterization and shading. Here we set the output image to be of size
# 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1
# and blur_radius=0.0. Refer to rasterize_meshes.py for explanations of these parameters.
raster_settings = RasterizationSettings(
    image_size=512,
    blur_radius=0.0,
    faces_per_pixel=1,
)

# Create a Phong renderer by composing a rasterizer and a shader. Here we can use a predefined
# PhongShader, passing in the device on which to initialize the default parameters
renderer = MeshRenderer(
    rasterizer=MeshRasterizer(cameras=cameras,raster_settings=raster_settings),
    shader=HardPhongShader(device='cuda', cameras=cameras)
)

verts_tensor = pad_tensor(verts_all, 'verts')
faces_tensor = pad_tensor(faces_all, 'faces')

# verts_tensor = torch.stack(verts_all)[None]
# faces_tensor = torch.stack(faces_all)[None]

verts_rgb = torch.ones_like(verts_tensor).cuda()
textures = TexturesVertex(verts_features=verts_rgb)
mesh = Meshes(verts=verts_tensor.cuda(), faces=faces_tensor.cuda(), textures=textures)
output = renderer(mesh)

image = np.array(output.cpu()[12]*255, dtype=np.uint8)
Image.fromarray(image).show()

# Render or visualize the transformed mesh as needed
# for verts, faces in zip(verts_all, faces_all):
#     verts_rgb = torch.ones_like(verts)[None].cuda() # (1, V, 3)
#     textures = TexturesVertex(verts_features=verts_rgb)
#     mesh = Meshes(verts=verts[None].cuda(), faces=faces[None].cuda(), textures=textures)
#     output = renderer(mesh)

#     image = np.array(output.cpu()[0]*255, dtype=np.uint8)
#     Image.fromarray(image)
    

AttributeError: 'Model' object has no attribute 'mesh_facenum'

In [31]:
mjx_data.geom_xpos

Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

In [124]:
print(mjx_model.mesh_vertadr)
print(mjx_model.mesh_vertnum)
print(mjx_model.mesh_vert)
print(mjx_model.geom_dataid)
print(mjx_model.geom_type)
print(mjx_model.mesh_face)
print(mjx_model.mesh_faceadr)
print(mj_model.mesh_facenum)

[   0  576  627  696  780  865  957 1003 1047 1096 1141 1196 1281]
[576  51  69  84  85  92  46  44  49  45  55  85 162]
[[-2.8272648e-10 -9.2864203e-01 -2.3211255e+00]
 [ 2.5000000e-01 -9.0375888e-01 -2.2589304e+00]
 [ 4.3301201e-01 -8.3577782e-01 -2.0890131e+00]
 ...
 [-2.1652332e-01 -1.6377217e-01 -4.3069282e-01]
 [-3.5036230e-01 -1.0314891e-01 -3.5469219e-01]
 [-2.9920775e-01 -2.5700620e-01 -3.2190454e-01]]
[12  0  1  2  3  4  5  6  7  8  9 10 11]
[7 7 7 7 7 7 7 7 7 7 7 7 7]
[[  0  12  13]
 [  0  13   1]
 [  1  13  14]
 ...
 [ 13 159 160]
 [ 14 160 161]
 [ 12 161 159]]
[   0 1152 1250 1384 1548 1714 1894 1982 2066 2160 2246 2352 2518]
[1152   98  134  164  166  180   88   84   94   86  106  166  320]
