Skip to content

Commit

Permalink
Check that element is a nullptr before casting.
Browse files Browse the repository at this point in the history
The bug was caused by the fact that the `mjs_as*` functions would return a pointer to a `mjC*` object, even if the input `mjsElement` pointer was null. This would cause a segfault.

PiperOrigin-RevId: 651802736
Change-Id: Ibec00d269a4befadff8a68e0513fd381c7e38a2a
  • Loading branch information
quagla authored and Copybara-Service committed Jul 12, 2024
1 parent e76f3ce commit 5d51b4c
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 23 deletions.
16 changes: 16 additions & 0 deletions python/mujoco/specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,5 +234,21 @@ def test_element_list(self):
self.assertEqual(spec.sensors[1].name, 'sensor2')
self.assertEqual(spec.sensors[2].name, 'sensor3')

def test_iterators(self):
spec = mujoco.MjSpec()
geom1 = spec.worldbody.add_geom()
geom2 = spec.worldbody.add_geom()
geom3 = spec.worldbody.add_geom()
geom1.name = 'geom1'
geom2.name = 'geom2'
geom3.name = 'geom3'
geom = spec.worldbody.first_geom()
i = 1
while geom:
self.assertEqual(geom.name, 'geom' + str(i))
geom = spec.worldbody.next_geom(geom)
i += 1


