/
smooth_test.py
165 lines (147 loc) · 5.42 KB
/
smooth_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for smooth dynamics functions."""
from absl.testing import absltest
import jax
import mujoco
from mujoco import mjx
from mujoco.mjx._src import test_util
import numpy as np
# tolerance for difference between MuJoCo and MJX smooth calculations - mostly
# due to float precision
_TOLERANCE = 5e-5
def _assert_eq(a, b, name):
tol = _TOLERANCE * 10 # avoid test noise
err_msg = f'mismatch: {name}'
np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)
def _assert_attr_eq(a, b, attr):
_assert_eq(getattr(a, attr), getattr(b, attr), attr)
class SmoothTest(absltest.TestCase):
def setUp(self):
super().setUp()
# although we already have generous padding of thresholds, it doesn't hurt
# to also fix the seed to reduce test flakiness
np.random.seed(0)
def test_smooth(self):
"""Tests MJX smooth functions match MuJoCo smooth functions."""
m = test_util.load_test_file('pendula.xml')
# # force MJX sparse for testing:
m.opt.jacobian = mujoco.mjtJacobian.mjJAC_SPARSE
d = mujoco.MjData(m)
# give the system a little kick to ensure we have non-identity rotations
d.qvel = np.random.random(m.nv)
mujoco.mj_step(m, d, 10) # let dynamics get state significantly non-zero
mujoco.mj_forward(m, d)
mx = mjx.put_model(m)
# kinematics
dx = jax.jit(mjx.kinematics)(mx, mjx.put_data(m, d))
_assert_attr_eq(d, dx, 'xanchor')
_assert_attr_eq(d, dx, 'xaxis')
_assert_attr_eq(d, dx, 'xpos')
_assert_attr_eq(d, dx, 'xquat')
_assert_eq(d.xmat.reshape((-1, 3, 3)), dx.xmat, 'xmat')
_assert_attr_eq(d, dx, 'xipos')
_assert_eq(d.ximat.reshape((-1, 3, 3)), dx.ximat, 'ximat')
_assert_attr_eq(d, dx, 'geom_xpos')
_assert_eq(d.geom_xmat.reshape((-1, 3, 3)), dx.geom_xmat, 'geom_xmat')
_assert_attr_eq(d, dx, 'site_xpos')
_assert_eq(d.site_xmat.reshape((-1, 3, 3)), dx.site_xmat, 'site_xmat')
# com_pos
dx = jax.jit(mjx.com_pos)(mx, mjx.put_data(m, d))
_assert_attr_eq(d, dx, 'subtree_com')
_assert_attr_eq(d, dx, 'cinert')
_assert_attr_eq(d, dx, 'cdof')
# crb
dx = jax.jit(mjx.crb)(mx, mjx.put_data(m, d))
_assert_attr_eq(d, dx, 'crb')
_assert_attr_eq(d, dx, 'qM')
# factor_m
dx = jax.jit(mjx.factor_m)(mx, mjx.put_data(m, d))
_assert_attr_eq(d, dx, 'qLD')
_assert_attr_eq(d, dx, 'qLDiagInv')
# com_vel
dx = jax.jit(mjx.com_vel)(mx, mjx.put_data(m, d))
_assert_attr_eq(d, dx, 'cvel')
_assert_attr_eq(d, dx, 'cdof_dot')
# rne
dx = jax.jit(mjx.rne)(mx, mjx.put_data(m, d))
_assert_attr_eq(d, dx, 'qfrc_bias')
# transmission
dx = jax.jit(mjx.transmission)(mx, mjx.put_data(m, d))
_assert_attr_eq(d, dx, 'actuator_length')
_assert_attr_eq(d, dx, 'actuator_moment')
# camlight
dx = jax.jit(mjx.camlight)(mx, mjx.put_data(m, d))
_assert_attr_eq(d, dx, 'cam_xpos')
_assert_eq(d.cam_xmat.reshape((-1, 3, 3)), dx.cam_xmat, 'cam_xmat')
def test_disable_gravity(self):
m = mujoco.MjModel.from_xml_string("""
<mujoco>
<option>
<flag gravity="disable"/>
</option>
<worldbody>
<body>
<joint type="free"/>
<geom size="0.1"/>
</body>
</worldbody>
</mujoco>
""")
d = mujoco.MjData(m)
mujoco.mj_forward(m, d)
mx = mjx.put_model(m)
dx = mjx.put_data(m, d)
dx = jax.jit(mjx.rne)(mx, dx)
np.testing.assert_allclose(dx.qfrc_bias, 0)
def test_site_transmission(self):
m = mujoco.MjModel.from_xml_string("""
<mujoco>
<compiler autolimits="true"/>
<worldbody>
<body>
<joint type="free"/>
<geom type="box" size=".05 .05 .05" mass="1"/>
<site name="site1"/>
<body>
<joint type="hinge"/>
<geom size="0.1" mass="1"/>
<site name="site2" pos="0.1 0.2 0.3"/>
</body>
</body>
<body pos="1 0 0">
<joint name="slide" type="hinge"/>
<geom type="box" size=".05 .05 .05" mass="1"/>
</body>
</worldbody>
<actuator>
<position site="site1" kv="0.1" gear="1 2 3 0 0 0"/>
<position site="site1" kv="0.2" gear="0 0 0 1 2 3"/>
<position site="site2" kv="0.3" gear="0 3 0 0 0 1"/>
<position joint="slide" kv="0.05" />
<position site="site2" refsite="site1" gear="1 2 3 0.5 0.4 0.6"/>
</actuator>
</mujoco>
""")
d = mujoco.MjData(m)
mujoco.mj_forward(m, d)
mx = mjx.put_model(m)
dx = mjx.put_data(m, d)
mujoco.mj_transmission(m, d)
dx = jax.jit(mjx.transmission)(mx, dx)
_assert_attr_eq(d, dx, 'actuator_length')
_assert_attr_eq(d, dx, 'actuator_moment')
if __name__ == '__main__':
absltest.main()