diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6fa887f1b376..918a734a130d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -32,8 +32,8 @@ import math import operator import types -from typing import (overload, Any, Callable, Literal, NamedTuple, Protocol, - TypeVar, Union) +from typing import (cast, overload, Any, Callable, Literal, NamedTuple, + Protocol, TypeVar, Union) from textwrap import dedent as _dedent import warnings @@ -851,10 +851,12 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: try: shape = list(shape) except TypeError: - shape = [shape] + # TODO: Consider warning here since shape is supposed to be a sequence, so + # this should not happen. + shape = cast(list[Any], [shape]) if any(ndim(s) != 0 for s in shape): raise ValueError("unravel_index: shape should be a scalar or 1D sequence.") - out_indices = [0] * len(shape) + out_indices: list[ArrayLike] = [0] * len(shape) for i, s in reversed(list(enumerate(shape))): indices_arr, out_indices[i] = ufuncs.divmod(indices_arr, s) oob_pos = indices_arr > 0 @@ -1137,7 +1139,12 @@ def where( else: util.check_arraylike("where", acondition, if_true, if_false) if size is not None or fill_value is not None: - raise ValueError("size and fill_value arguments cannot be used in three-term where function.") + raise ValueError("size and fill_value arguments cannot be used in " + "three-term where function.") + if if_true is None or if_false is None: + raise ValueError("Either both or neither of the x and y arguments " + "should be provided to jax.numpy.where, got " + f"{if_true} and {if_false}.") return util._where(acondition, if_true, if_false) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index dc2ff57c0d28..9e82284f7cc4 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -46,7 +46,7 @@ def _roots_no_zeros(p: Array) -> Array: @jit -def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array: +def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array: # Avoid lapack errors when p is all zero p = _where(len(p) == num_leading_zeros, 1.0, p) # Roll any leading zeros to the end & compute the roots @@ -85,7 +85,7 @@ def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array: """) def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: check_arraylike("roots", p) - p_arr = atleast_1d(*promote_dtypes_inexact(p)) + p_arr = atleast_1d(promote_dtypes_inexact(p)[0]) if p_arr.ndim != 1: raise ValueError("Input must be a rank-1 array.") if p_arr.size < 2: @@ -96,7 +96,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: num_leading_zeros = core.concrete_or_error(int, num_leading_zeros, "The error occurred in the jnp.roots() function. To use this within a " "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros " - "will be result in some returned roots being set to NaN.") + "will result in some returned roots being set to NaN.") return _roots_no_zeros(p_arr[num_leading_zeros:]) else: return _roots_with_zeros(p_arr, num_leading_zeros) diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index dabc567ae6ea..1665e0fed1f8 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -18,7 +18,7 @@ from collections.abc import Sequence import dataclasses import math -from typing import Any, Generic, TypeVar, Union +from typing import Any, Union from jax._src import core from jax._src import effects @@ -74,8 +74,6 @@ class AccumEffect(RefEffect): # ## `Ref`s -Aval = TypeVar("Aval", bound=core.AbstractValue) - @dataclasses.dataclass class RefIndexer: ref_or_view: Any @@ -124,7 +122,7 @@ def __setitem__(self, slc, value): # We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs. -class AbstractRef(core.AbstractValue, Generic[Aval]): +class AbstractRef(core.AbstractValue): __slots__ = ["inner_aval"] def __init__(self, inner_aval: core.AbstractValue): @@ -212,6 +210,6 @@ def get_ref_state_effects( def shaped_array_ref(shape: tuple[int, ...], dtype, weak_type: bool = False, - named_shape = None) -> AbstractRef[core.AbstractValue]: + named_shape = None) -> AbstractRef: return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type, named_shape=named_shape)) diff --git a/jax/_src/util.py b/jax/_src/util.py index 3bfea618ceb9..c09231fc2cf1 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -138,10 +138,13 @@ def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T] lists[b].append(x) return lists -def merge_lists(bs: Sequence[bool], l0: Sequence[T], l1: Sequence[T]) -> list[T]: +def merge_lists(bs: Sequence[bool], + l0: Sequence[T1], + l1: Sequence[T2] + ) -> list[T1 | T2]: assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0) i0, i1 = iter(l0), iter(l1) - out = [next(i1) if b else next(i0) for b in bs] + out: list[T1 | T2] = [next(i1) if b else next(i0) for b in bs] sentinel = object() assert next(i0, sentinel) is next(i1, sentinel) is sentinel return out