From 60a9c88e0b0b3e5b854b91d8f2581c4e6ef4b57d Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 8 Sep 2025 14:53:28 +0100 Subject: [PATCH] ENH: `lazy_xp_function` support for iterators --- src/array_api_extra/_lib/_utils/_helpers.py | 22 ++++++++++++++++----- tests/test_helpers.py | 14 +++++++++++++ tests/test_testing.py | 21 +++++++++++++++++++- 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index d177b376..6dd94a38 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -6,7 +6,7 @@ import math import pickle import types -from collections.abc import Callable, Generator, Iterable +from collections.abc import Callable, Generator, Iterable, Iterator from functools import wraps from types import ModuleType from typing import ( @@ -512,13 +512,24 @@ class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01 convert them to/from PyTrees. """ - obj: T + _obj: Any + _is_iter: bool _registered: ClassVar[bool] = False - __slots__: tuple[str, ...] = ("obj",) + __slots__: tuple[str, ...] = ("_is_iter", "_obj") def __init__(self, obj: T) -> None: # numpydoc ignore=GL08 self._register() - self.obj = obj + if isinstance(obj, Iterator): + self._obj = list(obj) + self._is_iter = True + else: + self._obj = obj + self._is_iter = False + + @property + def obj(self) -> T: # numpydoc ignore=RT01 + """Return wrapped object.""" + return iter(self._obj) if self._is_iter else self._obj @classmethod def _register(cls) -> None: # numpydoc ignore=SS06 @@ -531,7 +542,7 @@ def _register(cls) -> None: # numpydoc ignore=SS06 jax.tree_util.register_pytree_node( cls, - lambda obj: pickle_flatten(obj, jax.Array), # pyright: ignore[reportUnknownArgumentType] + lambda instance: pickle_flatten(instance, jax.Array), # pyright: ignore[reportUnknownArgumentType] lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType] ) cls._registered = True @@ -556,6 +567,7 @@ def jax_autojit( - Automatically descend into non-array return values and find ``jax.Array`` objects inside them, then rebuild them downstream of exiting the JIT, swapping the JAX tracer objects with concrete arrays. + - Returned iterators are immediately completely consumed. See Also -------- diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 9f5c924d..77ba8cd8 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,3 +1,4 @@ +from collections.abc import Iterator from types import ModuleType from typing import TYPE_CHECKING, Generic, TypeVar, cast @@ -417,3 +418,16 @@ def f(x: list[int]) -> list[int]: out = f([1, 2]) assert isinstance(out, list) assert out == [3, 4] + + def test_iterators(self, jnp: ModuleType): + @jax_autojit + def f(x: Array) -> Iterator[Array]: + return (x + i for i in range(2)) + + inp = jnp.asarray([1, 2]) + out = f(inp) + assert isinstance(out, Iterator) + xp_assert_equal(next(out), jnp.asarray([1, 2])) + xp_assert_equal(next(out), jnp.asarray([2, 3])) + with pytest.raises(StopIteration): + _ = next(out) diff --git a/tests/test_testing.py b/tests/test_testing.py index 3a93e287..be7b6103 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -1,4 +1,4 @@ -from collections.abc import Callable +from collections.abc import Callable, Iterator from types import ModuleType from typing import cast @@ -468,3 +468,22 @@ def test_patch_lazy_xp_functions_deprecated_monkeypatch( monkeypatch.undo() y = non_materializable5(x) xp_assert_equal(y, x) + + +def my_iter(x: Array) -> Iterator[Array]: + yield x[0, :] + yield x[1, :] + + +lazy_xp_function(my_iter) + + +def test_patch_lazy_xp_functions_iter(xp: ModuleType): + x = xp.asarray([[1.0, 2.0], [3.0, 4.0]]) + it = my_iter(x) + + assert isinstance(it, Iterator) + xp_assert_equal(next(it), x[0, :]) + xp_assert_equal(next(it), x[1, :]) + with pytest.raises(StopIteration): + _ = next(it)