Skip to content

Commit

Permalink
Add proper handling of function pointer types in autowrap.
Browse files Browse the repository at this point in the history
Currently, all function pointers are marshaled as `ctypes.c_void_p`, and `set_callback` always wraps a Python callable into a CFUNCTYPE with a fixed signature. However, different MuJoCo callbacks have different function signatures. In particular, sensor and actuator callbacks take in an additional argument compared to the "generic" function pointer type.

In this CL, we modify autowrap to generate code that wraps callbacks into CFUNCTYPE with the correct signature. The generated wrapper code exposes callbacks (both global ones and those that appear in UI-related structs) as Python properties, where the getters and setters automatically handles the wrapping of Python callables.

PiperOrigin-RevId: 233789558
  • Loading branch information
saran-t authored and alimuldal committed Feb 18, 2019
1 parent c44cac1 commit 5cc8c9b
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 102 deletions.
46 changes: 28 additions & 18 deletions dm_control/autowrap/autowrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@

import six

_MJMODEL_H = "mjmodel.h"
_MJXMACRO_H = "mjxmacro.h"

FLAGS = flags.FLAGS

flags.DEFINE_spaceseplist(
Expand All @@ -60,17 +63,24 @@


def main(unused_argv):
# Get the path to the xmacro header file.
xmacro_hdr_path = None
for path in FLAGS.header_paths:
if path.endswith("mjxmacro.h"):
xmacro_hdr_path = path
break
if xmacro_hdr_path is None:
logging.fatal("List of inputs must contain a path to mjxmacro.h")

special_header_paths = {}

# Get the path to the mjmodel and mjxmacro header files.
# These header files need special handling.
for header in (_MJMODEL_H, _MJXMACRO_H):
for path in FLAGS.header_paths:
if path.endswith(header):
special_header_paths[header] = path
break
if header not in special_header_paths:
logging.fatal("List of inputs must contain a path to %s", header)

# Make sure mjmodel.h is parsed first, since it is included by other headers.
srcs = codegen_util.UniqueOrderedDict()
for p in sorted(FLAGS.header_paths):
sorted_header_paths = sorted(FLAGS.header_paths)
sorted_header_paths.remove(special_header_paths[_MJMODEL_H])
sorted_header_paths.insert(0, special_header_paths[_MJMODEL_H])
for p in sorted_header_paths:
with io.open(p, "r", errors="ignore") as f:
srcs[p] = f.read()

Expand All @@ -92,30 +102,30 @@ def main(unused_argv):

# Parse enums.
for pth, src in six.iteritems(srcs):
if pth is not xmacro_hdr_path:
if pth is not special_header_paths[_MJXMACRO_H]:
parser.parse_enums(src)

# Parse constants and type declarations.
for pth, src in six.iteritems(srcs):
if pth is not xmacro_hdr_path:
if pth is not special_header_paths[_MJXMACRO_H]:
parser.parse_consts_typedefs(src)

# Get shape hints from mjxmacro.h.
parser.parse_hints(srcs[xmacro_hdr_path])
parser.parse_hints(srcs[special_header_paths[_MJXMACRO_H]])

# Parse structs.
# Parse structs and function pointer type declarations.
for pth, src in six.iteritems(srcs):
if pth is not xmacro_hdr_path:
parser.parse_structs(src)
if pth is not special_header_paths[_MJXMACRO_H]:
parser.parse_structs_and_function_pointer_typedefs(src)

# Parse functions.
for pth, src in six.iteritems(srcs):
if pth is not xmacro_hdr_path:
if pth is not special_header_paths[_MJXMACRO_H]:
parser.parse_functions(src)

# Parse global strings and function pointers.
for pth, src in six.iteritems(srcs):
if pth is not xmacro_hdr_path:
if pth is not special_header_paths[_MJXMACRO_H]:
parser.parse_global_strings(src)
parser.parse_function_pointers(src)

Expand Down
95 changes: 59 additions & 36 deletions dm_control/autowrap/binding_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self,
consts_dict=None,
typedefs_dict=None,
hints_dict=None,
structs_dict=None,
types_dict=None,
funcs_dict=None,
strings_dict=None,
func_ptrs_dict=None,
Expand All @@ -70,7 +70,7 @@ def __init__(self,
consts_dict: Mapping from {const_name: value}.
typedefs_dict: Mapping from {type_name: ctypes_typename}.
hints_dict: Mapping from {var_name: shape_tuple}.
structs_dict: Mapping from {struct_name: Struct_instance}.
types_dict: Mapping from {type_name: type_instance}.
funcs_dict: Mapping from {func_name: Function_instance}.
strings_dict: Mapping from {var_name: StaticStringArray_instance}.
func_ptrs_dict: Mapping from {var_name: FunctionPtr_instance}.
Expand All @@ -84,8 +84,8 @@ def __init__(self,
else codegen_util.UniqueOrderedDict())
self.hints_dict = (hints_dict if hints_dict is not None
else codegen_util.UniqueOrderedDict())
self.structs_dict = (structs_dict if structs_dict is not None
else codegen_util.UniqueOrderedDict())
self.types_dict = (types_dict if types_dict is not None
else codegen_util.UniqueOrderedDict())
self.funcs_dict = (funcs_dict if funcs_dict is not None
else codegen_util.UniqueOrderedDict())
self.strings_dict = (strings_dict if strings_dict is not None
Expand Down Expand Up @@ -211,7 +211,7 @@ def get_type_from_token(self, token, parent=None):
out.sub_structs[member.name] = member

# Add to dict of unions
self.structs_dict[out.ctypes_typename] = out
self.types_dict[out.ctypes_typename] = out

# A struct declaration
elif token.members:
Expand Down Expand Up @@ -258,7 +258,7 @@ def get_type_from_token(self, token, parent=None):
out.sub_structs[member.name] = member

# Add to dict of structs
self.structs_dict[out.ctypes_typename] = out
self.types_dict[out.ctypes_typename] = out

else:

Expand Down Expand Up @@ -303,10 +303,14 @@ def get_type_from_token(self, token, parent=None):
parent, is_const)

# A struct we've already encountered
elif typename in self.structs_dict:
s = self.structs_dict[typename]
out = c_declarations.Struct(name, s.typename, s.members, s.sub_structs,
comment, parent)
elif typename in self.types_dict:
s = self.types_dict[typename]
if isinstance(s, c_declarations.FunctionPtrTypedef):
out = c_declarations.FunctionPtr(
name, token.name, s.typename, comment)
else:
out = c_declarations.Struct(name, s.typename, s.members,
s.sub_structs, comment, parent)

# Presumably this is a scalar primitive
else:
Expand Down Expand Up @@ -358,8 +362,7 @@ def parse_enums(self, src):
def parse_consts_typedefs(self, src):
"""Updates self.consts_dict, self.typedefs_dict."""
parser = (header_parsing.COND_DECL |
header_parsing.UNCOND_DECL |
header_parsing.FUNCTION_PTR_TYPE_DECL)
header_parsing.UNCOND_DECL)
for tokens, _, _ in parser.scanString(src):
self.recurse_into_conditionals(tokens)

Expand All @@ -375,11 +378,7 @@ def recurse_into_conditionals(self, tokens):
self.recurse_into_conditionals(token.if_false)
# One or more declarations
else:
# A type declaration for a function pointer.
if token.arguments:
self.typedefs_dict.update(
{token.typename: header_parsing.CTYPES_FUNCTION_PTR})
elif token.typename:
if token.typename:
self.typedefs_dict.update({token.name: token.typename})
elif token.value:
value = codegen_util.try_coerce_to_num(token.value)
Expand All @@ -391,12 +390,21 @@ def recurse_into_conditionals(self, tokens):
else:
self.consts_dict.update({token.name: True})

def parse_structs(self, src):
"""Updates self.structs_dict."""
parser = header_parsing.NESTED_STRUCTS
def parse_structs_and_function_pointer_typedefs(self, src):
"""Updates self.types_dict."""
parser = (header_parsing.NESTED_STRUCTS |
header_parsing.FUNCTION_PTR_TYPE_DECL)
for tokens, _, _ in parser.scanString(src):
for token in tokens:
self.get_type_from_token(token)
if token.return_type:
# This is a function type declaration.
self.types_dict[token.typename] = c_declarations.FunctionPtrTypedef(
token.typename,
self.get_type_from_token(token.return_type),
tuple(self.get_type_from_token(arg) for arg in token.arguments))
else:
# This is a struct or a union.
self.get_type_from_token(token)

def parse_functions(self, src):
"""Updates self.funcs_dict."""
Expand Down Expand Up @@ -434,7 +442,8 @@ def parse_function_pointers(self, src):
for token, _, _ in parser.scanString(src):
name = codegen_util.mangle_varname(token.name)
self.func_ptrs_dict[name] = c_declarations.FunctionPtr(
name, symbol_name=token.name)
name, symbol_name=token.name,
type_name=token.typename, comment=token.comment)

