Skip to content

Commit

Permalink
Add contact force support function. #1555
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 632591177
Change-Id: I15aa00686d9f885330506d2d63443c700cc029f6
  • Loading branch information
btaba authored and Copybara-Service committed May 10, 2024
1 parent 65d4f04 commit c6b1293
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 4 deletions.
40 changes: 40 additions & 0 deletions mjx/mujoco/mjx/_src/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,43 @@ def name2id(
}

return names_map.get(name, -1)


def _decode_pyramid(
pyramid: jax.Array, mu: jax.Array, condim: int
) -> jax.Array:
"""Converts pyramid representation to contact force."""
force = jp.zeros(6, dtype=float)
if condim == 1:
return force.at[0].set(pyramid[0])

# force_normal = sum(pyramid0_i + pyramid1_i)
force = force.at[0].set(pyramid[0 : 2 * (condim - 1)].sum())

# force_tangent_i = (pyramid0_i - pyramid1_i) * mu_i
i = np.arange(0, condim)
force = force.at[i + 1].set((pyramid[2 * i] - pyramid[2 * i + 1]) * mu[i])

return force


def contact_force(
m: Model, d: Data, contact_id: int, to_world_frame: bool = False
) -> jax.Array:
"""Extract 6D force:torque for one contact, in contact frame by default."""
efc_address = d.contact.efc_address[contact_id]
condim = d.contact.dim[contact_id]
if m.opt.cone == mujoco.mjtCone.mjCONE_PYRAMIDAL:
force = _decode_pyramid(
d.efc_force[efc_address:], d.contact.friction[contact_id], condim
)
elif m.opt.cone == mujoco.mjtCone.mjCONE_ELLIPTIC:
raise NotImplementedError('Elliptic cone force is not implemented yet.')
else:
raise ValueError(f'Unknown cone type: {m.opt.cone}')

if to_world_frame:
force = force.reshape((-1, 3)) @ d.contact.frame[contact_id]
force = force.reshape(-1)

return force * (efc_address >= 0)
59 changes: 59 additions & 0 deletions mjx/mujoco/mjx/_src/support_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,65 @@ def test_names_and_ids(self):
i = i if n is not None else -1
self.assertEqual(support.name2id(mx, obj, n), i)

_CONTACTS = """
<mujoco>
<worldbody>
<body pos="0 0 0.55" euler="1 0 0">
<joint axis="1 0 0" type="free"/>
<geom fromto="-0.4 0 0 0.4 0 0" size="0.05" type="capsule" condim="6"/>
</body>
<body pos="0 0 0.5" euler="0 1 0">
<joint axis="1 0 0" type="free"/>
<geom fromto="-0.4 0 0 0.4 0 0" size="0.05" type="capsule" condim="3"/>
</body>
<body pos="0 0 0.445" euler="0 90 0">
<joint axis="1 0 0" type="free"/>
<geom fromto="-0.4 0 0 0.4 0 0" size="0.05" type="capsule" condim="1"/>
</body>
</worldbody>
</mujoco>
"""

def test_contact_force(self):
m = mujoco.MjModel.from_xml_string(self._CONTACTS)
d = mujoco.MjData(m)
mujoco.mj_step(m, d)
assert (
np.unique(d.contact.geom).shape[0] == 3
), 'This test assumes all capsule are in contact.'
mx = mjx.put_model(m)
dx = mjx.put_data(m, d)
mujoco.mj_step(m, d)
dx = mjx.step(mx, dx)

# map MJX contacts to MJ ones
def _find(g):
val = (g == dx.contact.geom).sum(axis=1)
return np.where(val == 2)[0][0]

contact_id_map = {i: _find(d.contact.geom[i]) for i in range(d.ncon)}

for i in range(d.ncon):
result = np.zeros(6, dtype=float)
mujoco.mj_contactForce(m, d, i, result)

j = contact_id_map[i]
force = jax.jit(support.contact_force, static_argnums=(2,))(mx, dx, j)
np.testing.assert_allclose(result, force, rtol=1e-5, atol=2)

# test world conversion
force = jax.jit(
support.contact_force,
static_argnums=(
2,
3,
),
)(mx, dx, j, True)
# back to contact frame
force = force.at[:3].set(dx.contact.frame[j] @ force[:3])
force = force.at[3:].set(dx.contact.frame[j] @ force[3:])
np.testing.assert_allclose(result, force, rtol=1e-5, atol=2)


if __name__ == '__main__':
absltest.main()
4 changes: 0 additions & 4 deletions mjx/mujoco/mjx/test_data/convex.xml
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
<mujoco model="convex_convex">
<custom>
<numeric data="0.2" name="baumgarte_erp"/>
<numeric data="0.5" name="elasticity"/>
</custom>
<default>
<geom friction="0.5 0.0 0.0"/>
</default>
Expand Down

0 comments on commit c6b1293

Please sign in to comment.