if __name__ == '__main__':
absltest.main()
115 changes: 92 additions & 23 deletions src/user/user_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -636,161 +636,230 @@ mjsElement* mjs_nextElement(mjSpec* s, mjsElement* element) {

// return body given mjsElement
mjsBody* mjs_asBody(mjsElement* element) {
return element->elemtype == mjOBJ_BODY ? &(static_cast<mjCBody*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_BODY) {
return &(static_cast<mjCBody*>(element)->spec);
}
return nullptr;
}



// return geom given mjsElement
mjsGeom* mjs_asGeom(mjsElement* element) {
return element->elemtype == mjOBJ_GEOM ? &(static_cast<mjCGeom*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_GEOM) {
return &(static_cast<mjCGeom*>(element)->spec);
}
return nullptr;
}



// return joint given mjsElement
mjsJoint* mjs_asJoint(mjsElement* element) {
return element->elemtype == mjOBJ_JOINT ? &(static_cast<mjCJoint*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_JOINT) {
return &(static_cast<mjCJoint*>(element)->spec);
}
return nullptr;
}



// Return site given mjsElement
mjsSite* mjs_asSite(mjsElement* element) {
return element->elemtype == mjOBJ_SITE ? &(static_cast<mjCSite*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_SITE) {
return &(static_cast<mjCSite*>(element)->spec);
}
return nullptr;
}



// return camera given mjsElement
mjsCamera* mjs_asCamera(mjsElement* element) {
return element->elemtype == mjOBJ_CAMERA ? &(static_cast<mjCCamera*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_CAMERA) {
return &(static_cast<mjCCamera*>(element)->spec);
}
return nullptr;
}



// return light given mjsElement
mjsLight* mjs_asLight(mjsElement* element) {
return element->elemtype == mjOBJ_LIGHT ? &(static_cast<mjCLight*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_LIGHT) {
return &(static_cast<mjCLight*>(element)->spec);
}
return nullptr;
}



// return frame given mjsElement
mjsFrame* mjs_asFrame(mjsElement* element) {
return element->elemtype == mjOBJ_FRAME ? &(static_cast<mjCFrame*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_FRAME) {
return &(static_cast<mjCFrame*>(element)->spec);
}
return nullptr;
}



// return actuator given mjsElement
mjsActuator* mjs_asActuator(mjsElement* element) {
return element->elemtype == mjOBJ_ACTUATOR ? &(static_cast<mjCActuator*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_ACTUATOR) {
return &(static_cast<mjCActuator*>(element)->spec);
}
return nullptr;
}



// return sensor given mjsElement
mjsSensor* mjs_asSensor(mjsElement* element) {
return element->elemtype == mjOBJ_SENSOR ? &(static_cast<mjCSensor*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_SENSOR) {
return &(static_cast<mjCSensor*>(element)->spec);
}
return nullptr;
}



// return flex given mjsElement
mjsFlex* mjs_asFlex(mjsElement* element) {
return element->elemtype == mjOBJ_FLEX ? &(static_cast<mjCFlex*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_FLEX) {
return &(static_cast<mjCFlex*>(element)->spec);
}
return nullptr;
}



// return pair given mjsElement
mjsPair* mjs_asPair(mjsElement* element) {
return element->elemtype == mjOBJ_PAIR ? &(static_cast<mjCPair*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_PAIR) {
return &(static_cast<mjCPair*>(element)->spec);
}
return nullptr;
}



// return equality given mjsElement
mjsEquality* mjs_asEquality(mjsElement* element) {
return element->elemtype == mjOBJ_EQUALITY ? &(static_cast<mjCEquality*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_EQUALITY) {
return &(static_cast<mjCEquality*>(element)->spec);
}
return nullptr;
}



// return exclude given mjsElement
mjsExclude* mjs_asExclude(mjsElement* element) {
return element->elemtype == mjOBJ_EXCLUDE ? &(static_cast<mjCBodyPair*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_EXCLUDE) {
return &(static_cast<mjCBodyPair*>(element)->spec);
}
return nullptr;
}



// return tendon given mjsElement
mjsTendon* mjs_asTendon(mjsElement* element) {
return element->elemtype == mjOBJ_TENDON ? &(static_cast<mjCTendon*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_TENDON) {
return &(static_cast<mjCTendon*>(element)->spec);
}
return nullptr;
}



// return numeric given mjsElement
mjsNumeric* mjs_asNumeric(mjsElement* element) {
return element->elemtype == mjOBJ_NUMERIC ? &(static_cast<mjCNumeric*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_NUMERIC) {
return &(static_cast<mjCNumeric*>(element)->spec);
}
return nullptr;
}



// return text given mjsElement
mjsText* mjs_asText(mjsElement* element) {
return element->elemtype == mjOBJ_TEXT ? &(static_cast<mjCText*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_TEXT) {
return &(static_cast<mjCText*>(element)->spec);
}
return nullptr;
}



// return tuple given mjsElement
mjsTuple* mjs_asTuple(mjsElement* element) {
return element->elemtype == mjOBJ_TUPLE ? &(static_cast<mjCTuple*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_TUPLE) {
return &(static_cast<mjCTuple*>(element)->spec);
}
return nullptr;
}



// return key given mjsElement
mjsKey* mjs_asKey(mjsElement* element) {
return element->elemtype == mjOBJ_KEY ? &(static_cast<mjCKey*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_KEY) {
return &(static_cast<mjCKey*>(element)->spec);
}
return nullptr;
}



// return mesh given mjsElement
mjsMesh* mjs_asMesh(mjsElement* element) {
return element->elemtype == mjOBJ_MESH ? &(static_cast<mjCMesh*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_MESH) {
return &(static_cast<mjCMesh*>(element)->spec);
}
return nullptr;
}



// return hfield given mjsElement
mjsHField* mjs_asHField(mjsElement* element) {
return element->elemtype == mjOBJ_HFIELD ? &(static_cast<mjCHField*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_HFIELD) {
return &(static_cast<mjCHField*>(element)->spec);
}
return nullptr;
}



// return skin given mjsElement
mjsSkin* mjs_asSkin(mjsElement* element) {
return element->elemtype == mjOBJ_SKIN ? &(static_cast<mjCSkin*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_SKIN) {
return &(static_cast<mjCSkin*>(element)->spec);
}
return nullptr;
}



// return texture given mjsElement
mjsTexture* mjs_asTexture(mjsElement* element) {
return element->elemtype == mjOBJ_TEXTURE ? &(static_cast<mjCTexture*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_TEXTURE) {
return &(static_cast<mjCTexture*>(element)->spec);
}
return nullptr;
}



// return material given mjsElement
mjsMaterial* mjs_asMaterial(mjsElement* element) {
return element->elemtype == mjOBJ_MATERIAL ? &(static_cast<mjCMaterial*>(element)->spec) : nullptr;
if (element && element->elemtype == mjOBJ_MATERIAL) {
return &(static_cast<mjCMaterial*>(element)->spec);
}
return nullptr;
}


Expand Down

0 comments on commit 5d51b4c

Please sign in to comment.