# Code generation methods
# ----------------------------------------------------------------------------
Expand Down Expand Up @@ -487,15 +496,16 @@ def write_enums(self, fname):
f.write("\n" + codegen_util.comment_line("End of generated code"))

def write_types(self, fname):
"""Write ctypes struct declarations."""
"""Write ctypes struct and function type declarations."""
imports = [
"import ctypes",
]
with open(fname, "w") as f:
f.write(self.make_header(imports))
f.write(codegen_util.comment_line("ctypes struct and union declarations"))
for struct in six.itervalues(self.structs_dict):
f.write("\n" + struct.ctypes_decl)
f.write(codegen_util.comment_line(
"ctypes struct, union, and function type declarations"))
for type_decl in six.itervalues(self.types_dict):
f.write("\n" + type_decl.ctypes_decl)
f.write("\n" + codegen_util.comment_line("End of generated code"))

def write_wrappers(self, fname):
Expand All @@ -510,9 +520,9 @@ def write_wrappers(self, fname):
]
f.write(self.make_header(imports))
f.write(codegen_util.comment_line("Low-level wrapper classes"))
for struct_or_union in six.itervalues(self.structs_dict):
if isinstance(struct_or_union, c_declarations.Struct):
f.write("\n" + struct_or_union.wrapper_class)
for type_decl in six.itervalues(self.types_dict):
if isinstance(type_decl, c_declarations.Struct):
f.write("\n" + type_decl.wrapper_class)
f.write("\n" + codegen_util.comment_line("End of generated code"))

