From 4ffc1136b9607ee93121229496e2f4a98c188d42 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 28 Oct 2025 12:02:19 -0500 Subject: [PATCH 1/3] Add more special functions to BaseFakeNumpyNamespace And use positional-only where appropriate --- arraycontext/fake_numpy.py | 185 ++++++++++++++--------- arraycontext/impl/jax/fake_numpy.py | 4 +- arraycontext/impl/numpy/fake_numpy.py | 4 +- arraycontext/impl/pyopencl/fake_numpy.py | 4 +- arraycontext/impl/pytato/fake_numpy.py | 4 +- 5 files changed, 118 insertions(+), 83 deletions(-) diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index b7b47017..0f4f967f 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -313,54 +313,101 @@ def inner(a: ArrayOrScalar) -> ArrayOrScalar: # as attributes, making __getattr__ fail to retrieve the intended function. def broadcast_to(self, - array: ArrayOrContainerOrScalar, + array: ArrayOrContainerOrScalar, /, shape: tuple[int, ...] ) -> ArrayOrContainerOrScalar: ... def concatenate(self, - arrays: Sequence[ArrayOrContainerT], + arrays: Sequence[ArrayOrContainerT], /, axis: int = 0 ) -> ArrayOrContainerT: ... def stack(self, - arrays: Sequence[ArrayOrContainerT], + arrays: Sequence[ArrayOrContainerT], /, axis: int = 0 ) -> ArrayOrContainerT: ... def ravel(self, - a: ArrayOrContainerOrScalarT, + a: ArrayOrContainerOrScalarT, /, order: OrderCF = "C" ) -> ArrayOrContainerOrScalarT: ... def array_equal(self, a: ArrayOrContainerOrScalar, - b: ArrayOrContainerOrScalar + b: ArrayOrContainerOrScalar, + / ) -> Array: ... - def sqrt(self, - a: ArrayOrContainerOrScalarT, + def sqrt(self, a: ArrayOrContainerOrScalarT, /, + ) -> ArrayOrContainerOrScalarT: ... + def abs(self, a: ArrayOrContainerOrScalarT, /, ) -> ArrayOrContainerOrScalarT: ... - def abs(self, - a: ArrayOrContainerOrScalarT, + def sin(self, a: ArrayOrContainerOrScalarT, /, + ) -> ArrayOrContainerOrScalarT: ... + def cos(self, a: ArrayOrContainerOrScalarT, /, + ) -> ArrayOrContainerOrScalarT: ... + def tan(self, a: ArrayOrContainerOrScalarT, /, + ) -> ArrayOrContainerOrScalarT: ... + def arcsin(self, a: ArrayOrContainerOrScalarT, /, + ) -> ArrayOrContainerOrScalarT: ... + def arccos(self, a: ArrayOrContainerOrScalarT, /, + ) -> ArrayOrContainerOrScalarT: ... + def arctan(self, a: ArrayOrContainerOrScalarT, /, ) -> ArrayOrContainerOrScalarT: ... - def sin(self, + def hypot(self, a: ArrayOrContainerOrScalarT, + b: ArrayOrContainerOrScalarT, + /, ) -> ArrayOrContainerOrScalarT: ... - - def cos(self, + def arctan2(self, a: ArrayOrContainerOrScalarT, + b: ArrayOrContainerOrScalarT, + /, ) -> ArrayOrContainerOrScalarT: ... - def floor(self, - a: ArrayOrContainerOrScalarT, + def deg2rad(self, a: ArrayOrContainerOrScalarT, /, + ) -> ArrayOrContainerOrScalarT: ... + def rad2deg(self, a: ArrayOrContainerOrScalarT, /, ) -> ArrayOrContainerOrScalarT: ... - def ceil(self, - a: ArrayOrContainerOrScalarT, + def sinh(self, a: ArrayOrContainerOrScalarT, /, + ) -> ArrayOrContainerOrScalarT: ... + def cosh(self, a: ArrayOrContainerOrScalarT, /, + ) -> ArrayOrContainerOrScalarT: ... + def tanh(self, a: ArrayOrContainerOrScalarT, /, + ) -> ArrayOrContainerOrScalarT: ... + def arcsinh(self, a: ArrayOrContainerOrScalarT, /, + ) -> ArrayOrContainerOrScalarT: ... + def arccosh(self, a: ArrayOrContainerOrScalarT, /, + ) -> ArrayOrContainerOrScalarT: ... + def arctanh(self, a: ArrayOrContainerOrScalarT, /, ) -> ArrayOrContainerOrScalarT: ... + def ceil(self, a: ArrayOrContainerOrScalarT, /, + ) -> ArrayOrContainerOrScalarT: ... + def floor(self, a: ArrayOrContainerOrScalarT, /, + ) -> ArrayOrContainerOrScalarT: ... + + def exp(self, a: ArrayOrContainerOrScalarT, / + ) -> ArrayOrContainerOrScalarT: ... + def expm1(self, a: ArrayOrContainerOrScalarT, / + ) -> ArrayOrContainerOrScalarT: ... + def exp2(self, a: ArrayOrContainerOrScalarT, / + ) -> ArrayOrContainerOrScalarT: ... + def log(self, a: ArrayOrContainerOrScalarT, / + ) -> ArrayOrContainerOrScalarT: ... + def log10(self, a: ArrayOrContainerOrScalarT, / + ) -> ArrayOrContainerOrScalarT: ... + def log2(self, a: ArrayOrContainerOrScalarT, / + ) -> ArrayOrContainerOrScalarT: ... + def log1p(self, a: ArrayOrContainerOrScalarT, / + ) -> ArrayOrContainerOrScalarT: ... + def logaddexp(self, a: ArrayOrContainerOrScalarT, / + ) -> ArrayOrContainerOrScalarT: ... + def logaddexp2(self, a: ArrayOrContainerOrScalarT, / + ) -> ArrayOrContainerOrScalarT: ... # {{{ binary/ternary ufuncs # FIXME: These are more restrictive than necessary, but they'll do the job @@ -397,73 +444,71 @@ def where(self, @overload def sum(self, - a: ArrayOrContainer, + a: ArrayOrContainer, /, axis: int | tuple[int, ...] | None = None, dtype: DTypeLike = None, ) -> Array: ... @overload def sum(self, - a: ScalarLike, + a: ScalarLike, /, axis: int | tuple[int, ...] | None = None, dtype: DTypeLike = None, ) -> ScalarLike: ... def sum(self, - a: ArrayOrContainerOrScalar, + a: ArrayOrContainerOrScalar, /, axis: int | tuple[int, ...] | None = None, dtype: DTypeLike = None, ) -> ArrayOrScalar: ... @overload def min(self, - a: ArrayOrContainer, + a: ArrayOrContainer, /, axis: int | tuple[int, ...] | None = None, ) -> Array: ... @overload def min(self, - a: ScalarLike, + a: ScalarLike, /, axis: int | tuple[int, ...] | None = None, ) -> ScalarLike: ... def min(self, - a: ArrayOrContainerOrScalar, + a: ArrayOrContainerOrScalar, /, axis: int | tuple[int, ...] | None = None, ) -> ArrayOrScalar: ... @overload def max(self, - a: ArrayOrContainer, + a: ArrayOrContainer, /, axis: int | tuple[int, ...] | None = None, ) -> Array: ... @overload def max(self, - a: ScalarLike, + a: ScalarLike, /, axis: int | tuple[int, ...] | None = None, ) -> ScalarLike: ... def max(self, - a: ArrayOrContainerOrScalar, + a: ArrayOrContainerOrScalar, /, axis: int | tuple[int, ...] | None = None, ) -> ArrayOrScalar: ... @deprecated("use min instead") def amin(self, - a: ArrayOrContainerOrScalar, + a: ArrayOrContainerOrScalar, /, axis: int | tuple[int, ...] | None = None, ) -> ArrayOrScalar: ... @deprecated("use max instead") def amax(self, - a: ArrayOrContainerOrScalar, + a: ArrayOrContainerOrScalar, /, axis: int | tuple[int, ...] | None = None, ) -> ArrayOrScalar: ... - def any(self, - a: ArrayOrContainerOrScalar, + def any(self, a: ArrayOrContainerOrScalar, /, ) -> ArrayOrScalar: ... - def all(self, - a: ArrayOrContainerOrScalar, + def all(self, a: ArrayOrContainerOrScalar, /, ) -> ArrayOrScalar: ... # }}} @@ -475,50 +520,40 @@ def all(self, # These operations provide access to numpy-style comparisons in that # case. - def greater( - self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar - ) -> ArrayOrContainerOrScalar: - ... - - def greater_equal( - self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar - ) -> ArrayOrContainerOrScalar: - ... - - def less( - self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar - ) -> ArrayOrContainerOrScalar: - ... - - def less_equal( - self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar - ) -> ArrayOrContainerOrScalar: - ... - - def equal( - self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar - ) -> ArrayOrContainerOrScalar: - ... - - def not_equal( - self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar - ) -> ArrayOrContainerOrScalar: - ... - - def logical_or( - self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar - ) -> ArrayOrContainerOrScalar: - ... - - def logical_and( - self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar - ) -> ArrayOrContainerOrScalar: - ... - - def logical_not( - self, x: ArrayOrContainerOrScalar - ) -> ArrayOrContainerOrScalar: - ... + def greater(self, + x: ArrayOrContainerOrScalar, + y: ArrayOrContainerOrScalar, /, + ) -> ArrayOrContainerOrScalar: ... + def greater_equal(self, + x: ArrayOrContainerOrScalar, + y: ArrayOrContainerOrScalar, / + ) -> ArrayOrContainerOrScalar: ... + def less(self, + x: ArrayOrContainerOrScalar, + y: ArrayOrContainerOrScalar, / + ) -> ArrayOrContainerOrScalar: ... + def less_equal(self, + x: ArrayOrContainerOrScalar, + y: ArrayOrContainerOrScalar, / + ) -> ArrayOrContainerOrScalar: ... + def equal(self, + x: ArrayOrContainerOrScalar, + y: ArrayOrContainerOrScalar, / + ) -> ArrayOrContainerOrScalar: ... + def not_equal(self, + x: ArrayOrContainerOrScalar, + y: ArrayOrContainerOrScalar, / + ) -> ArrayOrContainerOrScalar: ... + def logical_or(self, + x: ArrayOrContainerOrScalar, + y: ArrayOrContainerOrScalar, / + ) -> ArrayOrContainerOrScalar: ... + def logical_and(self, + x: ArrayOrContainerOrScalar, + y: ArrayOrContainerOrScalar, / + ) -> ArrayOrContainerOrScalar: ... + def logical_not(self, x: ArrayOrContainerOrScalar, / + ) -> ArrayOrContainerOrScalar: ... # }}} diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 76c22062..af3ca5c9 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -157,11 +157,11 @@ def _rec_vdot(ary1, ary2): # {{{ logic functions - def all(self, a): + def all(self, a, /): return rec_map_reduce_array_container( partial(reduce, jnp.logical_and), jnp.all, a) - def any(self, a): + def any(self, a, /): return rec_map_reduce_array_container( partial(reduce, jnp.logical_or), jnp.any, a) diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index 4fe94bc9..f923a4d6 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -236,11 +236,11 @@ def inner_ravel(ary: ArrayOrScalar) -> ArrayOrScalar: def vdot(self, x, y): return rec_multimap_reduce_array_container(sum, np.vdot, x, y) - def any(self, a): + def any(self, a, /): return rec_map_reduce_array_container(partial(reduce, np.logical_or), lambda subary: np.any(subary), a) - def all(self, a): + def all(self, a, /): return rec_map_reduce_array_container(partial(reduce, np.logical_and), lambda subary: np.all(subary), a) diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 22b475f6..2f063984 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -197,7 +197,7 @@ def vdot(self, x, y, dtype=None): # {{{ logic functions - def all(self, a): + def all(self, a, /): queue = self._array_context.queue def _all(ary): @@ -210,7 +210,7 @@ def _all(ary): _all, a) - def any(self, a): + def any(self, a, /): queue = self._array_context.queue def _any(ary): diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 259dd911..c96aaf62 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -173,12 +173,12 @@ def stack(self, arrays, axis=0): # {{{ logic functions - def all(self, a): + def all(self, a, /): return rec_map_reduce_array_container( partial(reduce, pt.logical_and), lambda subary: pt.all(subary), a) - def any(self, a): + def any(self, a, /): return rec_map_reduce_array_container( partial(reduce, pt.logical_or), lambda subary: pt.any(subary), a) From b45c156cfc143e3b526b27817914e0d5f0e5156a Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 28 Oct 2025 12:07:56 -0500 Subject: [PATCH 2/3] Improve error message for inconsistent types in _multimap_array_container_impl --- arraycontext/container/traversal.py | 13 ++++++++++--- test/test_arraycontext.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 37bd01c3..2e0d428b 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -244,9 +244,16 @@ def rec(*args_: Any) -> Any: except NotAnArrayContainerError: return f(*args_) - assert all( - type(args_[i]) is type(template_ary) for i in container_indices[1:] - ), f"expected type '{type(template_ary).__name__}'" + if __debug__: # noqa: SIM102 + if not all( + type(args_[i]) is type(template_ary) + for i in container_indices[1:] + ): + arg_summary = ", ".join( + f"arg{i+1}: {type(arg)}" + for i, arg in enumerate(args_)) + raise TypeError( + f"inconsistent types in multiple traversal: {arg_summary}") result = [] new_args = list(args_) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 06aa2061..ddb8df05 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -802,7 +802,7 @@ def check_leaf(a, subary1, b, subary2): for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: check_leaf(2, ary, 2, ary) - with pytest.raises(AssertionError): + with pytest.raises(TypeError): rec_multimap_array_container(func_multiple_scalar, 2, ary_dof, 2, dc_of_dofs) # }}} From b7092d35979ce1d0d9f9823adc83cb628c5a5918 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 28 Oct 2025 12:02:22 -0500 Subject: [PATCH 3/3] Update baseline --- .basedpyright/baseline.json | 124 ++++++++++++++++++++++++++++++++++-- 1 file changed, 118 insertions(+), 6 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 3aa3bf72..8aaa5228 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -388,24 +388,32 @@ { "code": "reportAny", "range": { - "startColumn": 21, - "endColumn": 29, + "startColumn": 29, + "endColumn": 37, "lineCount": 1 } }, { "code": "reportAny", "range": { - "startColumn": 39, - "endColumn": 51, + "startColumn": 47, + "endColumn": 59, "lineCount": 1 } }, { "code": "reportAny", "range": { - "startColumn": 42, - "endColumn": 54, + "startColumn": 38, + "endColumn": 41, + "lineCount": 1 + } + }, + { + "code": "reportAny", + "range": { + "startColumn": 27, + "endColumn": 30, "lineCount": 1 } }, @@ -2461,6 +2469,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 20, + "endColumn": 21, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -2469,6 +2485,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 23, + "endColumn": 24, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -2493,6 +2517,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 24, + "endColumn": 25, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -2501,6 +2533,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 27, + "endColumn": 28, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -2525,6 +2565,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 22, + "endColumn": 23, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -2533,6 +2581,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 25, + "endColumn": 26, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -2557,6 +2613,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 28, + "endColumn": 29, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -2565,6 +2629,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 31, + "endColumn": 32, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -2589,6 +2661,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 19, + "endColumn": 20, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -2597,6 +2677,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 22, + "endColumn": 23, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -2621,6 +2709,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 25, + "endColumn": 26, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -2629,6 +2725,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 28, + "endColumn": 29, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -3641,6 +3745,14 @@ "lineCount": 5 } }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 12, + "endColumn": 18, + "lineCount": 1 + } + }, { "code": "reportAny", "range": {