From 5d51b4cd5fb4a8edc8f045baf44ae647e2b45542 Mon Sep 17 00:00:00 2001 From: Alessio Quaglino Date: Fri, 12 Jul 2024 09:45:18 -0700 Subject: [PATCH] Check that element is a nullptr before casting. The bug was caused by the fact that the `mjs_as*` functions would return a pointer to a `mjC*` object, even if the input `mjsElement` pointer was null. This would cause a segfault. PiperOrigin-RevId: 651802736 Change-Id: Ibec00d269a4befadff8a68e0513fd381c7e38a2a --- python/mujoco/specs_test.py | 16 +++++ src/user/user_api.cc | 115 ++++++++++++++++++++++++++++-------- 2 files changed, 108 insertions(+), 23 deletions(-) diff --git a/python/mujoco/specs_test.py b/python/mujoco/specs_test.py index aae10380ce..0f36653c2c 100644 --- a/python/mujoco/specs_test.py +++ b/python/mujoco/specs_test.py @@ -234,5 +234,21 @@ def test_element_list(self): self.assertEqual(spec.sensors[1].name, 'sensor2') self.assertEqual(spec.sensors[2].name, 'sensor3') + def test_iterators(self): + spec = mujoco.MjSpec() + geom1 = spec.worldbody.add_geom() + geom2 = spec.worldbody.add_geom() + geom3 = spec.worldbody.add_geom() + geom1.name = 'geom1' + geom2.name = 'geom2' + geom3.name = 'geom3' + geom = spec.worldbody.first_geom() + i = 1 + while geom: + self.assertEqual(geom.name, 'geom' + str(i)) + geom = spec.worldbody.next_geom(geom) + i += 1 + + if __name__ == '__main__': absltest.main() diff --git a/src/user/user_api.cc b/src/user/user_api.cc index 62e0bd5e23..454e3199d5 100644 --- a/src/user/user_api.cc +++ b/src/user/user_api.cc @@ -636,161 +636,230 @@ mjsElement* mjs_nextElement(mjSpec* s, mjsElement* element) { // return body given mjsElement mjsBody* mjs_asBody(mjsElement* element) { - return element->elemtype == mjOBJ_BODY ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_BODY) { + return &(static_cast(element)->spec); + } + return nullptr; } // return geom given mjsElement mjsGeom* mjs_asGeom(mjsElement* element) { - return element->elemtype == mjOBJ_GEOM ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_GEOM) { + return &(static_cast(element)->spec); + } + return nullptr; } // return joint given mjsElement mjsJoint* mjs_asJoint(mjsElement* element) { - return element->elemtype == mjOBJ_JOINT ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_JOINT) { + return &(static_cast(element)->spec); + } + return nullptr; } // Return site given mjsElement mjsSite* mjs_asSite(mjsElement* element) { - return element->elemtype == mjOBJ_SITE ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_SITE) { + return &(static_cast(element)->spec); + } + return nullptr; } // return camera given mjsElement mjsCamera* mjs_asCamera(mjsElement* element) { - return element->elemtype == mjOBJ_CAMERA ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_CAMERA) { + return &(static_cast(element)->spec); + } + return nullptr; } // return light given mjsElement mjsLight* mjs_asLight(mjsElement* element) { - return element->elemtype == mjOBJ_LIGHT ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_LIGHT) { + return &(static_cast(element)->spec); + } + return nullptr; } // return frame given mjsElement mjsFrame* mjs_asFrame(mjsElement* element) { - return element->elemtype == mjOBJ_FRAME ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_FRAME) { + return &(static_cast(element)->spec); + } + return nullptr; } // return actuator given mjsElement mjsActuator* mjs_asActuator(mjsElement* element) { - return element->elemtype == mjOBJ_ACTUATOR ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_ACTUATOR) { + return &(static_cast(element)->spec); + } + return nullptr; } // return sensor given mjsElement mjsSensor* mjs_asSensor(mjsElement* element) { - return element->elemtype == mjOBJ_SENSOR ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_SENSOR) { + return &(static_cast(element)->spec); + } + return nullptr; } // return flex given mjsElement mjsFlex* mjs_asFlex(mjsElement* element) { - return element->elemtype == mjOBJ_FLEX ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_FLEX) { + return &(static_cast(element)->spec); + } + return nullptr; } // return pair given mjsElement mjsPair* mjs_asPair(mjsElement* element) { - return element->elemtype == mjOBJ_PAIR ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_PAIR) { + return &(static_cast(element)->spec); + } + return nullptr; } // return equality given mjsElement mjsEquality* mjs_asEquality(mjsElement* element) { - return element->elemtype == mjOBJ_EQUALITY ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_EQUALITY) { + return &(static_cast(element)->spec); + } + return nullptr; } // return exclude given mjsElement mjsExclude* mjs_asExclude(mjsElement* element) { - return element->elemtype == mjOBJ_EXCLUDE ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_EXCLUDE) { + return &(static_cast(element)->spec); + } + return nullptr; } // return tendon given mjsElement mjsTendon* mjs_asTendon(mjsElement* element) { - return element->elemtype == mjOBJ_TENDON ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_TENDON) { + return &(static_cast(element)->spec); + } + return nullptr; } // return numeric given mjsElement mjsNumeric* mjs_asNumeric(mjsElement* element) { - return element->elemtype == mjOBJ_NUMERIC ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_NUMERIC) { + return &(static_cast(element)->spec); + } + return nullptr; } // return text given mjsElement mjsText* mjs_asText(mjsElement* element) { - return element->elemtype == mjOBJ_TEXT ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_TEXT) { + return &(static_cast(element)->spec); + } + return nullptr; } // return tuple given mjsElement mjsTuple* mjs_asTuple(mjsElement* element) { - return element->elemtype == mjOBJ_TUPLE ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_TUPLE) { + return &(static_cast(element)->spec); + } + return nullptr; } // return key given mjsElement mjsKey* mjs_asKey(mjsElement* element) { - return element->elemtype == mjOBJ_KEY ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_KEY) { + return &(static_cast(element)->spec); + } + return nullptr; } // return mesh given mjsElement mjsMesh* mjs_asMesh(mjsElement* element) { - return element->elemtype == mjOBJ_MESH ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_MESH) { + return &(static_cast(element)->spec); + } + return nullptr; } // return hfield given mjsElement mjsHField* mjs_asHField(mjsElement* element) { - return element->elemtype == mjOBJ_HFIELD ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_HFIELD) { + return &(static_cast(element)->spec); + } + return nullptr; } // return skin given mjsElement mjsSkin* mjs_asSkin(mjsElement* element) { - return element->elemtype == mjOBJ_SKIN ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_SKIN) { + return &(static_cast(element)->spec); + } + return nullptr; } // return texture given mjsElement mjsTexture* mjs_asTexture(mjsElement* element) { - return element->elemtype == mjOBJ_TEXTURE ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_TEXTURE) { + return &(static_cast(element)->spec); + } + return nullptr; } // return material given mjsElement mjsMaterial* mjs_asMaterial(mjsElement* element) { - return element->elemtype == mjOBJ_MATERIAL ? &(static_cast(element)->spec) : nullptr; + if (element && element->elemtype == mjOBJ_MATERIAL) { + return &(static_cast(element)->spec); + } + return nullptr; }