def write_funcs_and_globals(self, fname):
Expand Down Expand Up @@ -542,17 +552,30 @@ def write_funcs_and_globals(self, fname):
for string_arr in six.itervalues(self.strings_dict):
f.write(string_arr.ctypes_var_decl(cdll_name="mjlib"))

f.write("\n" + codegen_util.comment_line("Function pointers"))
f.write("\n" + codegen_util.comment_line("Callback function pointers"))

fields = [repr(name) for name in self.func_ptrs_dict.keys()]
fields = ["'_{0}'".format(func_ptr.name)
for func_ptr in self.func_ptrs_dict.values()]
values = [func_ptr.ctypes_var_decl(cdll_name="mjlib")
for func_ptr in self.func_ptrs_dict.values()]
f.write(textwrap.dedent("""
function_pointers = collections.namedtuple(
'FunctionPointers',
[{0}]
)({1})
""").format(",\n ".join(fields), ",\n ".join(values)))
class _Callbacks(object):
__slots__ = [
{0}
]
def __init__(self):
{1}
""").format(",\n ".join(fields), "\n ".join(values)))

indent = codegen_util.Indenter()
with indent:
for func_ptr in self.func_ptrs_dict.values():
f.write(indent(func_ptr.getters_setters_with_custom_prefix("self._")))

f.write("\n\ncallbacks = _Callbacks() # pylint: disable=invalid-name")
f.write("\ndel _Callbacks\n")

f.write("\n" + codegen_util.comment_line("End of generated code"))

Expand Down
53 changes: 49 additions & 4 deletions dm_control/autowrap/c_declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,15 +461,60 @@ def ctypes_var_decl(self, cdll_name=""):
self.name, ptr_str, cdll_name, self.symbol_name)


class FunctionPtrTypedef(CDeclBase):
"""A type declaration for a C function pointer."""

def __init__(self, typename, return_type, argument_types):
super(FunctionPtrTypedef, self).__init__(
typename=typename, return_type=return_type,
argument_types=argument_types)

@property
def ctypes_decl(self):
"""Generates a ctypes.CFUNCTYPE declaration for self."""
types = (self.return_type,) + self.argument_types
types_decl = ", ".join(t.arg for t in types)
return "{0} = ctypes.CFUNCTYPE({1})".format(self.typename, types_decl)


class FunctionPtr(CDeclBase):
"""A pointer to an externally defined C function."""

def __init__(self, name, symbol_name, type_name=None):
def __init__(self, name, symbol_name, type_name, comment=""):
super(FunctionPtr, self).__init__(
name=name, symbol_name=symbol_name, type_name=type_name)
name=name, symbol_name=symbol_name,
type_name=type_name, comment=comment)

@property
def ctypes_field_decl(self):
"""Generates a declaration for self as a field of a ctypes.Structure."""
return "('{0.name}', {0.type_name})".format(self) # pylint: disable=missing-format-attribute

def ctypes_var_decl(self, cdll_name=""):
"""Generates a ctypes export statement."""

return "ctypes.c_void_p.in_dll({0}, {1!r})".format(
cdll_name, self.symbol_name)
return "self._{0} = ctypes.c_void_p.in_dll({1}, {2!r})".format(
self.name, cdll_name, self.symbol_name)

def getters_setters_with_custom_prefix(self, prefix):
return textwrap.dedent("""
@property
def {0.name}(self):
if {1}{0.name}.value:
return {0.type_name}({1}{0.name}.value)
else:
return None
@{0.name}.setter
def {0.name}(self, value):
new_func_ptr, wrapped_pyfunc = util.cast_func_to_c_void_p(
value, {0.type_name})
# Prevents wrapped_pyfunc from being inadvertently garbage collected.
{1}{0.name}._wrapped_pyfunc = wrapped_pyfunc
{1}{0.name}.value = new_func_ptr.value
""".format(self, prefix)) # pylint: disable=missing-format-attribute

@property
def getters_setters(self):
"""Populates a Python class with getter & setter methods for self."""
return self.getters_setters_with_custom_prefix(prefix="self._ptr.contents.")
3 changes: 1 addition & 2 deletions dm_control/autowrap/header_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

NONE = "None"
CTYPES_CHAR = "ctypes.c_char"
CTYPES_FUNCTION_PTR = "ctypes.c_void_p"

C_TO_CTYPES = {
# integers
Expand Down Expand Up @@ -292,7 +291,7 @@ def _nested_if_else(if_, pred, else_, endif, match_if_true, match_if_false):
FUNCTION_PTR_TYPE_DECL = pp.Group(
pp.Optional(MULTILINE_COMMENT("comment")) +
TYPEDEF +
(NATIVE_TYPENAME | NAME)("return_typename") +
RET("return_type") +
LPAREN +
PTR +
NAME("typename") +
Expand Down
Loading

0 comments on commit 5cc8c9b

Please sign in to comment.