Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use py::type::handle_of(obj) rather than obj.get_type() #49

Merged
merged 2 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
hooks:
- id: clang-format
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.259
rev: v0.0.260
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -38,7 +38,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/asottile/pyupgrade
Expand Down
8 changes: 2 additions & 6 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,18 +529,14 @@ def compare( # pylint: disable=too-many-locals
speedups = {name: time / base_time for name, time in times_us.items()}
best_speedups = {name: time / best_time for name, time in times_us.items()}
labels = {
name: (cmark if speedup == 1.0 else tie if speedup < 1.1 else ' ')
name: cmark if speedup == 1.0 else (tie if speedup < 1.1 else ' ')
for name, speedup in best_speedups.items()
}
colors = {
name: (
'green'
if speedup == 1.0
else 'cyan'
if speedup < 1.1
else 'yellow'
if speedup < 4.0
else 'red'
else ('cyan' if speedup < 1.1 else ('yellow' if speedup < 4.0 else 'red'))
)
for name, speedup in best_speedups.items()
}
Expand Down
25 changes: 13 additions & 12 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,22 +278,23 @@ inline void AssertExact<py::dict>(const py::handle& object) {
}

inline void AssertExactOrderedDict(const py::handle& object) {
if (!object.get_type().is(PyOrderedDictTypeObject)) [[unlikely]] {
if (!py::type::handle_of(object).is(PyOrderedDictTypeObject)) [[unlikely]] {
throw py::value_error(absl::StrFormat(
"Expected an instance of collections.OrderedDict, got %s.", py::repr(object)));
}
}

inline void AssertExactDefaultDict(const py::handle& object) {
if (!object.get_type().is(PyDefaultDictTypeObject)) [[unlikely]] {
if (!py::type::handle_of(object).is(PyDefaultDictTypeObject)) [[unlikely]] {
throw py::value_error(absl::StrFormat(
"Expected an instance of collections.defaultdict, got %s.", py::repr(object)));
}
}

inline void AssertExactStandardDict(const py::handle& object) {
if (!(PyDict_CheckExact(object.ptr()) || object.get_type().is(PyOrderedDictTypeObject) ||
object.get_type().is(PyDefaultDictTypeObject))) [[unlikely]] {
if (!(PyDict_CheckExact(object.ptr()) ||
py::type::handle_of(object).is(PyOrderedDictTypeObject) ||
py::type::handle_of(object).is(PyDefaultDictTypeObject))) [[unlikely]] {
throw py::value_error(
absl::StrFormat("Expected an instance of "
"dict, collections.OrderedDict, or collections.defaultdict, "
Expand All @@ -303,7 +304,7 @@ inline void AssertExactStandardDict(const py::handle& object) {
}

inline void AssertExactDeque(const py::handle& object) {
if (!object.get_type().is(PyDequeTypeObject)) [[unlikely]] {
if (!py::type::handle_of(object).is(PyDequeTypeObject)) [[unlikely]] {
throw py::value_error(absl::StrFormat("Expected an instance of collections.deque, got %s.",
py::repr(object)));
}
Expand Down Expand Up @@ -334,10 +335,10 @@ inline bool IsNamedTupleClass(const py::handle& type) {
return PyType_Check(type.ptr()) && IsNamedTupleClassImpl(type);
}
inline bool IsNamedTupleInstance(const py::handle& object) {
return IsNamedTupleClass(object.get_type());
return IsNamedTupleClass(py::type::handle_of(object));
}
inline bool IsNamedTuple(const py::handle& object) {
py::handle type = (PyType_Check(object.ptr()) ? object : object.get_type());
py::handle type = (PyType_Check(object.ptr()) ? object : py::type::handle_of(object));
return IsNamedTupleClass(type);
}
inline void AssertExactNamedTuple(const py::handle& object) {
Expand All @@ -355,7 +356,7 @@ inline py::tuple NamedTupleGetFields(const py::handle& object) {
py::repr(object)));
}
} else [[likely]] {
type = object.get_type();
type = py::type::handle_of(object);
if (!IsNamedTupleClass(type)) [[unlikely]] {
throw py::type_error(absl::StrFormat(
"Expected an instance of collections.namedtuple type, got %s.", py::repr(object)));
Expand Down Expand Up @@ -391,10 +392,10 @@ inline bool IsStructSequenceClass(const py::handle& type) {
return PyType_Check(type.ptr()) && IsStructSequenceClassImpl(type);
}
inline bool IsStructSequenceInstance(const py::handle& object) {
return IsStructSequenceClass(object.get_type());
return IsStructSequenceClass(py::type::handle_of(object));
}
inline bool IsStructSequence(const py::handle& object) {
py::handle type = (PyType_Check(object.ptr()) ? object : object.get_type());
py::handle type = (PyType_Check(object.ptr()) ? object : py::type::handle_of(object));
return IsStructSequenceClass(type);
}
inline void AssertExactStructSequence(const py::handle& object) {
Expand All @@ -412,7 +413,7 @@ inline py::tuple StructSequenceGetFields(const py::handle& object) {
absl::StrFormat("Expected a PyStructSequence type, got %s.", py::repr(object)));
}
} else [[likely]] {
type = object.get_type();
type = py::type::handle_of(object);
if (!IsStructSequenceClass(type)) [[unlikely]] {
throw py::type_error(absl::StrFormat(
"Expected an instance of PyStructSequence type, got %s.", py::repr(object)));
Expand Down Expand Up @@ -442,7 +443,7 @@ inline void TotalOrderSort(py::list& list) { // NOLINT[runtime/references]
try {
// Sort with `(f'{o.__class__.__module__}.{o.__class__.__qualname__}', o)`
auto sort_key_fn = py::cpp_function([](const py::object& o) {
py::handle t = o.get_type();
py::handle t = py::type::handle_of(o);
py::str qualname{absl::StrFormat(
"%s.%s",
static_cast<std::string>(py::getattr(t, "__module__").cast<py::str>()),
Expand Down
22 changes: 11 additions & 11 deletions src/treespec/flatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle,
case PyTreeKind::StructSequence: {
auto tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = GET_SIZE<py::tuple>(tuple);
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
node.node_data = py::reinterpret_borrow<py::object>(py::type::handle_of(tuple));
for (ssize_t i = 0; i < node.arity; ++i) {
recurse(GET_ITEM_HANDLE<py::tuple>(tuple, i));
}
Expand Down Expand Up @@ -283,7 +283,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle,
case PyTreeKind::StructSequence: {
auto tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = GET_SIZE<py::tuple>(tuple);
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
node.node_data = py::reinterpret_borrow<py::object>(py::type::handle_of(tuple));
for (ssize_t i = 0; i < node.arity; ++i) {
recurse(GET_ITEM_HANDLE<py::tuple>(tuple, i), py::int_(i));
}
Expand Down Expand Up @@ -506,11 +506,11 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
GET_SIZE<py::tuple>(tuple),
py::repr(object)));
}
if (object.get_type().not_equal(node.node_data)) [[unlikely]] {
if (py::type::handle_of(object).not_equal(node.node_data)) [[unlikely]] {
throw py::value_error(absl::StrFormat(
"namedtuple type mismatch; expected type: %s, got type: %s; tuple: %s.",
py::repr(node.node_data),
py::repr(object.get_type()),
py::repr(py::type::handle_of(object)),
py::repr(object)));
}
for (ssize_t i = 0; i < node.arity; ++i) {
Expand Down Expand Up @@ -545,12 +545,12 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
GET_SIZE<py::tuple>(tuple),
py::repr(object)));
}
if (object.get_type().not_equal(node.node_data)) [[unlikely]] {
if (py::type::handle_of(object).not_equal(node.node_data)) [[unlikely]] {
throw py::value_error(
absl::StrFormat("PyStructSequence type mismatch; "
"expected type: %s, got type: %s; tuple: %s.",
py::repr(node.node_data),
py::repr(object.get_type()),
py::repr(py::type::handle_of(object)),
py::repr(object)));
}
for (ssize_t i = 0; i < node.arity; ++i) {
Expand All @@ -562,17 +562,17 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
case PyTreeKind::Custom: {
const PyTreeTypeRegistry::Registration* registration = nullptr;
if (m_none_is_leaf) [[unlikely]] {
registration =
PyTreeTypeRegistry::Lookup<NONE_IS_LEAF>(object.get_type(), m_namespace);
registration = PyTreeTypeRegistry::Lookup<NONE_IS_LEAF>(
py::type::handle_of(object), m_namespace);
} else [[likely]] {
registration =
PyTreeTypeRegistry::Lookup<NONE_IS_NODE>(object.get_type(), m_namespace);
registration = PyTreeTypeRegistry::Lookup<NONE_IS_NODE>(
py::type::handle_of(object), m_namespace);
}
if (registration != node.custom) [[unlikely]] {
throw py::value_error(absl::StrFormat(
"Custom node type mismatch; expected type: %s, got type: %s; value: %s.",
py::repr(node.custom->type),
py::repr(object.get_type()),
py::repr(py::type::handle_of(object)),
py::repr(object)));
}
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(object));
Expand Down
2 changes: 1 addition & 1 deletion src/treespec/treespec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ template <bool NoneIsLeaf>
PyTreeTypeRegistry::Registration const** custom,
const std::string& registry_namespace) {
const PyTreeTypeRegistry::Registration* registration =
PyTreeTypeRegistry::Lookup<NoneIsLeaf>(handle.get_type(), registry_namespace);
PyTreeTypeRegistry::Lookup<NoneIsLeaf>(py::type::handle_of(handle), registry_namespace);
if (registration) [[likely]] {
if (registration->kind == PyTreeKind::Custom) [[unlikely]] {
*custom = registration;
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def test_round_trip_with_flatten_up_to(tree, none_is_leaf, namespace):
[0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9],
[0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9],
[0, ((1, (2, 3)), (4, (5, 6, 7))), 8, 9],
[0, {'d': (4, (5, 6, 7)), 'c': ((1, {'b': 3, 'a': 2})), 'e': [8, 9]}],
[0, {1: (4, (5, 6, 7)), 1.1: ((1, {'b': 3, 'a': 2})), 'c': [8, 9]}],
[0, {'d': (4, (5, 6, 7)), 'c': (1, {'b': 3, 'a': 2}), 'e': [8, 9]}],
[0, {1: (4, (5, 6, 7)), 1.1: (1, {'b': 3, 'a': 2}), 'c': [8, 9]}],
[0, OrderedDict([(1, (1, (2, 3, 4))), (1.1, ((5, {'b': 7, 'a': 6}))), ('c', [8, 9])])],
],
none_is_leaf=[False, True],
Expand Down
8 changes: 4 additions & 4 deletions tests/test_treespec.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_with_namespace():
for namespace in ('', 'undefined'):
leaves, treespec = optree.tree_flatten(tree, none_is_leaf=False, namespace=namespace)
assert leaves == [tree]
assert str(treespec) == ('PyTreeSpec(*)')
assert str(treespec) == 'PyTreeSpec(*)'
paths, leaves, treespec = optree.tree_flatten_with_path(
tree,
none_is_leaf=False,
Expand All @@ -129,11 +129,11 @@ def test_with_namespace():
assert paths == [()]
assert leaves == [tree]
assert paths == treespec.paths()
assert str(treespec) == ('PyTreeSpec(*)')
assert str(treespec) == 'PyTreeSpec(*)'
for namespace in ('', 'undefined'):
leaves, treespec = optree.tree_flatten(tree, none_is_leaf=True, namespace=namespace)
assert leaves == [tree]
assert str(treespec) == ('PyTreeSpec(*, NoneIsLeaf)')
assert str(treespec) == 'PyTreeSpec(*, NoneIsLeaf)'
paths, leaves, treespec = optree.tree_flatten_with_path(
tree,
none_is_leaf=True,
Expand All @@ -142,7 +142,7 @@ def test_with_namespace():
assert paths == [()]
assert leaves == [tree]
assert paths == treespec.paths()
assert str(treespec) == ('PyTreeSpec(*, NoneIsLeaf)')
assert str(treespec) == 'PyTreeSpec(*, NoneIsLeaf)'

expected_string = "PyTreeSpec(CustomTreeNode(MyAnotherDict[['foo', 'baz']], [CustomTreeNode(MyAnotherDict[['c', 'b', 'a']], [None, *, *]), *]), namespace='namespace')"
leaves, treespec = optree.tree_flatten(tree, none_is_leaf=False, namespace='namespace')
Expand Down