Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import math
import pickle
import types
from collections.abc import Callable, Generator, Iterable
from collections.abc import Callable, Generator, Iterable, Iterator
from functools import wraps
from types import ModuleType
from typing import (
Expand Down Expand Up @@ -512,13 +512,24 @@ class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
convert them to/from PyTrees.
"""

obj: T
_obj: Any
_is_iter: bool
_registered: ClassVar[bool] = False
__slots__: tuple[str, ...] = ("obj",)
__slots__: tuple[str, ...] = ("_is_iter", "_obj")

def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
self._register()
self.obj = obj
if isinstance(obj, Iterator):
self._obj = list(obj)
self._is_iter = True
else:
self._obj = obj
self._is_iter = False

@property
def obj(self) -> T: # numpydoc ignore=RT01
"""Return wrapped object."""
return iter(self._obj) if self._is_iter else self._obj

@classmethod
def _register(cls) -> None: # numpydoc ignore=SS06
Expand All @@ -531,7 +542,7 @@ def _register(cls) -> None: # numpydoc ignore=SS06

jax.tree_util.register_pytree_node(
cls,
lambda obj: pickle_flatten(obj, jax.Array), # pyright: ignore[reportUnknownArgumentType]
lambda instance: pickle_flatten(instance, jax.Array), # pyright: ignore[reportUnknownArgumentType]
lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
)
cls._registered = True
Expand All @@ -556,6 +567,7 @@ def jax_autojit(
- Automatically descend into non-array return values and find ``jax.Array`` objects
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
tracer objects with concrete arrays.
- Returned iterators are immediately completely consumed.

See Also
--------
Expand Down
14 changes: 14 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterator
from types import ModuleType
from typing import TYPE_CHECKING, Generic, TypeVar, cast

Expand Down Expand Up @@ -417,3 +418,16 @@ def f(x: list[int]) -> list[int]:
out = f([1, 2])
assert isinstance(out, list)
assert out == [3, 4]

def test_iterators(self, jnp: ModuleType):
@jax_autojit
def f(x: Array) -> Iterator[Array]:
return (x + i for i in range(2))

inp = jnp.asarray([1, 2])
out = f(inp)
assert isinstance(out, Iterator)
xp_assert_equal(next(out), jnp.asarray([1, 2]))
xp_assert_equal(next(out), jnp.asarray([2, 3]))
with pytest.raises(StopIteration):
_ = next(out)
21 changes: 20 additions & 1 deletion tests/test_testing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Callable
from collections.abc import Callable, Iterator
from types import ModuleType
from typing import cast

Expand Down Expand Up @@ -468,3 +468,22 @@ def test_patch_lazy_xp_functions_deprecated_monkeypatch(
monkeypatch.undo()
y = non_materializable5(x)
xp_assert_equal(y, x)


def my_iter(x: Array) -> Iterator[Array]:
yield x[0, :]
yield x[1, :]


lazy_xp_function(my_iter)


def test_patch_lazy_xp_functions_iter(xp: ModuleType):
x = xp.asarray([[1.0, 2.0], [3.0, 4.0]])
it = my_iter(x)

assert isinstance(it, Iterator)
xp_assert_equal(next(it), x[0, :])
xp_assert_equal(next(it), x[1, :])
with pytest.raises(StopIteration):
_ = next(it)