From 9106f40e2bd70aea6e9c7bb2042a3132f6f88d2c Mon Sep 17 00:00:00 2001 From: Alessio Quaglino Date: Mon, 1 Jul 2024 12:04:52 -0700 Subject: [PATCH] Add mjSpec bindings. Co-authored-by: Saran Tunyasuvunakool PiperOrigin-RevId: 648444641 Change-Id: I08ee1d3f4fae1efbc7934ba94702b1c06d8cb92a --- doc/APIreference/functions.rst | 63 ++ doc/changelog.rst | 1 - doc/includes/references.h | 7 + include/mujoco/mujoco.h | 21 + introspect/codegen/generate_structs.py | 12 +- introspect/functions.py | 112 +++ python/MANIFEST.in | 2 +- python/make_sdist.sh | 2 + python/mujoco/CMakeLists.txt | 25 + python/mujoco/__init__.py | 1 + .../mujoco/codegen/generate_spec_bindings.py | 210 +++++ python/mujoco/raw.h | 30 + python/mujoco/specs.cc | 768 ++++++++++++++++++ python/mujoco/specs_test.py | 222 +++++ python/mujoco/structs.cc | 34 + python/mujoco/structs.h | 17 +- python/setup.py | 1 + src/user/user_api.cc | 49 ++ src/user/user_api.h | 21 + 19 files changed, 1588 insertions(+), 10 deletions(-) create mode 100644 python/mujoco/codegen/generate_spec_bindings.py create mode 100644 python/mujoco/specs.cc create mode 100644 python/mujoco/specs_test.py diff --git a/doc/APIreference/functions.rst b/doc/APIreference/functions.rst index 87a8d02b32..e90a46db79 100644 --- a/doc/APIreference/functions.rst +++ b/doc/APIreference/functions.rst @@ -4167,6 +4167,69 @@ mjs_nextChild Return body's next child of the same type; return NULL if child is last. +.. _mjs_asBody: + +mjs_asBody +~~~~~~~~~~ + +.. mujoco-include:: mjs_asBody + +Safely cast an element as mjsBody, or return NULL if the element is not an mjsBody. + +.. _mjs_asGeom: + +mjs_asGeom +~~~~~~~~~~ + +.. mujoco-include:: mjs_asGeom + +Safely cast an element as mjsGeom, or return NULL if the element is not an mjsGeom. + +.. _mjs_asJoint: + +mjs_asJoint +~~~~~~~~~~~ + +.. mujoco-include:: mjs_asJoint + +Safely cast an element as mjsJoint, or return NULL if the element is not an mjsJoint. + +.. _mjs_asSite: + +mjs_asSite +~~~~~~~~~~ + +.. mujoco-include:: mjs_asSite + +Safely cast an element as mjsSite, or return NULL if the element is not an mjsSite. + +.. _mjs_asCamera: + +mjs_asCamera +~~~~~~~~~~~~ + +.. mujoco-include:: mjs_asCamera + +Safely cast an element as mjsCamera, or return NULL if the element is not an mjsCamera. + +.. _mjs_asLight: + +mjs_asLight +~~~~~~~~~~~ + +.. mujoco-include:: mjs_asLight + +Safely cast an element as mjsLight, or return NULL if the element is not an mjsLight. + +.. _mjs_asFrame: + +mjs_asFrame +~~~~~~~~~~~ + +.. mujoco-include:: mjs_asFrame + +Safely cast an element as mjsFrame, or return NULL if the element is not an mjsFrame. + .. _AttributeSetters: Attribute setters diff --git a/doc/changelog.rst b/doc/changelog.rst index ae29a32343..541c3c2283 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -26,7 +26,6 @@ General Still missing: - Detailed documentation. - - Python bindings. .. youtube:: ZXBTEIDWHhs :align: right diff --git a/doc/includes/references.h b/doc/includes/references.h index 6cbecfd432..ba87e8fab1 100644 --- a/doc/includes/references.h +++ b/doc/includes/references.h @@ -3544,6 +3544,13 @@ mjsDefault* mjs_getSpecDefault(mjSpec* s); int mjs_getId(mjsElement* element); mjsElement* mjs_firstChild(mjsBody* body, mjtObj type); mjsElement* mjs_nextChild(mjsBody* body, mjsElement* child); +mjsBody* mjs_asBody(mjsElement* element); +mjsGeom* mjs_asGeom(mjsElement* element); +mjsJoint* mjs_asJoint(mjsElement* element); +mjsSite* mjs_asSite(mjsElement* element); +mjsCamera* mjs_asCamera(mjsElement* element); +mjsLight* mjs_asLight(mjsElement* element); +mjsFrame* mjs_asFrame(mjsElement* element); void mjs_setString(mjString* dest, const char* text); void mjs_setStringVec(mjStringVec* dest, const char* text); mjtByte mjs_setInStringVec(mjStringVec* dest, int i, const char* text); diff --git a/include/mujoco/mujoco.h b/include/mujoco/mujoco.h index 4a45bc7249..0964111029 100644 --- a/include/mujoco/mujoco.h +++ b/include/mujoco/mujoco.h @@ -1548,6 +1548,27 @@ MJAPI mjsElement* mjs_firstChild(mjsBody* body, mjtObj type); // Return body's next child of the same type; return NULL if child is last. MJAPI mjsElement* mjs_nextChild(mjsBody* body, mjsElement* child); +// Safely cast an element as mjsBody, or return NULL if the element is not an mjsBody. +MJAPI mjsBody* mjs_asBody(mjsElement* element); + +// Safely cast an element as mjsGeom, or return NULL if the element is not an mjsGeom. +MJAPI mjsGeom* mjs_asGeom(mjsElement* element); + +// Safely cast an element as mjsJoint, or return NULL if the element is not an mjsJoint. +MJAPI mjsJoint* mjs_asJoint(mjsElement* element); + +// Safely cast an element as mjsSite, or return NULL if the element is not an mjsSite. +MJAPI mjsSite* mjs_asSite(mjsElement* element); + +// Safely cast an element as mjsCamera, or return NULL if the element is not an mjsCamera. +MJAPI mjsCamera* mjs_asCamera(mjsElement* element); + +// Safely cast an element as mjsLight, or return NULL if the element is not an mjsLight. +MJAPI mjsLight* mjs_asLight(mjsElement* element); + +// Safely cast an element as mjsFrame, or return NULL if the element is not an mjsFrame. +MJAPI mjsFrame* mjs_asFrame(mjsElement* element); + //---------------------------------- Attribute setters --------------------------------------------- diff --git a/introspect/codegen/generate_structs.py b/introspect/codegen/generate_structs.py index 42a5dcd140..7e2038b563 100644 --- a/introspect/codegen/generate_structs.py +++ b/introspect/codegen/generate_structs.py @@ -166,9 +166,15 @@ def visit(self, node: ClangJsonNode) -> None: elif (node.get('kind') == 'TypedefDecl' and node['type']['qualType'].startswith('struct mj') and node['name'] not in _EXCLUDED): - struct = self._structs[node['type']['qualType']] - self._typedefs[node['name']] = ast_nodes.StructDecl( - name=node['name'], declname=struct.declname, fields=struct.fields) + declname = node['type']['qualType'] + try: + struct = self._structs[declname] + except KeyError: + self._typedefs[node['name']] = ast_nodes.StructDecl( + name=node['name'], declname=declname, fields=()) + else: + self._typedefs[node['name']] = ast_nodes.StructDecl( + name=node['name'], declname=struct.declname, fields=struct.fields) def resolve_all_anonymous(self) -> None: """Replaces anonymous struct placeholders with corresponding decl.""" diff --git a/introspect/functions.py b/introspect/functions.py index a4808e5318..0bd071c33f 100644 --- a/introspect/functions.py +++ b/introspect/functions.py @@ -9813,6 +9813,118 @@ ), doc="Return body's next child of the same type; return NULL if child is last.", # pylint: disable=line-too-long )), + ('mjs_asBody', + FunctionDecl( + name='mjs_asBody', + return_type=PointerType( + inner_type=ValueType(name='mjsBody'), + ), + parameters=( + FunctionParameterDecl( + name='element', + type=PointerType( + inner_type=ValueType(name='mjsElement'), + ), + ), + ), + doc='Safely cast an element as mjsBody, or return NULL if the element is not an mjsBody.', # pylint: disable=line-too-long + )), + ('mjs_asGeom', + FunctionDecl( + name='mjs_asGeom', + return_type=PointerType( + inner_type=ValueType(name='mjsGeom'), + ), + parameters=( + FunctionParameterDecl( + name='element', + type=PointerType( + inner_type=ValueType(name='mjsElement'), + ), + ), + ), + doc='Safely cast an element as mjsGeom, or return NULL if the element is not an mjsGeom.', # pylint: disable=line-too-long + )), + ('mjs_asJoint', + FunctionDecl( + name='mjs_asJoint', + return_type=PointerType( + inner_type=ValueType(name='mjsJoint'), + ), + parameters=( + FunctionParameterDecl( + name='element', + type=PointerType( + inner_type=ValueType(name='mjsElement'), + ), + ), + ), + doc='Safely cast an element as mjsJoint, or return NULL if the element is not an mjsJoint.', # pylint: disable=line-too-long + )), + ('mjs_asSite', + FunctionDecl( + name='mjs_asSite', + return_type=PointerType( + inner_type=ValueType(name='mjsSite'), + ), + parameters=( + FunctionParameterDecl( + name='element', + type=PointerType( + inner_type=ValueType(name='mjsElement'), + ), + ), + ), + doc='Safely cast an element as mjsSite, or return NULL if the element is not an mjsSite.', # pylint: disable=line-too-long + )), + ('mjs_asCamera', + FunctionDecl( + name='mjs_asCamera', + return_type=PointerType( + inner_type=ValueType(name='mjsCamera'), + ), + parameters=( + FunctionParameterDecl( + name='element', + type=PointerType( + inner_type=ValueType(name='mjsElement'), + ), + ), + ), + doc='Safely cast an element as mjsCamera, or return NULL if the element is not an mjsCamera.', # pylint: disable=line-too-long + )), + ('mjs_asLight', + FunctionDecl( + name='mjs_asLight', + return_type=PointerType( + inner_type=ValueType(name='mjsLight'), + ), + parameters=( + FunctionParameterDecl( + name='element', + type=PointerType( + inner_type=ValueType(name='mjsElement'), + ), + ), + ), + doc='Safely cast an element as mjsLight, or return NULL if the element is not an mjsLight.', # pylint: disable=line-too-long + )), + ('mjs_asFrame', + FunctionDecl( + name='mjs_asFrame', + return_type=PointerType( + inner_type=ValueType(name='mjsFrame'), + ), + parameters=( + FunctionParameterDecl( + name='element', + type=PointerType( + inner_type=ValueType(name='mjsElement'), + ), + ), + ), + doc='Safely cast an element as mjsFrame, or return NULL if the element is not an mjsFrame.', # pylint: disable=line-too-long + )), ('mjs_setString', FunctionDecl( name='mjs_setString', diff --git a/python/MANIFEST.in b/python/MANIFEST.in index a55f15f191..e7ecd95366 100644 --- a/python/MANIFEST.in +++ b/python/MANIFEST.in @@ -1,3 +1,3 @@ include LICENSE *.md -recursive-include mujoco *.h *.cc *.mm CMakeLists.txt *.cmake +recursive-include mujoco *.h *.cc *.cc.inc *.mm CMakeLists.txt *.cmake recursive-include mujoco/mjpython mjpython.* Info.plist diff --git a/python/make_sdist.sh b/python/make_sdist.sh index 1466cb3d7d..7abe811d51 100755 --- a/python/make_sdist.sh +++ b/python/make_sdist.sh @@ -44,6 +44,8 @@ python "${package_dir}"/mujoco/codegen/generate_enum_traits.py > \ mujoco/enum_traits.h python "${package_dir}"/mujoco/codegen/generate_function_traits.py > \ mujoco/function_traits.h +python "${package_dir}"/mujoco/codegen/generate_spec_bindings.py > \ + mujoco/specs.cc.inc export PYTHONPATH="${old_pythonpath}" # Copy over the LICENSE file. diff --git a/python/mujoco/CMakeLists.txt b/python/mujoco/CMakeLists.txt index d31bd9032a..480865d410 100644 --- a/python/mujoco/CMakeLists.txt +++ b/python/mujoco/CMakeLists.txt @@ -405,6 +405,29 @@ target_link_libraries( structs_header ) +if(NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/specs.cc.inc) + add_custom_command( + OUTPUT specs.cc.inc + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${mujoco_SOURCE_DIR}/mujoco ${Python3_EXECUTABLE} + ${CMAKE_CURRENT_SOURCE_DIR}/codegen/generate_spec_bindings.py > specs.cc.inc + DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/codegen/generate_spec_bindings.py + ) +endif() + +mujoco_pybind11_module( + _specs + specs.cc + specs.cc.inc +) +target_link_libraries( + _specs + PRIVATE mujoco + Eigen3::Eigen + errors_header + raw + structs_header +) + mujoco_pybind11_module(_simulate simulate.cc) target_link_libraries( _simulate @@ -424,6 +447,7 @@ set(LIBRARIES_FOR_WHEEL "$" "$" "$" + "$" "$" "$" ) @@ -459,6 +483,7 @@ if(MUJOCO_PYTHON_MAKE_WHEEL) _render _rollout _simulate + _specs _structs mujoco ) diff --git a/python/mujoco/__init__.py b/python/mujoco/__init__.py index d9e0e5cf7f..dd3ad7137e 100644 --- a/python/mujoco/__init__.py +++ b/python/mujoco/__init__.py @@ -43,6 +43,7 @@ from mujoco._errors import * from mujoco._functions import * from mujoco._render import * +from mujoco._specs import * from mujoco._structs import * from mujoco.gl_context import * from mujoco.renderer import Renderer diff --git a/python/mujoco/codegen/generate_spec_bindings.py b/python/mujoco/codegen/generate_spec_bindings.py new file mode 100644 index 0000000000..064e83786e --- /dev/null +++ b/python/mujoco/codegen/generate_spec_bindings.py @@ -0,0 +1,210 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generates the bindings for the MuJoCo specs.""" + +from collections.abc import Sequence + +from absl import app + +from introspect import ast_nodes +from introspect import structs + + +def _scalar_binding_code( + field: ast_nodes.ValueType, classname: str = '', varname: str = '' +) -> str: + """Creates a string that defines Python bindings for a scalar type.""" + fulltype = field.name + '&' # default return type is by reference + fullvarname = varname + rawclassname = classname.replace('mjs', 'raw::Mjs') + if classname == 'mjSpec': # raw mjSpec has a wrapper + rawclassname = classname.replace('mjS', 'MjS') + fullvarname = 'ptr->' + varname + if field.name.startswith('mjs'): # all other mjs are raw structs + fulltype = field.name.replace('mjs', 'raw::Mjs') + if field.name != 'mjsPlugin' and field.name != 'mjsOrientation': + fulltype = fulltype + '*' # plugin and orientation are pointers + return f"""\ + {classname}.def_property( + "{varname}", + []({rawclassname}& self) -> {fulltype} {{ + return self.{fullvarname}; + }}, + []({rawclassname}& self, {fulltype} {varname}) {{ + (self.{fullvarname}) = {varname}; + }}, py::return_value_policy::reference_internal);""" + + +def _array_binding_code( + field: ast_nodes.ArrayType, classname: str = '', varname: str = '' +) -> str: + """Creates a string that declares Python bindings for an array type.""" + if len(field.extents) > 1: + raise NotImplementedError() + innertype = field.inner_type.decl() + rawclassname = classname.replace('mjs', 'raw::Mjs') + fullvarname = varname + if classname == 'mjSpec': # raw mjSpec has a wrapper + rawclassname = classname.replace('mjS', 'MjS') + fullvarname = 'ptr->' + varname + if innertype == 'double' or innertype == 'mjtNum': + innertype = 'MjDouble' # custom Eigen type + elif innertype == 'float': + innertype = 'MjFloat' # custom Eigen type + elif innertype == 'int': + innertype = 'MjInt' # custom Eigen type + elif innertype == 'char': + # char array special case + return f"""\ + {classname}.def_property( + "{varname}", + []({rawclassname}& self) -> py::array_t {{ + return py::array_t({field.extents[0]}, self.{fullvarname}); + }}, + []({rawclassname}& self, py::object rhs) {{ + int i = 0; + for (auto val : rhs) {{ + self.{fullvarname}[i++] = (py::cast(val)); + }} + }}, py::return_value_policy::reference_internal);""" + # all other array types + return f"""\ + {classname}.def_property( + "{varname}", + []({rawclassname}& self) -> {innertype}{field.extents[0]} {{ + return {innertype}{field.extents[0]}(self.{fullvarname}); + }}, + []({rawclassname}& self, {innertype}Ref{field.extents[0]} {varname}) {{ + {innertype}{field.extents[0]}(self.{fullvarname}) = {varname}; + }}, py::return_value_policy::reference_internal);""" + + +def _ptr_binding_code( + field: ast_nodes.PointerType, classname: str = '', varname: str = '' +) -> str: + """Creates a string that declares Python bindings for a pointer type.""" + vartype = field.inner_type.decl() + rawclassname = classname.replace('mjs', 'raw::Mjs') + fullvarname = varname + if classname == 'mjSpec': # raw mjSpec has a wrapper + rawclassname = classname.replace('mjS', 'MjS') + fullvarname = 'ptr->' + varname + if vartype == 'mjsElement': # this is ignored by the caller + return 'mjsElement' + if vartype.startswith('mjs'): # for structs, use the scalar case + return _scalar_binding_code(field.inner_type, classname, varname) + elif vartype == 'mjString': # C++ string -> Python string + return f"""\ + {classname}.def_property( + "{varname}", + []({rawclassname}& self) -> std::string_view {{ + return *self.{fullvarname}; + }}, + []({rawclassname}& self, std::string_view {varname}) {{ + *(self.{fullvarname}) = {varname}; + }}, py::return_value_policy::reference_internal);""" + elif ( # C++ vectors of scalars -> Python array + vartype == 'mjDoubleVec' + or vartype == 'mjFloatVec' + or vartype == 'mjIntVec' + ): + vartype = vartype.replace('mj', '').replace('Vec', '').lower() + return f"""\ + {classname}.def_property( + "{varname}", + []({rawclassname}& self) -> py::array_t<{vartype}> {{ + return py::array_t<{vartype}>(self.{fullvarname}->size(), + self.{fullvarname}->data()); + }}, + []({rawclassname}& self, py::object rhs) {{ + self.{fullvarname}->clear(); + self.{fullvarname}->reserve(py::len(rhs)); + for (auto val : rhs) {{ + self.{fullvarname}->push_back(py::cast<{vartype}>(val)); + }} + }}, py::return_value_policy::reference_internal);""" + elif vartype == 'mjStringVec': # C++ vector of strings -> Python list + return f"""\ + {classname}.def_property( + "{varname}", + []({rawclassname}& self) -> py::list {{ + py::list list; + for (auto val : *self.{fullvarname}) {{ + list.append(val); + }} + return list; + }}, + []({rawclassname}& self, py::object rhs) {{ + self.{fullvarname}->clear(); + self.{fullvarname}->reserve(py::len(rhs)); + for (auto val : rhs) {{ + self.{fullvarname}->push_back(py::cast(val)); + }} + }}, py::return_value_policy::reference_internal);""" + elif 'VecVec' in vartype: # C++ vector of vectors -> Python list of lists + vartype = vartype.replace('mj', '').replace('VecVec', '').lower() + return f"""\ + {classname}.def_property( + "{varname}", + []({rawclassname}& self) -> py::list {{ + py::list list; + for (auto inner_vec : *self.{fullvarname}) {{ + py::list inner_list; + for (auto val : inner_vec) {{ + inner_list.append(val); + }} + list.append(inner_list); + }} + return list; + }}, + []({rawclassname}& self, py::object rhs) {{ + self.{fullvarname}->clear(); + self.{fullvarname}->reserve(py::len(rhs)); + for (auto inner_list : rhs) {{ + auto inner_vec = py::cast>(inner_list); + self.{fullvarname}->push_back(inner_vec); + }} + }}, py::return_value_policy::reference_internal);""" + + raise NotImplementedError() + + +def _binding_code(field: ast_nodes.StructFieldDecl, key: str) -> str: + if isinstance(field.type, ast_nodes.ValueType): + return _scalar_binding_code(field.type, key, field.name) + elif isinstance(field.type, ast_nodes.PointerType): + return _ptr_binding_code(field.type, key, field.name) + elif isinstance(field.type, ast_nodes.ArrayType): + return _array_binding_code(field.type, key, field.name) + return '' + + +def generate() -> None: + for key in structs.STRUCTS.keys(): + if (key.startswith('mjs') or key == 'mjSpec') and key != 'mjsElement': + print('\n // ' + key) + for field in structs.STRUCTS[key].fields: + code = _binding_code(field, key) + if code != 'mjsElement': + print(code) + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + generate() + +if __name__ == '__main__': + app.run(main) diff --git a/python/mujoco/raw.h b/python/mujoco/raw.h index 5aa146f43a..3d481a12bf 100644 --- a/python/mujoco/raw.h +++ b/python/mujoco/raw.h @@ -18,6 +18,7 @@ #include #include #include +#include #include // Type aliases for MuJoCo C structs to allow us refer to consistently refer @@ -28,6 +29,35 @@ using MjContact = ::mjContact; using MjData = ::mjData; using MjLROpt = ::mjLROpt; using MjModel = ::mjModel; +using MjSpec = ::mjSpec; +using MjsElement = ::mjsElement; +using MjsOrientation = ::mjsOrientation; +using MjsPlugin = ::mjsPlugin; +using MjsBody = ::mjsBody; +using MjsFrame = ::mjsFrame; +using MjsJoint = ::mjsJoint; +using MjsGeom = ::mjsGeom; +using MjsSite = ::mjsSite; +using MjsCamera = ::mjsCamera; +using MjsLight = ::mjsLight; +using MjsFlex = ::mjsFlex; +using MjsMesh = ::mjsMesh; +using MjsHField = ::mjsHField; +using MjsSkin = ::mjsSkin; +using MjsTexture = ::mjsTexture; +using MjsMaterial = ::mjsMaterial; +using MjsPair = ::mjsPair; +using MjsExclude = ::mjsExclude; +using MjsEquality = ::mjsEquality; +using MjsTendon = ::mjsTendon; +using MjsWrap = ::mjsWrap; +using MjsActuator = ::mjsActuator; +using MjsSensor = ::mjsSensor; +using MjsNumeric = ::mjsNumeric; +using MjsText = ::mjsText; +using MjsTuple = ::mjsTuple; +using MjsKey = ::mjsKey; +using MjsDefault = ::mjsDefault; using MjOption = ::mjOption; using MjSolverStat = ::mjSolverStat; using MjStatistic = ::mjStatistic; diff --git a/python/mujoco/specs.cc b/python/mujoco/specs.cc new file mode 100644 index 0000000000..a0e95cd22e --- /dev/null +++ b/python/mujoco/specs.cc @@ -0,0 +1,768 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include "errors.h" +#include "indexers.h" +#include "raw.h" +#include "structs.h" +#include +#include +#include +#include +#include +#include + +namespace py = ::pybind11; + +namespace mujoco::python { +using MjInt2 = Eigen::Map; +using MjInt3 = Eigen::Map; +using MjFloat2 = Eigen::Map; +using MjFloat3 = Eigen::Map; +using MjFloat4 = Eigen::Map; +using MjDouble2 = Eigen::Map; +using MjDouble3 = Eigen::Map; +using MjDouble4 = Eigen::Map; +using MjDouble5 = Eigen::Map>; +using MjDouble6 = Eigen::Map>; +using MjDouble10 = Eigen::Map>; +using MjDouble11 = Eigen::Map>; +using MjDoubleVec = Eigen::Map; + +using MjIntRef2 = Eigen::Ref; +using MjIntRef3 = Eigen::Ref; +using MjFloatRef2 = Eigen::Ref; +using MjFloatRef3 = Eigen::Ref; +using MjFloatRef4 = Eigen::Ref; +using MjDoubleRef2 = Eigen::Ref; +using MjDoubleRef3 = Eigen::Ref; +using MjDoubleRef4 = Eigen::Ref; +using MjDoubleRef5 = Eigen::Ref>; +using MjDoubleRef6 = Eigen::Ref>; +using MjDoubleRef10 = Eigen::Ref>; +using MjDoubleRef11 = Eigen::Ref>; +using MjDoubleRefVec = Eigen::Ref; + +struct MjSpec { + MjSpec() : ptr(mj_makeSpec()) {} + ~MjSpec() { mj_deleteSpec(ptr); } + raw::MjSpec* ptr; +}; + +PYBIND11_MODULE(_specs, m) { + auto structs_m = py::module::import("mujoco._structs"); + py::function mjmodel_from_spec_ptr = + structs_m.attr("MjModel").attr("_from_spec_ptr"); + py::function mjmodel_mjdata_from_spec_ptr = + structs_m.attr("_recompile_spec_addr"); + + py::class_ mjSpec(m, "MjSpec"); + py::class_ mjsElement(m, "MjsElement"); + py::class_ mjsDefault(m, "MjsDefault"); + py::class_ mjsBody(m, "MjsBody"); + py::class_ mjsFrame(m, "MjsFrame"); + py::class_ mjsGeom(m, "MjsGeom"); + py::class_ mjsJoint(m, "MjsJoint"); + py::class_ mjsLight(m, "MjsLight"); + py::class_ mjsMaterial(m, "MjsMaterial"); + py::class_ mjsSite(m, "MjsSite"); + py::class_ mjsMesh(m, "MjsMesh"); + py::class_ mjsSkin(m, "MjsSkin"); + py::class_ mjsTexture(m, "MjsTexture"); + py::class_ mjsText(m, "MjsText"); + py::class_ mjsTuple(m, "MjsTuple"); + py::class_ mjsCamera(m, "MjsCamera"); + py::class_ mjsFlex(m, "MjsFlex"); + py::class_ mjsHField(m, "MjsHField"); + py::class_ mjsKey(m, "MjsKey"); + py::class_ mjsNumeric(m, "MjsNumeric"); + py::class_ mjsPair(m, "MjsPair"); + py::class_ mjsExclude(m, "MjsExclude"); + py::class_ mjsEquality(m, "MjsEquality"); + py::class_ mjsTendon(m, "MjsTendon"); + py::class_ mjsSensor(m, "MjsSensor"); + py::class_ mjsActuator(m, "MjsActuator"); + py::class_ mjsPlugin(m, "MjsPlugin"); + py::class_ mjsOrientation(m, "MjsOrientation"); + py::class_ mjsWrap(m, "MjsWrap"); + + // ============================= MJSPEC ===================================== + mjSpec.def(py::init<>()); + mjSpec.def("recompile", [mjmodel_mjdata_from_spec_ptr]( + const MjSpec& self, py::object m, py::object d) { + return mjmodel_mjdata_from_spec_ptr(reinterpret_cast(self.ptr), + m, d); + }); + mjSpec.def( + "copy", + [](const MjSpec& self) -> raw::MjSpec* { return mj_copySpec(self.ptr); }, + py::return_value_policy::reference_internal); + mjSpec.def_property_readonly( + "worldbody", + [](MjSpec& self) -> raw::MjsBody* { + return mjs_findBody(self.ptr, "world"); + }, + py::return_value_policy::reference_internal); + mjSpec.def( + "find_body", + [](MjSpec& self, std::string& name) -> raw::MjsBody* { + return mjs_findBody(self.ptr, name.c_str()); + }, + py::return_value_policy::reference_internal); + mjSpec.def( + "find_mesh", + [](MjSpec& self, std::string& name) -> raw::MjsMesh* { + return mjs_findMesh(self.ptr, name.c_str()); + }, + py::return_value_policy::reference_internal); + mjSpec.def( + "find_frame", + [](MjSpec& self, std::string& name) -> raw::MjsFrame* { + return mjs_findFrame(self.ptr, name.c_str()); + }, + py::return_value_policy::reference_internal); + mjSpec.def( + "find_default", + [](MjSpec& self, std::string& classname) -> raw::MjsDefault* { + return mjs_findDefault(self.ptr, classname.c_str()); + }, + py::return_value_policy::reference_internal); + mjSpec.def("compile", [mjmodel_from_spec_ptr](MjSpec& self) { + return mjmodel_from_spec_ptr(reinterpret_cast(self.ptr)); + }); + mjSpec.def( + "copy_back", + [](MjSpec& self, raw::MjModel& model) { + return mj_copyBack(self.ptr, &model); + }, + py::return_value_policy::reference_internal); + mjSpec.def("to_xml", [](MjSpec& self) -> std::string { + int size = mj_saveXMLString(self.ptr, nullptr, 0, nullptr, 0); + std::unique_ptr buf(new char[size + 1]); + std::array err; + buf[0] = '\0'; + err[0] = '\0'; + mj_saveXMLString(self.ptr, buf.get(), size + 1, err.data(), err.size()); + if (err[0] != '\0') { + throw FatalError(std::string(err.data())); + } + return std::string(buf.get()); + }); + mjSpec.def("from_file", [](MjSpec& self, std::string& filename) -> void { + std::array err; + err[0] = '\0'; + mj_deleteSpec(self.ptr); + self.ptr = mj_parseXML(filename.c_str(), 0, err.data(), err.size()); + if (!self.ptr) { + throw FatalError(std::string(err.data())); + } + }); + mjSpec.def("from_string", [](MjSpec& self, std::string& xml) -> void { + std::array err; + err[0] = '\0'; + mj_deleteSpec(self.ptr); + self.ptr = mj_parseXMLString(xml.c_str(), 0, err.data(), err.size()); + if (!self.ptr) { + throw FatalError(std::string(err.data())); + } + }); + mjSpec.def( + "add_default", + [](MjSpec* spec, std::string& classname, + raw::MjsDefault* parent) -> raw::MjsDefault* { + return mjs_addDefault(spec->ptr, classname.c_str(), parent); + }, + py::return_value_policy::reference_internal); + mjSpec.def( + "default", + [](MjSpec& self) -> raw::MjsDefault* { + return mjs_getSpecDefault(self.ptr); + }, + py::return_value_policy::reference_internal); + mjSpec.def( + "add_material", + [](MjSpec& self, raw::MjsDefault* default_) -> raw::MjsMaterial* { + return mjs_addMaterial(self.ptr, default_); + }, + py::arg_v("default", nullptr), + py::return_value_policy::reference_internal); + mjSpec.def( + "add_mesh", + [](MjSpec& self, raw::MjsDefault* default_) -> raw::MjsMesh* { + return mjs_addMesh(self.ptr, default_); + }, + py::arg_v("default", nullptr), + py::return_value_policy::reference_internal); + mjSpec.def( + "add_skin", + [](MjSpec& self) -> raw::MjsSkin* { return mjs_addSkin(self.ptr); }, + py::return_value_policy::reference_internal); + mjSpec.def( + "add_texture", + [](MjSpec& self) -> raw::MjsTexture* { return mjs_addTexture(self.ptr); }, + py::return_value_policy::reference_internal); + mjSpec.def( + "add_text", + [](MjSpec& self) -> raw::MjsText* { return mjs_addText(self.ptr); }, + py::return_value_policy::reference_internal); + mjSpec.def( + "add_tuple", + [](MjSpec& self) -> raw::MjsTuple* { return mjs_addTuple(self.ptr); }, + py::return_value_policy::reference_internal); + mjSpec.def( + "add_flex", + [](MjSpec& self) -> raw::MjsFlex* { return mjs_addFlex(self.ptr); }, + py::return_value_policy::reference_internal); + mjSpec.def( + "add_hfield", + [](MjSpec& self) -> raw::MjsHField* { return mjs_addHField(self.ptr); }, + py::return_value_policy::reference_internal); + mjSpec.def( + "add_key", + [](MjSpec& self) -> raw::MjsKey* { return mjs_addKey(self.ptr); }, + py::return_value_policy::reference_internal); + mjSpec.def( + "add_numeric", + [](MjSpec& self) -> raw::MjsNumeric* { return mjs_addNumeric(self.ptr); }, + py::return_value_policy::reference_internal); + mjSpec.def( + "add_pair", + [](MjSpec& self, raw::MjsDefault* default_) -> raw::MjsPair* { + return mjs_addPair(self.ptr, default_); + }, + py::arg_v("default", nullptr), + py::return_value_policy::reference_internal); + mjSpec.def( + "add_exclude", + [](MjSpec& self) -> raw::MjsExclude* { return mjs_addExclude(self.ptr); }, + py::return_value_policy::reference_internal); + mjSpec.def( + "add_equality", + [](MjSpec& self, raw::MjsDefault* default_) -> raw::MjsEquality* { + return mjs_addEquality(self.ptr, default_); + }, + py::arg_v("default", nullptr), + py::return_value_policy::reference_internal); + mjSpec.def( + "add_tendon", + [](MjSpec& self, raw::MjsDefault* default_) -> raw::MjsTendon* { + return mjs_addTendon(self.ptr, default_); + }, + py::arg_v("default", nullptr), + py::return_value_policy::reference_internal); + mjSpec.def( + "add_sensor", + [](MjSpec& self) -> raw::MjsSensor* { return mjs_addSensor(self.ptr); }, + py::return_value_policy::reference_internal); + mjSpec.def( + "add_actuator", + [](MjSpec& self, raw::MjsDefault* default_) -> raw::MjsActuator* { + return mjs_addActuator(self.ptr, default_); + }, + py::arg_v("default", nullptr), + py::return_value_policy::reference_internal); + mjSpec.def( + "add_plugin", + [](MjSpec& self) -> raw::MjsPlugin* { return mjs_addPlugin(self.ptr); }, + py::return_value_policy::reference_internal); + mjSpec.def("detach_body", [](MjSpec& self, raw::MjsBody& body) { + mjs_detachBody(self.ptr, &body); + }); + + // ============================= MJSBODY ==================================== + mjsBody.def("id", [](raw::MjsBody& self) -> int { + return mjs_getId(self.element); + }); + mjsBody.def("delete", [](raw::MjsBody& self) { mjs_delete(self.element); }); + mjsBody.def( + "add_body", + [](raw::MjsBody& self, raw::MjsDefault* default_) -> raw::MjsBody* { + return mjs_addBody(&self, default_); + }, + py::arg_v("default", nullptr), + py::return_value_policy::reference_internal); + mjsBody.def( + "add_frame", + [](raw::MjsBody& self, raw::MjsFrame* parentframe_) -> raw::MjsFrame* { + return mjs_addFrame(&self, parentframe_); + }, + py::arg_v("default", nullptr), + py::return_value_policy::reference_internal); + mjsBody.def( + "add_geom", + [](raw::MjsBody& self, raw::MjsDefault* default_) -> raw::MjsGeom* { + return mjs_addGeom(&self, default_); + }, + py::arg_v("default", nullptr), + py::return_value_policy::reference_internal); + mjsBody.def( + "add_joint", + [](raw::MjsBody& self, raw::MjsDefault* default_) -> raw::MjsJoint* { + return mjs_addJoint(&self, default_); + }, + py::arg_v("default", nullptr), + py::return_value_policy::reference_internal); + mjsBody.def( + "add_freejoint", + [](raw::MjsBody& self) -> raw::MjsJoint* { + return mjs_addFreeJoint(&self); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "add_light", + [](raw::MjsBody& self, raw::MjsDefault* default_) -> raw::MjsLight* { + return mjs_addLight(&self, default_); + }, + py::arg_v("default", nullptr), + py::return_value_policy::reference_internal); + mjsBody.def( + "add_site", + [](raw::MjsBody& self, raw::MjsDefault* default_) -> raw::MjsSite* { + return mjs_addSite(&self, default_); + }, + py::arg_v("default", nullptr), + py::return_value_policy::reference_internal); + mjsBody.def( + "add_camera", + [](raw::MjsBody& self, raw::MjsDefault* default_) -> raw::MjsCamera* { + return mjs_addCamera(&self, default_); + }, + py::arg_v("default", nullptr), + py::return_value_policy::reference_internal); + mjsBody.def("set_frame", + [](raw::MjsBody& self, raw::MjsFrame& frame) -> void { + mjs_setFrame(self.element, &frame); + }); + mjsBody.def("set_default", + [](raw::MjsBody& self, raw::MjsDefault& default_) -> void { + mjs_setDefault(self.element, &default_); + }); + mjsBody.def( + "default", + [](raw::MjsBody& self) -> raw::MjsDefault* { + return mjs_getDefault(self.element); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "find_child", + [](raw::MjsBody& self, std::string& name) -> raw::MjsBody* { + return mjs_findChild(&self, name.c_str()); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "first_body", + [](raw::MjsBody& self) -> raw::MjsBody* { + return mjs_asBody(mjs_firstChild(&self, mjOBJ_BODY)); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "next_body", + [](raw::MjsBody& self, raw::MjsBody& child) -> raw::MjsBody* { + return mjs_asBody(mjs_nextChild(&self, child.element)); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "first_camera", + [](raw::MjsBody& self) -> raw::MjsCamera* { + return mjs_asCamera(mjs_firstChild(&self, mjOBJ_CAMERA)); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "next_camera", + [](raw::MjsBody& self, raw::MjsCamera& child) -> raw::MjsCamera* { + return mjs_asCamera(mjs_nextChild(&self, child.element)); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "first_light", + [](raw::MjsBody& self) -> raw::MjsLight* { + return mjs_asLight(mjs_firstChild(&self, mjOBJ_LIGHT)); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "next_light", + [](raw::MjsBody& self, raw::MjsLight& child) -> raw::MjsLight* { + return mjs_asLight(mjs_nextChild(&self, child.element)); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "first_joint", + [](raw::MjsBody& self) -> raw::MjsJoint* { + return mjs_asJoint(mjs_firstChild(&self, mjOBJ_JOINT)); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "next_joint", + [](raw::MjsBody& self, raw::MjsJoint& child) -> raw::MjsJoint* { + return mjs_asJoint(mjs_nextChild(&self, child.element)); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "first_geom", + [](raw::MjsBody& self) -> raw::MjsGeom* { + return mjs_asGeom(mjs_firstChild(&self, mjOBJ_GEOM)); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "next_geom", + [](raw::MjsBody& self, raw::MjsGeom& child) -> raw::MjsGeom* { + return mjs_asGeom(mjs_nextChild(&self, child.element)); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "first_site", + [](raw::MjsBody& self) -> raw::MjsSite* { + return mjs_asSite(mjs_firstChild(&self, mjOBJ_SITE)); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "next_site", + [](raw::MjsBody& self, raw::MjsSite& child) -> raw::MjsSite* { + return mjs_asSite(mjs_nextChild(&self, child.element)); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "first_frame", + [](raw::MjsBody& self) -> raw::MjsFrame* { + return mjs_asFrame(mjs_firstChild(&self, mjOBJ_FRAME)); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "next_frame", + [](raw::MjsBody& self, raw::MjsFrame& child) -> raw::MjsFrame* { + return mjs_asFrame(mjs_nextChild(&self, child.element)); + }, + py::return_value_policy::reference_internal); + mjsBody.def( + "spec", + [](raw::MjsBody& self) -> raw::MjSpec* { return mjs_getSpec(&self); }, + py::return_value_policy::reference_internal); + mjsBody.def("attach_frame", + [](raw::MjsBody& self, raw::MjsFrame& frame, std::string& prefix, + std::string& suffix) -> void { + mjs_attachFrame(&self, &frame, prefix.c_str(), suffix.c_str()); + }); + + // ============================= MJSFRAME ==================================== + mjsFrame.def("id", [](raw::MjsFrame& self) -> int { + return mjs_getId(self.element); + }); + mjsFrame.def("delete", [](raw::MjsFrame& self) { mjs_delete(self.element); }); + mjsFrame.def("set_frame", [](raw::MjsFrame& self, raw::MjsFrame& frame) { + mjs_setFrame(self.element, &frame); + }); + mjsFrame.def("attach_body", [](raw::MjsFrame& self, raw::MjsBody& body, + std::string& prefix, std::string& suffix) { + mjs_attachBody(&self, &body, prefix.c_str(), suffix.c_str()); + }); + + // ============================= MJSGEOM ==================================== + mjsGeom.def("id", [](raw::MjsGeom& self) -> int { + return mjs_getId(self.element); + }); + mjsGeom.def("delete", [](raw::MjsGeom& self) { mjs_delete(self.element); }); + mjsGeom.def("set_frame", [](raw::MjsGeom& self, raw::MjsFrame& frame) { + mjs_setFrame(self.element, &frame); + }); + mjsGeom.def("set_default", [](raw::MjsGeom& self, raw::MjsDefault& def) { + mjs_setDefault(self.element, &def); + }); + mjsGeom.def( + "default", + [](raw::MjsGeom& self) -> raw::MjsDefault* { + return mjs_getDefault(self.element); + }, + py::return_value_policy::reference_internal); + + // ============================= MJSJOINT ==================================== + mjsJoint.def("id", [](raw::MjsJoint& self) -> int { + return mjs_getId(self.element); + }); + mjsJoint.def("delete", [](raw::MjsJoint& self) { mjs_delete(self.element); }); + mjsJoint.def("set_frame", [](raw::MjsJoint& self, raw::MjsFrame& frame) { + mjs_setFrame(self.element, &frame); + }); + mjsJoint.def("set_default", [](raw::MjsJoint& self, raw::MjsDefault& def) { + mjs_setDefault(self.element, &def); + }); + mjsJoint.def( + "default", + [](raw::MjsJoint& self) -> raw::MjsDefault* { + return mjs_getDefault(self.element); + }, + py::return_value_policy::reference_internal); + + // ============================= MJSSITE ==================================== + mjsSite.def("id", [](raw::MjsSite& self) -> int { + return mjs_getId(self.element); + }); + mjsSite.def("delete", [](raw::MjsSite& self) { mjs_delete(self.element); }); + mjsSite.def("set_frame", [](raw::MjsSite& self, raw::MjsFrame& frame) { + mjs_setFrame(self.element, &frame); + }); + mjsSite.def("set_default", [](raw::MjsSite& self, raw::MjsDefault& def) { + mjs_setDefault(self.element, &def); + }); + mjsSite.def( + "default", + [](raw::MjsSite& self) -> raw::MjsDefault* { + return mjs_getDefault(self.element); + }, + py::return_value_policy::reference_internal); + + // ============================= MJSCAMERA ================================== + mjsCamera.def("id", [](raw::MjsCamera& self) -> int { + return mjs_getId(self.element); + }); + mjsCamera.def("delete", + [](raw::MjsCamera& self) { mjs_delete(self.element); }); + mjsCamera.def("set_frame", [](raw::MjsCamera& self, raw::MjsFrame& frame) { + mjs_setFrame(self.element, &frame); + }); + mjsCamera.def("set_default", [](raw::MjsCamera& self, raw::MjsDefault& def) { + mjs_setDefault(self.element, &def); + }); + mjsCamera.def( + "default", + [](raw::MjsCamera& self) -> raw::MjsDefault* { + return mjs_getDefault(self.element); + }, + py::return_value_policy::reference_internal); + + // ============================= MJSLIGHT ==================================== + mjsLight.def("id", [](raw::MjsLight& self) -> int { + return mjs_getId(self.element); + }); + mjsLight.def("delete", [](raw::MjsLight& self) { mjs_delete(self.element); }); + mjsLight.def("set_frame", [](raw::MjsLight& self, raw::MjsFrame& frame) { + mjs_setFrame(self.element, &frame); + }); + mjsLight.def("set_default", [](raw::MjsLight& self, raw::MjsDefault& def) { + mjs_setDefault(self.element, &def); + }); + mjsLight.def( + "default", + [](raw::MjsLight& self) -> raw::MjsDefault* { + return mjs_getDefault(self.element); + }, + py::return_value_policy::reference_internal); + + // ============================= MJSMATERIAL ================================ + mjsMaterial.def("id", [](raw::MjsMaterial& self) -> int { + return mjs_getId(self.element); + }); + mjsMaterial.def("delete", + [](raw::MjsMaterial& self) { mjs_delete(self.element); }); + mjsMaterial.def("set_default", + [](raw::MjsMaterial& self, raw::MjsDefault& def) { + mjs_setDefault(self.element, &def); + }); + mjsMaterial.def( + "default", + [](raw::MjsMaterial& self) -> raw::MjsDefault* { + return mjs_getDefault(self.element); + }, + py::return_value_policy::reference_internal); + + // ============================= MJSMESH ==================================== + mjsMesh.def("id", [](raw::MjsMesh& self) -> int { + return mjs_getId(self.element); + }); + mjsMesh.def("delete", [](raw::MjsMesh& self) { mjs_delete(self.element); }); + mjsMesh.def("set_default", [](raw::MjsMesh& self, raw::MjsDefault& def) { + mjs_setDefault(self.element, &def); + }); + mjsMesh.def( + "default", + [](raw::MjsMesh& self) -> raw::MjsDefault* { + return mjs_getDefault(self.element); + }, + py::return_value_policy::reference_internal); + + // ============================= MJSPAIR ==================================== + mjsPair.def("id", [](raw::MjsPair& self) -> int { + return mjs_getId(self.element); + }); + mjsPair.def("delete", [](raw::MjsPair& self) { mjs_delete(self.element); }); + mjsPair.def("set_default", [](raw::MjsPair& self, raw::MjsDefault& def) { + mjs_setDefault(self.element, &def); + }); + mjsPair.def( + "default", + [](raw::MjsPair& self) -> raw::MjsDefault* { + return mjs_getDefault(self.element); + }, + py::return_value_policy::reference_internal); + + // ============================= MJSEQUAL ==================================== + mjsEquality.def("id", [](raw::MjsEquality& self) -> int { + return mjs_getId(self.element); + }); + mjsEquality.def("delete", + [](raw::MjsEquality& self) { mjs_delete(self.element); }); + mjsEquality.def("set_default", + [](raw::MjsEquality& self, raw::MjsDefault& def) { + mjs_setDefault(self.element, &def); + }); + mjsEquality.def( + "default", + [](raw::MjsEquality& self) -> raw::MjsDefault* { + return mjs_getDefault(self.element); + }, + py::return_value_policy::reference_internal); + + // ============================= MJSACTUATOR ================================ + mjsActuator.def("id", [](raw::MjsActuator& self) -> int { + return mjs_getId(self.element); + }); + mjsActuator.def("delete", + [](raw::MjsActuator& self) { mjs_delete(self.element); }); + mjsActuator.def("set_default", + [](raw::MjsActuator& self, raw::MjsDefault& def) { + mjs_setDefault(self.element, &def); + }); + mjsActuator.def( + "default", + [](raw::MjsActuator& self) -> raw::MjsDefault* { + return mjs_getDefault(self.element); + }, + py::return_value_policy::reference_internal); + + // ============================= MJSTENDON ================================== + mjsTendon.def("id", [](raw::MjsTendon& self) -> int { + return mjs_getId(self.element); + }); + mjsTendon.def("delete", + [](raw::MjsTendon& self) { mjs_delete(self.element); }); + mjsTendon.def("set_default", [](raw::MjsTendon& self, raw::MjsDefault& def) { + mjs_setDefault(self.element, &def); + }); + mjsTendon.def( + "default", + [](raw::MjsTendon& self) -> raw::MjsDefault* { + return mjs_getDefault(self.element); + }, + py::return_value_policy::reference_internal); + mjsTendon.def( + "wrap_site", + [](raw::MjsTendon& self, std::string& name) { + return mjs_wrapSite(&self, name.c_str()); + }, + py::return_value_policy::reference_internal); + mjsTendon.def( + "wrap_geom", + [](raw::MjsTendon& self, std::string& name, std::string& sidesite) { + return mjs_wrapGeom(&self, name.c_str(), sidesite.c_str()); + }, + py::return_value_policy::reference_internal); + mjsTendon.def( + "wrap_joint", + [](raw::MjsTendon& self, std::string& name, double coef) { + return mjs_wrapJoint(&self, name.c_str(), coef); + }, + py::return_value_policy::reference_internal); + mjsTendon.def( + "wrap_pulley", + [](raw::MjsTendon& self, double divisor) { + return mjs_wrapPulley(&self, divisor); + }, + py::return_value_policy::reference_internal); + + // ============================= MJSSENSOR ================================== + mjsSensor.def("id", [](raw::MjsSensor& self) -> int { + return mjs_getId(self.element); + }); + mjsSensor.def("delete", + [](raw::MjsSensor& self) { mjs_delete(self.element); }); + + // ============================= MJSFLEX ==================================== + mjsFlex.def("id", [](raw::MjsFlex& self) -> int { + return mjs_getId(self.element); + }); + mjsFlex.def("delete", [](raw::MjsFlex& self) { mjs_delete(self.element); }); + + // ============================= MJSHFIELD ================================== + mjsHField.def("id", [](raw::MjsHField& self) -> int { + return mjs_getId(self.element); + }); + mjsHField.def("delete", + [](raw::MjsHField& self) { mjs_delete(self.element); }); + + // ============================= MJSSKIN ==================================== + mjsSkin.def("id", [](raw::MjsSkin& self) -> int { + return mjs_getId(self.element); + }); + mjsSkin.def("delete", [](raw::MjsSkin& self) { mjs_delete(self.element); }); + + // ============================= MJSTEXTURE ================================= + mjsTexture.def("id", [](raw::MjsTexture& self) -> int { + return mjs_getId(self.element); + }); + mjsTexture.def("delete", + [](raw::MjsTexture& self) { mjs_delete(self.element); }); + + // ============================= MJSKEY ===================================== + mjsKey.def("id", + [](raw::MjsKey& self) -> int { return mjs_getId(self.element); }); + mjsKey.def("delete", [](raw::MjsKey& self) { mjs_delete(self.element); }); + + // ============================= MJSTEXT ==================================== + mjsText.def("id", [](raw::MjsText& self) -> int { + return mjs_getId(self.element); + }); + mjsText.def("delete", [](raw::MjsText& self) { mjs_delete(self.element); }); + + // ============================= MJSNUMERIC ================================= + mjsNumeric.def("id", [](raw::MjsNumeric& self) -> int { + return mjs_getId(self.element); + }); + mjsNumeric.def("delete", + [](raw::MjsNumeric& self) { mjs_delete(self.element); }); + + // ============================= MJSEXCLUDE ================================ + mjsExclude.def("id", [](raw::MjsExclude& self) -> int { + return mjs_getId(self.element); + }); + mjsExclude.def("delete", + [](raw::MjsExclude& self) { mjs_delete(self.element); }); + + // ============================= MJSTUPLE =================================== + mjsTuple.def("id", [](raw::MjsTuple& self) -> int { + return mjs_getId(self.element); + }); + mjsTuple.def("delete", [](raw::MjsTuple& self) { mjs_delete(self.element); }); + + // ============================= MJSPLUGIN ================================== + mjsPlugin.def("id", [](raw::MjsPlugin& self) -> int { + return mjs_getId(self.instance); + }); + mjsPlugin.def("delete", + [](raw::MjsPlugin& self) { mjs_delete(self.instance); }); + +#include "specs.cc.inc" + +} // PYBIND11_MODULE // NOLINT +} // namespace mujoco::python diff --git a/python/mujoco/specs_test.py b/python/mujoco/specs_test.py new file mode 100644 index 0000000000..6abb503a6f --- /dev/null +++ b/python/mujoco/specs_test.py @@ -0,0 +1,222 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for mjSpec bindings.""" + +import inspect +import textwrap + +from absl.testing import absltest +import mujoco +import numpy as np + + +def get_linenumber(): + cf = inspect.currentframe() + return cf.f_back.f_lineno + + +class SpecsTest(absltest.TestCase): + + def test_basic(self): + # Create a spec. + spec = mujoco.MjSpec() + + # Check that euler sequence order is set correctly. + self.assertEqual(spec.euler[0], ord('x')) + spec.euler = ['z', 'y', 'x'] + self.assertEqual(spec.euler[0], ord('z')) + + # Add a body, check that it has default orientation. + body = spec.worldbody.add_body() + self.assertEqual(body.name, '') + np.testing.assert_array_equal(body.quat, [1, 0, 0, 0]) + + # Change the name of the body and read it back twice. + body.name = 'foobar' + self.assertEqual(body.name, 'foobar') + body.name = 'baz' + self.assertEqual(body.name, 'baz') + + # Change the position of the body and read it back. + body.pos = [1, 2, 3] + np.testing.assert_array_equal(body.pos, [1, 2, 3]) + self.assertEqual(body.pos.shape, (3,)) + + # Change the orientation of the body and read it back. + body.quat = [0, 1, 0, 0] + np.testing.assert_array_equal(body.quat, [0, 1, 0, 0]) + self.assertEqual(body.quat.shape, (4,)) + + # Add a site to the body with user data and read it back. + site = body.add_site() + site.name = 'sitename' + site.userdata = [1, 2, 3, 4, 5, 6] + self.assertEqual(site.name, 'sitename') + np.testing.assert_array_equal(site.userdata, [1, 2, 3, 4, 5, 6]) + + # Check that the site has no id before compilation. + self.assertEqual(body.id(), -1) + + # Compile the spec and check for expected values in the model. + model = spec.compile() + self.assertEqual(spec.worldbody.id(), 0) + self.assertEqual(body.id(), 1) + self.assertEqual(model.nbody, 2) # 2 bodies, including the world body + np.testing.assert_array_equal(model.body_pos[1], [1, 2, 3]) + np.testing.assert_array_equal(model.body_quat[1], [0, 1, 0, 0]) + self.assertEqual(model.nsite, 1) + self.assertEqual(model.nuser_site, 6) + np.testing.assert_array_equal(model.site_user[0], [1, 2, 3, 4, 5, 6]) + + self.assertEqual(spec.to_xml(), textwrap.dedent("""\ + + + + + + + + + + + + """),) + + def test_compile_errors_with_line_info(self): + spec = mujoco.MjSpec() + + added_on_line = get_linenumber() + 1 + geom = spec.worldbody.add_geom() + geom.name = 'MyGeom' + geom.info = f'geom added on line {added_on_line}' + + # Try to compile, get error. + expected_error = ( + 'Error: size 0 must be positive in geom\n' + + f'Element name \'MyGeom\', id 0, geom added on line {added_on_line}' + ) + with self.assertRaisesRegex(ValueError, expected_error): + spec.compile() + + def test_recompile(self): + # Create a spec. + spec = mujoco.MjSpec() + + # Add movable body1. + body1 = spec.worldbody.add_body() + geom = body1.add_geom() + geom.size[0] = 1 + geom.pos = [1, 1, 0] + joint = body1.add_joint() + joint.type = mujoco.mjtJoint.mjJNT_BALL + + # Compile model, make data. + model = spec.compile() + data = mujoco.MjData(model) + + # Simulate for 1 second. + while data.time < 1: + mujoco.mj_step(model, data) + + # Add movable body2. + body2 = spec.worldbody.add_body() + body2.pos[1] = 3 + geom = body2.add_geom() + geom.size[0] = 1 + geom.pos = [0, 1, 0] + joint = body2.add_joint() + joint.type = mujoco.mjtJoint.mjJNT_BALL + + # Recompile model and data while maintaining the state. + model_new, data_new = spec.recompile(model, data) + + # Check that the state is preserved. + np.testing.assert_array_equal(model_new.body_pos[1], model.body_pos[1]) + np.testing.assert_array_equal(data_new.qpos[:4], data.qpos) + np.testing.assert_array_equal(data_new.qvel[:3], data.qvel) + + def test_uncompiled_spec_cannot_be_written(self): + spec = mujoco.MjSpec() + + # Cannot write XML of an uncompiled spec. + expected_error = 'XML Write error: Only compiled model can be written' + with self.assertRaisesWithLiteralMatch(mujoco.FatalError, expected_error): + spec.to_xml() + + def test_modelname_default_class(self): + spec = mujoco.MjSpec() + spec.modelname = 'test' + + main = spec.default() + main.geom.size[0] = 2 + + def1 = spec.add_default('def1', main) + def1.geom.size[0] = 3 + + spec.worldbody.add_geom(def1) + spec.worldbody.add_geom(main) + + spec.compile() + self.assertEqual(spec.to_xml(), textwrap.dedent("""\ + + + + + + + + + + + + + + + + """)) + spec = mujoco.MjSpec() + spec.modelname = 'test' + + main = spec.default() + main.geom.size[0] = 2 + + def1 = spec.add_default('def1', main) + def1.geom.size[0] = 3 + + spec.worldbody.add_geom(def1) + spec.worldbody.add_geom(main) + + spec.compile() + self.assertEqual(spec.to_xml(), textwrap.dedent("""\ + + + + + + + + + + + + + + + + """)) + + +if __name__ == '__main__': + absltest.main() diff --git a/python/mujoco/structs.cc b/python/mujoco/structs.cc index a49d4321f5..12412acb9f 100644 --- a/python/mujoco/structs.cc +++ b/python/mujoco/structs.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -430,6 +431,27 @@ MjModelWrapper MjModelWrapper::LoadXML( return MjModelWrapper(model); } +MjModelWrapper MjModelWrapper::CompileSpec(raw::MjSpec* spec) { + auto m = mj_compile(spec, nullptr); + if (!m || mjs_isWarning(spec)) { + throw py::value_error(mjs_getError(spec)); + } + return MjModelWrapper(m); +} + +py::tuple RecompileSpec(raw::MjSpec* spec, const MjModelWrapper& old_m, + const MjDataWrapper& old_d) { + raw::MjModel* m = static_cast(mju_malloc(sizeof(mjModel))); + m->buffer = nullptr; + raw::MjData* d = mj_copyData(nullptr, old_m.get(), old_d.get()); + mj_recompile(spec, nullptr, m, d); + + py::object m_pyobj = py::cast((MjModelWrapper(m))); + py::object d_pyobj = + py::cast((MjDataWrapper(py::cast(m_pyobj), d))); + return py::make_tuple(m_pyobj, d_pyobj); +} + namespace { // A byte at the start of serialized mjModel structs, which can be incremented // when we change the serialization logic to reject pickles from an unsupported @@ -1557,6 +1579,11 @@ PYBIND11_MODULE(_structs, m) { py::arg("xml"), py::arg_v("assets", py::none()), py::doc( R"(Loads an MjModel from an XML string and an optional assets dictionary.)")); + mjModel.def_static( + "_from_spec_ptr", [](uintptr_t addr) { + return MjModelWrapper::CompileSpec( + reinterpret_cast(addr)); + }); mjModel.def_static( "from_xml_path", &MjModelWrapper::LoadXMLFile, py::arg("filename"), py::arg_v("assets", py::none()), @@ -2406,5 +2433,12 @@ This is useful for example when the MJB is not available as a file on disk.)")); }, py::arg("cam1"), py::arg("cam2"), py::doc(python_traits::mjv_averageCamera::doc)); + + m.def( + "_recompile_spec_addr", + [](uintptr_t spec_addr, const MjModelWrapper& m, const MjDataWrapper& d) { + return RecompileSpec(reinterpret_cast(spec_addr), m, d); + } + ); } // PYBIND11_MODULE NOLINT(readability/fn_size) } // namespace mujoco::python::_impl diff --git a/python/mujoco/structs.h b/python/mujoco/structs.h index f27e17ff2b..9d3fe2b904 100644 --- a/python/mujoco/structs.h +++ b/python/mujoco/structs.h @@ -462,6 +462,10 @@ class MjWrapper : public WrapperBase { public: MjWrapper(const MjWrapper&); MjWrapper(MjWrapper&&); + + // Takes ownership of the raw mjModel pointer. + explicit MjWrapper(raw::MjModel* ptr); + ~MjWrapper(); MjModelIndexer& indexer() { return indexer_; } @@ -485,6 +489,8 @@ class MjWrapper : public WrapperBase { const std::optional< std::unordered_map>& assets); + static MjWrapper CompileSpec(raw::MjSpec* spec); + static constexpr char kFromRawPointer[] = "__MUJOCO_STRUCTS_MJMODELWRAPPER_LOOKUP"; static MjWrapper* FromRawPointer(raw::MjModel* m) noexcept; @@ -502,8 +508,6 @@ class MjWrapper : public WrapperBase { pybind11::bytes paths_bytes; protected: - explicit MjWrapper(raw::MjModel* ptr); - MjModelIndexer indexer_; }; @@ -591,8 +595,14 @@ class MjWrapper: public WrapperBase { explicit MjWrapper(MjModelWrapper* model); MjWrapper(const MjWrapper& other); MjWrapper(MjWrapper&&); + // Used for deepcopy MjWrapper(const MjWrapper& other, MjModelWrapper* model); + + // Internal constructor which takes ownership of given mjData pointer. + // Used for deserialization and recompile. + explicit MjWrapper(MjModelWrapper* model, raw::MjData* d); + ~MjWrapper(); const MjModelWrapper& model() const { return *model_; } @@ -622,9 +632,6 @@ class MjWrapper: public WrapperBase { py_array_or_tuple_t energy; protected: - // Internal constructor which takes ownership of given mjData pointer. - // Used for deserialization. - explicit MjWrapper(MjModelWrapper* model, raw::MjData* d); raw::MjData* Copy() const; // A reference to the model that was used to create this mjData. diff --git a/python/setup.py b/python/setup.py index bc1934bb49..7807c410e0 100644 --- a/python/setup.py +++ b/python/setup.py @@ -347,6 +347,7 @@ def run(self): CMakeExtension('mujoco._render'), CMakeExtension('mujoco._rollout'), CMakeExtension('mujoco._simulate'), + CMakeExtension('mujoco._specs'), CMakeExtension('mujoco._structs'), ], scripts=[ diff --git a/src/user/user_api.cc b/src/user/user_api.cc index f71fa17666..e84fff6685 100644 --- a/src/user/user_api.cc +++ b/src/user/user_api.cc @@ -581,6 +581,55 @@ mjsElement* mjs_nextChild(mjsBody* body, mjsElement* child) { +// return body given mjsElement +mjsBody* mjs_asBody(mjsElement* element) { + return element->elemtype == mjOBJ_BODY ? &(static_cast(element)->spec) : nullptr; +} + + + +// return geom given mjsElement +mjsGeom* mjs_asGeom(mjsElement* element) { + return element->elemtype == mjOBJ_GEOM ? &(static_cast(element)->spec) : nullptr; +} + + + +// return joint given mjsElement +mjsJoint* mjs_asJoint(mjsElement* element) { + return element->elemtype == mjOBJ_JOINT ? &(static_cast(element)->spec) : nullptr; +} + + + +// Return site given mjsElement +mjsSite* mjs_asSite(mjsElement* element) { + return element->elemtype == mjOBJ_SITE ? &(static_cast(element)->spec) : nullptr; +} + + + +// return camera given mjsElement +mjsCamera* mjs_asCamera(mjsElement* element) { + return element->elemtype == mjOBJ_CAMERA ? &(static_cast(element)->spec) : nullptr; +} + + + +// return light given mjsElement +mjsLight* mjs_asLight(mjsElement* element) { + return element->elemtype == mjOBJ_LIGHT ? &(static_cast(element)->spec) : nullptr; +} + + + +// return frame given mjsElement +mjsFrame* mjs_asFrame(mjsElement* element) { + return element->elemtype == mjOBJ_FRAME ? &(static_cast(element)->spec) : nullptr; +} + + + // set string void mjs_setString(mjString* dest, const char* text) { std::string* str = static_cast(dest); diff --git a/src/user/user_api.h b/src/user/user_api.h index 9e30b105fd..676fe4695c 100644 --- a/src/user/user_api.h +++ b/src/user/user_api.h @@ -215,6 +215,27 @@ MJAPI mjsElement* mjs_firstChild(mjsBody* body, mjtObj type); // Return body's next child of the same type; return NULL if child is last. MJAPI mjsElement* mjs_nextChild(mjsBody* body, mjsElement* child); +// Safely cast an element as mjsBody, or return NULL if the element is not an mjsBody. +MJAPI mjsBody* mjs_asBody(mjsElement* element); + +// Safely cast an element as mjsGeom, or return NULL if the element is not an mjsGeom. +MJAPI mjsGeom* mjs_asGeom(mjsElement* element); + +// Safely cast an element as mjsJoint, or return NULL if the element is not an mjsJoint. +MJAPI mjsJoint* mjs_asJoint(mjsElement* element); + +// Safely cast an element as mjsSite, or return NULL if the element is not an mjsSite. +MJAPI mjsSite* mjs_asSite(mjsElement* element); + +// Safely cast an element as mjsCamera, or return NULL if the element is not an mjsCamera. +MJAPI mjsCamera* mjs_asCamera(mjsElement* element); + +// Safely cast an element as mjsLight, or return NULL if the element is not an mjsLight. +MJAPI mjsLight* mjs_asLight(mjsElement* element); + +// Safely cast an element as mjsFrame, or return NULL if the element is not an mjsFrame. +MJAPI mjsFrame* mjs_asFrame(mjsElement* element); + //---------------------------------- Attribute setters ---------------------------------------------