Skip to content

Commit

Permalink
refactor: use py::type::handle_of(obj) rather than obj.get_type() (
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Mar 30, 2023
1 parent 577ad3e commit ddfaaf9
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 38 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
hooks:
- id: clang-format
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.259
rev: v0.0.260
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -38,7 +38,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/asottile/pyupgrade
Expand Down
8 changes: 2 additions & 6 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,18 +529,14 @@ def compare( # pylint: disable=too-many-locals
speedups = {name: time / base_time for name, time in times_us.items()}
best_speedups = {name: time / best_time for name, time in times_us.items()}
labels = {
name: (cmark if speedup == 1.0 else tie if speedup < 1.1 else ' ')
name: cmark if speedup == 1.0 else (tie if speedup < 1.1 else ' ')
for name, speedup in best_speedups.items()
}
colors = {
name: (
'green'
if speedup == 1.0
else 'cyan'
if speedup < 1.1
else 'yellow'
if speedup < 4.0
else 'red'
else ('cyan' if speedup < 1.1 else ('yellow' if speedup < 4.0 else 'red'))
)
for name, speedup in best_speedups.items()
}
Expand Down
25 changes: 13 additions & 12 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,22 +278,23 @@ inline void AssertExact<py::dict>(const py::handle& object) {
}

inline void AssertExactOrderedDict(const py::handle& object) {
if (!object.get_type().is(PyOrderedDictTypeObject)) [[unlikely]] {
if (!py::type::handle_of(object).is(PyOrderedDictTypeObject)) [[unlikely]] {
throw py::value_error(absl::StrFormat(
"Expected an instance of collections.OrderedDict, got %s.", py::repr(object)));
}
}

inline void AssertExactDefaultDict(const py::handle& object) {
if (!object.get_type().is(PyDefaultDictTypeObject)) [[unlikely]] {
if (!py::type::handle_of(object).is(PyDefaultDictTypeObject)) [[unlikely]] {
throw py::value_error(absl::StrFormat(
"Expected an instance of collections.defaultdict, got %s.", py::repr(object)));
}
}

inline void AssertExactStandardDict(const py::handle& object) {
if (!(PyDict_CheckExact(object.ptr()) || object.get_type().is(PyOrderedDictTypeObject) ||
object.get_type().is(PyDefaultDictTypeObject))) [[unlikely]] {
if (!(PyDict_CheckExact(object.ptr()) ||
py::type::handle_of(object).is(PyOrderedDictTypeObject) ||
py::type::handle_of(object).is(PyDefaultDictTypeObject))) [[unlikely]] {
throw py::value_error(
absl::StrFormat("Expected an instance of "
"dict, collections.OrderedDict, or collections.defaultdict, "
Expand All @@ -303,7 +304,7 @@ inline void AssertExactStandardDict(const py::handle& object) {
}

inline void AssertExactDeque(const py::handle& object) {
if (!object.get_type().is(PyDequeTypeObject)) [[unlikely]] {
if (!py::type::handle_of(object).is(PyDequeTypeObject)) [[unlikely]] {
throw py::value_error(absl::StrFormat("Expected an instance of collections.deque, got %s.",
py::repr(object)));
}
Expand Down Expand Up @@ -334,10 +335,10 @@ inline bool IsNamedTupleClass(const py::handle& type) {
return PyType_Check(type.ptr()) && IsNamedTupleClassImpl(type);
}
inline bool IsNamedTupleInstance(const py::handle& object) {
return IsNamedTupleClass(object.get_type());
return IsNamedTupleClass(py::type::handle_of(object));
}
inline bool IsNamedTuple(const py::handle& object) {
py::handle type = (PyType_Check(object.ptr()) ? object : object.get_type());
py::handle type = (PyType_Check(object.ptr()) ? object : py::type::handle_of(object));
return IsNamedTupleClass(type);
}
inline void AssertExactNamedTuple(const py::handle& object) {
Expand All @@ -355,7 +356,7 @@ inline py::tuple NamedTupleGetFields(const py::handle& object) {
py::repr(object)));
}
} else [[likely]] {
type = object.get_type();
type = py::type::handle_of(object);
if (!IsNamedTupleClass(type)) [[unlikely]] {
throw py::type_error(absl::StrFormat(
"Expected an instance of collections.namedtuple type, got %s.", py::repr(object)));
Expand Down Expand Up @@ -391,10 +392,10 @@ inline bool IsStructSequenceClass(const py::handle& type) {
return PyType_Check(type.ptr()) && IsStructSequenceClassImpl(type);
}
inline bool IsStructSequenceInstance(const py::handle& object) {
return IsStructSequenceClass(object.get_type());
return IsStructSequenceClass(py::type::handle_of(object));
}
inline bool IsStructSequence(const py::handle& object) {
py::handle type = (PyType_Check(object.ptr()) ? object : object.get_type());
py::handle type = (PyType_Check(object.ptr()) ? object : py::type::handle_of(object));
return IsStructSequenceClass(type);
}
inline void AssertExactStructSequence(const py::handle& object) {
Expand All @@ -412,7 +413,7 @@ inline py::tuple StructSequenceGetFields(const py::handle& object) {
absl::StrFormat("Expected a PyStructSequence type, got %s.", py::repr(object)));
}
} else [[likely]] {
type = object.get_type();
type = py::type::handle_of(object);
if (!IsStructSequenceClass(type)) [[unlikely]] {
throw py::type_error(absl::StrFormat(
"Expected an instance of PyStructSequence type, got %s.", py::repr(object)));
Expand Down Expand Up @@ -442,7 +443,7 @@ inline void TotalOrderSort(py::list& list) { // NOLINT[runtime/references]
try {
// Sort with `(f'{o.__class__.__module__}.{o.__class__.__qualname__}', o)`
auto sort_key_fn = py::cpp_function([](const py::object& o) {
py::handle t = o.get_type();
py::handle t = py::type::handle_of(o);
py::str qualname{absl::StrFormat(
"%s.%s",
static_cast<std::string>(py::getattr(t, "__module__").cast<py::str>()),
Expand Down
22 changes: 11 additions & 11 deletions src/treespec/flatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle,
case PyTreeKind::StructSequence: {
auto tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = GET_SIZE<py::tuple>(tuple);
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
node.node_data = py::reinterpret_borrow<py::object>(py::type::handle_of(tuple));
for (ssize_t i = 0; i < node.arity; ++i) {
recurse(GET_ITEM_HANDLE<py::tuple>(tuple, i));
}
Expand Down Expand Up @@ -283,7 +283,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle,
case PyTreeKind::StructSequence: {
auto tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = GET_SIZE<py::tuple>(tuple);
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
node.node_data = py::reinterpret_borrow<py::object>(py::type::handle_of(tuple));
for (ssize_t i = 0; i < node.arity; ++i) {
recurse(GET_ITEM_HANDLE<py::tuple>(tuple, i), py::int_(i));
}
Expand Down Expand Up @@ -506,11 +506,11 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
GET_SIZE<py::tuple>(tuple),
py::repr(object)));
}
if (object.get_type().not_equal(node.node_data)) [[unlikely]] {
if (py::type::handle_of(object).not_equal(node.node_data)) [[unlikely]] {
throw py::value_error(absl::StrFormat(
"namedtuple type mismatch; expected type: %s, got type: %s; tuple: %s.",
py::repr(node.node_data),
py::repr(object.get_type()),
py::repr(py::type::handle_of(object)),
py::repr(object)));
}
for (ssize_t i = 0; i < node.arity; ++i) {
Expand Down Expand Up @@ -545,12 +545,12 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
GET_SIZE<py::tuple>(tuple),
py::repr(object)));
}
if (object.get_type().not_equal(node.node_data)) [[unlikely]] {
if (py::type::handle_of(object).not_equal(node.node_data)) [[unlikely]] {
throw py::value_error(
absl::StrFormat("PyStructSequence type mismatch; "
"expected type: %s, got type: %s; tuple: %s.",
py::repr(node.node_data),
py::repr(object.get_type()),
py::repr(py::type::handle_of(object)),
py::repr(object)));
}
for (ssize_t i = 0; i < node.arity; ++i) {
Expand All @@ -562,17 +562,17 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
case PyTreeKind::Custom: {
const PyTreeTypeRegistry::Registration* registration = nullptr;
if (m_none_is_leaf) [[unlikely]] {
registration =
PyTreeTypeRegistry::Lookup<NONE_IS_LEAF>(object.get_type(), m_namespace);
registration = PyTreeTypeRegistry::Lookup<NONE_IS_LEAF>(
py::type::handle_of(object), m_namespace);
} else [[likely]] {
registration =
PyTreeTypeRegistry::Lookup<NONE_IS_NODE>(object.get_type(), m_namespace);
registration = PyTreeTypeRegistry::Lookup<NONE_IS_NODE>(
py::type::handle_of(object), m_namespace);
}
if (registration != node.custom) [[unlikely]] {
throw py::value_error(absl::StrFormat(
"Custom node type mismatch; expected type: %s, got type: %s; value: %s.",
py::repr(node.custom->type),
py::repr(object.get_type()),
py::repr(py::type::handle_of(object)),
py::repr(object)));
}
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(object));
Expand Down
2 changes: 1 addition & 1 deletion src/treespec/treespec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ template <bool NoneIsLeaf>
PyTreeTypeRegistry::Registration const** custom,
const std::string& registry_namespace) {
const PyTreeTypeRegistry::Registration* registration =
PyTreeTypeRegistry::Lookup<NoneIsLeaf>(handle.get_type(), registry_namespace);
PyTreeTypeRegistry::Lookup<NoneIsLeaf>(py::type::handle_of(handle), registry_namespace);
if (registration) [[likely]] {
if (registration->kind == PyTreeKind::Custom) [[unlikely]] {
*custom = registration;
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def test_round_trip_with_flatten_up_to(tree, none_is_leaf, namespace):
[0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9],
[0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9],
[0, ((1, (2, 3)), (4, (5, 6, 7))), 8, 9],
[0, {'d': (4, (5, 6, 7)), 'c': ((1, {'b': 3, 'a': 2})), 'e': [8, 9]}],
[0, {1: (4, (5, 6, 7)), 1.1: ((1, {'b': 3, 'a': 2})), 'c': [8, 9]}],
[0, {'d': (4, (5, 6, 7)), 'c': (1, {'b': 3, 'a': 2}), 'e': [8, 9]}],
[0, {1: (4, (5, 6, 7)), 1.1: (1, {'b': 3, 'a': 2}), 'c': [8, 9]}],
[0, OrderedDict([(1, (1, (2, 3, 4))), (1.1, ((5, {'b': 7, 'a': 6}))), ('c', [8, 9])])],
],
none_is_leaf=[False, True],
Expand Down
8 changes: 4 additions & 4 deletions tests/test_treespec.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_with_namespace():
for namespace in ('', 'undefined'):
leaves, treespec = optree.tree_flatten(tree, none_is_leaf=False, namespace=namespace)
assert leaves == [tree]
assert str(treespec) == ('PyTreeSpec(*)')
assert str(treespec) == 'PyTreeSpec(*)'
paths, leaves, treespec = optree.tree_flatten_with_path(
tree,
none_is_leaf=False,
Expand All @@ -129,11 +129,11 @@ def test_with_namespace():
assert paths == [()]
assert leaves == [tree]
assert paths == treespec.paths()
assert str(treespec) == ('PyTreeSpec(*)')
assert str(treespec) == 'PyTreeSpec(*)'
for namespace in ('', 'undefined'):
leaves, treespec = optree.tree_flatten(tree, none_is_leaf=True, namespace=namespace)
assert leaves == [tree]
assert str(treespec) == ('PyTreeSpec(*, NoneIsLeaf)')
assert str(treespec) == 'PyTreeSpec(*, NoneIsLeaf)'
paths, leaves, treespec = optree.tree_flatten_with_path(
tree,
none_is_leaf=True,
Expand All @@ -142,7 +142,7 @@ def test_with_namespace():
assert paths == [()]
assert leaves == [tree]
assert paths == treespec.paths()
assert str(treespec) == ('PyTreeSpec(*, NoneIsLeaf)')
assert str(treespec) == 'PyTreeSpec(*, NoneIsLeaf)'

expected_string = "PyTreeSpec(CustomTreeNode(MyAnotherDict[['foo', 'baz']], [CustomTreeNode(MyAnotherDict[['c', 'b', 'a']], [None, *, *]), *]), namespace='namespace')"
leaves, treespec = optree.tree_flatten(tree, none_is_leaf=False, namespace='namespace')
Expand Down

0 comments on commit ddfaaf9

Please sign in to comment.