-
Notifications
You must be signed in to change notification settings - Fork 731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MJX False Collision in 3.1.4 and 3.1.5 only (no collision in MuJoCo) #1695
Comments
This appears to only be present in 3.1.4 and 3.1.5 -- if I explicitly specify version 3.1.3 when Hopefully this helps narrow down the search space of the bug (especially since it looks like some recent work was done to refactor collision checks and add condim support)! Also, the original example had |
Thanks @smurthas for the bug report, we have a fix that will be pushed out shortly! |
Thanks! I rolled back to Videos from MuJoCo -- no issue mujoco-10x-slowdown.mp4MJX -- box jumps mjx-10x-slowdown.mp4Python test case function repro (same as colab code): import mujoco
from mujoco import mjx
import jax
import mediapy
import numpy as np
block_only = """
<mujoco model="block_only">
<worldbody>
<camera name="side" pos="0. -.8 0.1" xyaxes="1 0 0 0 0 1" mode="trackcom"/>
<body name="starting_plane">
<geom type="box" size=".2 .2 .01" pos="0.0 0 -.02" rgba=".5 .8 .5 1"/>
</body>
<body name="box" pos="0 0 .03">
<freejoint/>
<geom type="box" size="0.04 0.04 0.04" rgba=".8 .8 .5 1"/>
</body>
</worldbody>
</mujoco>
"""
def test_block_only_env_mujoco_vs_mjx():
# init MuJoCo
mj_model = mujoco.MjModel.from_xml_string(block_only)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model, width=640, height=480)
mujoco.mj_resetData(mj_model, mj_data)
# init MJX from same model and data
jit_step = jax.jit(mjx.step)
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
mj_datas = []
mjx_mj_datas = []
# step forward to 0.25 seconds
for _ in range(125):
# step mujoco
mujoco.mj_step(mj_model, mj_data)
mj_datas.append(mj_data)
# step mjx and fetch data
mjx_data = jit_step(mjx_model, mjx_data)
mjx_mj_data = mjx.get_data(mj_model, mjx_data)
mjx_mj_datas.append(mjx_mj_data)
# render videos
mujoco_frames = []
mjx_frames = []
for mj_data, mjx_mj_data in zip(mj_datas, mjx_mj_datas):
# render MuJoCo
renderer.update_scene(mj_data, camera="side")
mujoco_frames.append(renderer.render())
# render MJX
renderer.update_scene(mjx_mj_data, camera="side")
mjx_frames.append(renderer.render())
# Display video at a 10x slowdown since it happens quickly
mediapy.show_video(mujoco_frames, fps=500/10)
mediapy.show_video(mjx_frames, fps=500/10)
# assert that the datas match
for mj_data, mjx_mj_data in zip(mj_datas, mjx_mj_datas):
np.testing.assert_allclose(
mj_data.xpos,
mjx_mj_data.xpos,
rtol=1e-3,
atol=1e-3,
)
test_block_only_env_mujoco_vs_mjx() I realize you are not going to try to fix issues that are no longer present, but I figured an isolated repro might be a useful regression test since the changes that introduced the false collision bug seems to have fixed a previous bug, so as you work to fix the false collision, it might help to ensure you don't inadvertently reintroduce this other bug. |
Nice, thank you for the clean repros @smurthas . Looks like we're not regressing to this bug, but please let us know if you find any issues! |
With the default settings of MJX and MuJoCo, this model and basic rollout produces a false collision in MJX that is (correctly) not present when the same thing is executed in MuJoCo directly.
Colab repro, the same code is also pasted below with model XML inline.
Here is a video of what this looks like (I just captured this video with me moving it by hand in the Mac MuJoCo GUI, but it looks about the same when rendered in python by adding a camera and colors):
bug.mov
Note that there is no visible collision -- the objects are well clear of each other.
Here is the minimal repro (same as the colab):
The text was updated successfully, but these errors were encountered: