Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix annotations #18395

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
import math
import operator
import types
from typing import (overload, Any, Callable, Literal, NamedTuple, Protocol,
TypeVar, Union)
from typing import (cast, overload, Any, Callable, Literal, NamedTuple,
Protocol, TypeVar, Union)
from textwrap import dedent as _dedent
import warnings

Expand Down Expand Up @@ -851,10 +851,12 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
try:
shape = list(shape)
except TypeError:
shape = [shape]
# TODO: Consider warning here since shape is supposed to be a sequence, so
# this should not happen.
jakevdp marked this conversation as resolved.
Show resolved Hide resolved
shape = cast(list[Any], [shape])
if any(ndim(s) != 0 for s in shape):
raise ValueError("unravel_index: shape should be a scalar or 1D sequence.")
out_indices = [0] * len(shape)
out_indices: list[ArrayLike] = [0] * len(shape)
for i, s in reversed(list(enumerate(shape))):
indices_arr, out_indices[i] = ufuncs.divmod(indices_arr, s)
oob_pos = indices_arr > 0
Expand Down Expand Up @@ -1137,7 +1139,12 @@ def where(
else:
util.check_arraylike("where", acondition, if_true, if_false)
if size is not None or fill_value is not None:
raise ValueError("size and fill_value arguments cannot be used in three-term where function.")
raise ValueError("size and fill_value arguments cannot be used in "
"three-term where function.")
if if_true is None or if_false is None:
raise ValueError("Either both or neither of the x and y arguments "
"should be provided to jax.numpy.where, got "
f"{if_true} and {if_false}.")
return util._where(acondition, if_true, if_false)


Expand Down
6 changes: 3 additions & 3 deletions jax/_src/numpy/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _roots_no_zeros(p: Array) -> Array:


@jit
def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array:
def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array:
NeilGirdhar marked this conversation as resolved.
Show resolved Hide resolved
# Avoid lapack errors when p is all zero
p = _where(len(p) == num_leading_zeros, 1.0, p)
# Roll any leading zeros to the end & compute the roots
Expand Down Expand Up @@ -85,7 +85,7 @@ def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array:
""")
def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
check_arraylike("roots", p)
p_arr = atleast_1d(*promote_dtypes_inexact(p))
p_arr = atleast_1d(promote_dtypes_inexact(p)[0])
NeilGirdhar marked this conversation as resolved.
Show resolved Hide resolved
if p_arr.ndim != 1:
raise ValueError("Input must be a rank-1 array.")
if p_arr.size < 2:
Expand All @@ -96,7 +96,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
num_leading_zeros = core.concrete_or_error(int, num_leading_zeros,
"The error occurred in the jnp.roots() function. To use this within a "
"JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
"will be result in some returned roots being set to NaN.")
"will result in some returned roots being set to NaN.")
return _roots_no_zeros(p_arr[num_leading_zeros:])
else:
return _roots_with_zeros(p_arr, num_leading_zeros)
Expand Down
8 changes: 3 additions & 5 deletions jax/_src/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections.abc import Sequence
import dataclasses
import math
from typing import Any, Generic, TypeVar, Union
from typing import Any, Union

from jax._src import core
from jax._src import effects
Expand Down Expand Up @@ -74,8 +74,6 @@ class AccumEffect(RefEffect):

# ## `Ref`s

Aval = TypeVar("Aval", bound=core.AbstractValue)

@dataclasses.dataclass
class RefIndexer:
ref_or_view: Any
Expand Down Expand Up @@ -124,7 +122,7 @@ def __setitem__(self, slc, value):


# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
class AbstractRef(core.AbstractValue, Generic[Aval]):
class AbstractRef(core.AbstractValue):
jakevdp marked this conversation as resolved.
Show resolved Hide resolved
__slots__ = ["inner_aval"]

def __init__(self, inner_aval: core.AbstractValue):
Expand Down Expand Up @@ -212,6 +210,6 @@ def get_ref_state_effects(

def shaped_array_ref(shape: tuple[int, ...], dtype,
weak_type: bool = False,
named_shape = None) -> AbstractRef[core.AbstractValue]:
named_shape = None) -> AbstractRef:
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type,
named_shape=named_shape))
7 changes: 5 additions & 2 deletions jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,13 @@ def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T]
lists[b].append(x)
return lists

def merge_lists(bs: Sequence[bool], l0: Sequence[T], l1: Sequence[T]) -> list[T]:
def merge_lists(bs: Sequence[bool],
l0: Sequence[T1],
l1: Sequence[T2]
) -> list[T1 | T2]:
NeilGirdhar marked this conversation as resolved.
Show resolved Hide resolved
assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0)
i0, i1 = iter(l0), iter(l1)
out = [next(i1) if b else next(i0) for b in bs]
out: list[T1 | T2] = [next(i1) if b else next(i0) for b in bs]
sentinel = object()
assert next(i0, sentinel) is next(i1, sentinel) is sentinel
return out
Expand Down