Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/set_setup_requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

PYPROJECT_FILE = ROOT / 'pyproject.toml'
PYBIND11_GIT_URL = 'https://github.com/pybind/pybind11.git'
PYBIND11_PACKAGE = f'pybind11 @ git+{PYBIND11_GIT_URL}#egg=pybind11'
PYBIND11_PACKAGE = f'pybind11 @ git+{PYBIND11_GIT_URL}'


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ repos:
args: [--ignore-case]
files: ^docs/source/spelling_wordlist\.txt$
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v22.1.4
rev: v22.1.5
hooks:
- id: clang-format
- repo: https://github.com/cpplint/cpplint
Expand All @@ -50,7 +50,7 @@ repos:
- id: codespell
additional_dependencies: [".[toml]"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.20.2
rev: v2.0.0
hooks:
- id: mypy
exclude: |
Expand Down
2 changes: 1 addition & 1 deletion optree/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def decorator(cls: _TypeT) -> _TypeT:
if namespace == '':
namespace = GLOBAL_NAMESPACE

cls = dataclasses.dataclass(cls, **kwargs) # type: ignore[assignment]
cls = dataclasses.dataclass(cls, **kwargs)
return register_node(cls, namespace=namespace)


Expand Down
23 changes: 14 additions & 9 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2668,15 +2668,20 @@ def tree_flatten_one_level(
if handler is None:
raise ValueError(f'Cannot flatten leaf-type: {node_type} (node: {tree!r}).')

flattened = tuple(handler.flatten_func(tree))
if len(flattened) == 2:
flattened = (*flattened, None)
elif len(flattened) != 3:
flattened: tuple[Iterable[PyTree[T]], MetaData, Iterable[Any] | None]
returned: (
tuple[Iterable[PyTree[T]], MetaData, Iterable[Any] | None]
| tuple[Iterable[PyTree[T]], MetaData]
) = tuple(handler.flatten_func(tree)) # type: ignore[assignment]
if len(returned) == 2:
flattened = (*returned, None)
elif len(returned) != 3:
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} should return a 2- or 3-tuple, '
f'got {len(flattened)}.',
f'got {len(returned)}.',
)
flattened: tuple[Iterable[PyTree[T]], MetaData, Iterable[Any] | None]
else:
flattened = returned
children, metadata, entries = flattened
children = list(children)
entries = tuple(range(len(children)) if entries is None else entries)
Expand Down Expand Up @@ -3408,7 +3413,7 @@ def treespec_deque(
>>> treespec_deque([treespec_leaf(), treespec_leaf(), treespec_none()], maxlen=5)
PyTreeSpec(deque([*, *, None], maxlen=5))
>>> treespec_deque()
PyTreeSpec(deque([]))
PyTreeSpec(deque())
>>> treespec_deque([treespec_leaf(), treespec_tuple([treespec_leaf(), treespec_leaf()])])
PyTreeSpec(deque([*, (*, *)]))
>>> treespec_deque([treespec_leaf(), tree_structure({'a': 1, 'b': 2})], maxlen=5)
Expand Down Expand Up @@ -3656,7 +3661,7 @@ def helper( # pylint: disable=too-many-locals
return # don't look for more errors in this subtree

# If the keys agree, we should ensure that the children are in the same order:
full_tree_children = [full_subtree[k] for k in prefix_tree_keys] # type: ignore[misc]
full_tree_children = [full_subtree[k] for k in prefix_tree_keys]

if len(prefix_tree_children) != len(full_tree_children):
yield lambda name: ValueError(
Expand Down Expand Up @@ -3729,6 +3734,6 @@ def helper( # pylint: disable=too-many-locals
), f'equal pytree nodes gave different keys: {entries} and {entries_}'
# pylint: disable-next=invalid-name
for e, t1, t2 in zip(entries, prefix_tree_children, full_tree_children):
yield from helper(accessor + e, t1, t2)
yield from helper(accessor + e, t1, t2) # type: ignore[arg-type]

return list(helper(PyTreeAccessor(), prefix_tree, full_tree))
2 changes: 1 addition & 1 deletion optree/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def count(self, key: Any, /) -> int:
"""Emulate sequence-like behavior."""
raise NotImplementedError

def get(self, key: Any, /, default: S | None = None) -> PyTree[T] | T | S | None:
def get(self, key: Any, default: S | None = None, /) -> PyTree[T] | T | S | None:
"""Emulate mapping-like behavior."""
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def cmake_context(
f'Could not run `{cmake}` directly. '
'Unset the `PYTHONPATH` environment variable in the build environment.',
)
spawn_context = unset_python_path # type: ignore[assignment]
spawn_context = unset_python_path
with unset_python_path():
# CMake in the parent virtual environment
output = subprocess.check_output( # noqa: S603
Expand Down
10 changes: 8 additions & 2 deletions src/treespec/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,15 @@ std::string PyTreeSpec::ToStringImpl() const {
}

case PyTreeKind::Deque: {
sstream << "deque([" << children << "]";
sstream << "deque(";
if (node.arity > 0) [[likely]] {
sstream << "[" << children << "]";
}
if (!node.node_data.is_none()) [[unlikely]] {
sstream << ", maxlen=" << PyRepr(node.node_data);
if (node.arity > 0) [[likely]] {
sstream << ", ";
}
sstream << "maxlen=" << PyRepr(node.node_data);
}
sstream << ")";
break;
Expand Down
4 changes: 2 additions & 2 deletions tests/concurrent/test_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import itertools
import pickle
import weakref
from collections import OrderedDict, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed

import pytest
Expand All @@ -28,6 +27,7 @@
from helpers import (
GLOBAL_NAMESPACE,
PYPY,
STANDARD_DICT_TYPES,
TREES,
WASM,
Py_DEBUG,
Expand Down Expand Up @@ -324,7 +324,7 @@ def check3():
actual = pickle.loads(expected_serialized)
concurrent_run(check1)
concurrent_run(check2)
if expected.type in {dict, OrderedDict, defaultdict}:
if expected.type in STANDARD_DICT_TYPES:
concurrent_run(check3)


Expand Down
11 changes: 6 additions & 5 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Py_GIL_DISABLED,
get_registry_size,
)
from optree.ops import STANDARD_DICT_TYPES as STANDARD_DICT_TYPES
from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE
from optree.registry import _NODETYPE_REGISTRY as NODETYPE_REGISTRY

Expand Down Expand Up @@ -241,7 +242,7 @@ def is_dict(dct):
def is_primitive_collection(obj):
if type(obj) in {tuple, list, deque}:
return all(isinstance(item, (int, float, str, bool, type(None))) for item in obj)
if type(obj) in {dict, OrderedDict, defaultdict}:
if type(obj) in STANDARD_DICT_TYPES:
return all(isinstance(value, (int, float, str, bool, type(None))) for value in obj.values())
return False

Expand Down Expand Up @@ -1552,8 +1553,8 @@ def __next__(self):
'PyTreeSpec(defaultdict(None, {}))',
"PyTreeSpec(defaultdict(<class 'int'>, {}))",
"PyTreeSpec(defaultdict(<class 'dict'>, {'baz': *, 'foo': *, 'something': *}))",
'PyTreeSpec(deque([]))',
'PyTreeSpec(deque([], maxlen=0))',
'PyTreeSpec(deque())',
'PyTreeSpec(deque(maxlen=0))',
'PyTreeSpec(deque([None, *, *]))',
'PyTreeSpec(deque([None, *], maxlen=2))',
"PyTreeSpec(CustomTreeNode(MyDict[['foo', 'baz']], [CustomTreeNode(MyDict[['c', 'b', 'a']], [None, *, *]), *]))",
Expand Down Expand Up @@ -1595,8 +1596,8 @@ def __next__(self):
'PyTreeSpec(defaultdict(None, {}), NoneIsLeaf)',
"PyTreeSpec(defaultdict(<class 'int'>, {}), NoneIsLeaf)",
"PyTreeSpec(defaultdict(<class 'dict'>, {'baz': *, 'foo': *, 'something': *}), NoneIsLeaf)",
'PyTreeSpec(deque([]), NoneIsLeaf)',
'PyTreeSpec(deque([], maxlen=0), NoneIsLeaf)',
'PyTreeSpec(deque(), NoneIsLeaf)',
'PyTreeSpec(deque(maxlen=0), NoneIsLeaf)',
'PyTreeSpec(deque([*, *, *]), NoneIsLeaf)',
'PyTreeSpec(deque([*, *], maxlen=2), NoneIsLeaf)',
"PyTreeSpec(CustomTreeNode(MyDict[['foo', 'baz']], [CustomTreeNode(MyDict[['c', 'b', 'a']], [*, *, *]), *]), NoneIsLeaf)",
Expand Down
9 changes: 5 additions & 4 deletions tests/test_prefix_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import optree
from helpers import (
GLOBAL_NAMESPACE,
STANDARD_DICT_TYPES,
TREES,
CustomTuple,
FlatCache,
Expand Down Expand Up @@ -484,10 +485,10 @@ def build_subtree(x):
return

def shuffle_dictionary(x):
if type(x) in {dict, OrderedDict, defaultdict}:
if type(x) in STANDARD_DICT_TYPES:
items = list(x.items())
random.shuffle(items)
dict_type = random.choice([dict, OrderedDict, defaultdict])
dict_type = random.choice(list(STANDARD_DICT_TYPES))
if dict_type is defaultdict:
return defaultdict(getattr(x, 'default_factory', int), items)
return dict_type(items)
Expand All @@ -496,7 +497,7 @@ def shuffle_dictionary(x):
shuffled_tree = optree.tree_map(
shuffle_dictionary,
tree,
is_leaf=lambda x: type(x) in {dict, OrderedDict, defaultdict},
is_leaf=lambda x: type(x) in STANDARD_DICT_TYPES,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
Expand All @@ -508,7 +509,7 @@ def shuffle_dictionary(x):
shuffled_suffix_tree = optree.tree_map(
shuffle_dictionary,
suffix_tree,
is_leaf=lambda x: type(x) in {dict, OrderedDict, defaultdict},
is_leaf=lambda x: type(x) in STANDARD_DICT_TYPES,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_treespec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
GLOBAL_NAMESPACE,
NAMESPACED_TREE,
PYPY,
STANDARD_DICT_TYPES,
TEST_ROOT,
TREE_STRINGS,
TREES,
Expand Down Expand Up @@ -511,7 +512,7 @@ def test_treespec_pickle_roundtrip(
else:
actual = pickle.loads(pickle.dumps(expected))
assert actual == expected
if expected.type in {dict, OrderedDict, defaultdict}:
if expected.type in STANDARD_DICT_TYPES:
assert list(optree.tree_unflatten(actual, range(len(actual)))) == list(
optree.tree_unflatten(expected, range(len(expected))),
)
Expand Down
Loading