Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions gtwrap/pybind_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,17 @@ def _wrap_method(self,
method,
(parser.StaticMethod, instantiator.InstantiatedStaticMethod))
return_void = method.return_type.is_void()
return_ref = getattr(
getattr(method.return_type, 'type1', None), 'is_ref', False)

# For methods returning const T&, use reference_internal policy
# to avoid unnecessary copies and keep the returned reference alive.
if return_ref and is_method:
lambda_ret = ' -> const auto&'
ref_policy = ', py::return_value_policy::reference_internal'
else:
lambda_ret = ''
ref_policy = ''

caller = cpp_class + "::" if not is_method else "self->"
function_call = ('{opt_return} {caller}{method_name}'
Expand All @@ -263,10 +274,10 @@ def _wrap_method(self,

result = (
'{prefix}.{cdef}("{py_method}",'
'[]({opt_self}{opt_comma}{args_signature_with_names}){{'
'[]({opt_self}{opt_comma}{args_signature_with_names}){lambda_ret}{{'
'{function_call}'
'}}'
'{py_args_names}{docstring}){suffix}'.format(
'{ref_policy}{py_args_names}{docstring}){suffix}'.format(
prefix=prefix,
cdef="def_static" if is_static else "def",
py_method=py_method,
Expand All @@ -275,7 +286,9 @@ def _wrap_method(self,
opt_comma=', '
if is_method and args_signature_with_names else '',
args_signature_with_names=args_signature_with_names,
lambda_ret=lambda_ret,
function_call=function_call,
ref_policy=ref_policy,
py_args_names=py_args_names,
suffix=suffix,
# Try to get the function's docstring from the Doxygen XML.
Expand Down
4 changes: 2 additions & 2 deletions tests/expected/python/class_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ PYBIND11_MODULE(class_py, m_) {
.def("return_matrix1",[](Test* self, const gtsam::Matrix& value){return self->return_matrix1(value);}, py::arg("value"))
.def("return_vector2",[](Test* self, const gtsam::Vector& value){return self->return_vector2(value);}, py::arg("value"))
.def("return_matrix2",[](Test* self, const gtsam::Matrix& value){return self->return_matrix2(value);}, py::arg("value"))
.def("return_vector2",[](Test* self, const gtsam::Vector& value){return self->return_vector2(value);}, py::arg("value"))
.def("return_matrix2",[](Test* self, const gtsam::Matrix& value){return self->return_matrix2(value);}, py::arg("value"))
.def("return_vector2",[](Test* self, const gtsam::Vector& value) -> const auto&{return self->return_vector2(value);}, py::return_value_policy::reference_internal, py::arg("value"))
.def("return_matrix2",[](Test* self, const gtsam::Matrix& value) -> const auto&{return self->return_matrix2(value);}, py::return_value_policy::reference_internal, py::arg("value"))
.def("arg_EigenConstRef",[](Test* self, const gtsam::Matrix& value){ self->arg_EigenConstRef(value);}, py::arg("value"))
.def("push_back",[](Test* self, gtsam::Key key){ self->push_back(key);}, py::arg("key"))
.def("return_field",[](Test* self, const Test& t){return self->return_field(t);}, py::arg("t"))
Expand Down
36 changes: 36 additions & 0 deletions tests/test_pybind_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,41 @@ def test_enum(self):
self.compare_and_diff('enum_pybind.cpp', output)


def test_const_ref_return_policy(self):
"""Test that methods returning const T& emit reference_internal policy.

Without this policy, pybind11 defaults to copying the returned reference.
With the policy, the binding keeps the reference alive via the parent object.

Expected emitted code difference:
Before: [](Cls* self, ...){return self->method(...);}, py::arg(...))
After: [](Cls* self, ...) -> const auto&{return self->method(...);},
py::return_value_policy::reference_internal, py::arg(...))
"""
source = osp.join(self.INTERFACE_DIR, 'class.i')
output = self.wrap_content([source], 'class_py',
self.PYTHON_ACTUAL_DIR)

with open(output, 'r') as f:
content = f.read()

# const Vector& return_vector2 should have reference_internal
self.assertIn('-> const auto&{return self->return_vector2', content)
self.assertIn('py::return_value_policy::reference_internal', content)

# const Matrix& return_matrix2 should also have reference_internal
self.assertIn('-> const auto&{return self->return_matrix2', content)

# Non-ref returns (e.g. return_vector1 which returns by value) should NOT
lines = content.split('\n')
for line in lines:
if 'return_vector1' in line:
self.assertNotIn('reference_internal', line)
self.assertNotIn('-> const auto&', line)
if 'return_matrix1' in line:
self.assertNotIn('reference_internal', line)
self.assertNotIn('-> const auto&', line)


if __name__ == '__main__':
unittest.main()