Skip to content

Commit

Permalink
Repair various type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Mar 13, 2024
1 parent a8e2ee9 commit ddd6d55
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 17 deletions.
22 changes: 15 additions & 7 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
import math
import operator
import types
from typing import (overload, Any, Callable, Literal, NamedTuple, Protocol,
TypeVar, Union)
from typing import (cast, overload, Any, Callable, Literal, NamedTuple,
Protocol, TypeVar, Union)
from textwrap import dedent as _dedent
import warnings

Expand Down Expand Up @@ -851,12 +851,15 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
try:
shape = list(shape)
except TypeError:
shape = [shape]
# TODO: Consider warning here since shape is supposed to be a sequence, so
# this should not happen.
shape = cast(list[Any], [shape])
if any(ndim(s) != 0 for s in shape):
raise ValueError("unravel_index: shape should be a scalar or 1D sequence.")
out_indices = [0] * len(shape)
for i, s in reversed(list(enumerate(shape))):
indices_arr, out_indices[i] = ufuncs.divmod(indices_arr, s)
out_indices: list[Array] = []
for s in reversed(shape):
indices_arr, this_out_index = ufuncs.divmod(indices_arr, s)
out_indices.append(this_out_index)
oob_pos = indices_arr > 0
oob_neg = indices_arr < -1
return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i))
Expand Down Expand Up @@ -1137,7 +1140,12 @@ def where(
else:
util.check_arraylike("where", acondition, if_true, if_false)
if size is not None or fill_value is not None:
raise ValueError("size and fill_value arguments cannot be used in three-term where function.")
raise ValueError("size and fill_value arguments cannot be used in "
"three-term where function.")
if if_true is None or if_false is None:
raise ValueError("Either both or neither of the x and y arguments "
"should be provided to jax.numpy.where, got "
f"{if_true} and {if_false}.")
return util._where(acondition, if_true, if_false)


Expand Down
6 changes: 3 additions & 3 deletions jax/_src/numpy/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _roots_no_zeros(p: Array) -> Array:


@jit
def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array:
def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array:
# Avoid lapack errors when p is all zero
p = _where(len(p) == num_leading_zeros, 1.0, p)
# Roll any leading zeros to the end & compute the roots
Expand Down Expand Up @@ -85,7 +85,7 @@ def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array:
""")
def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
check_arraylike("roots", p)
p_arr = atleast_1d(*promote_dtypes_inexact(p))
p_arr = atleast_1d(promote_dtypes_inexact(p)[0])
if p_arr.ndim != 1:
raise ValueError("Input must be a rank-1 array.")
if p_arr.size < 2:
Expand All @@ -96,7 +96,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
num_leading_zeros = core.concrete_or_error(int, num_leading_zeros,
"The error occurred in the jnp.roots() function. To use this within a "
"JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
"will be result in some returned roots being set to NaN.")
"will result in some returned roots being set to NaN.")
return _roots_no_zeros(p_arr[num_leading_zeros:])
else:
return _roots_with_zeros(p_arr, num_leading_zeros)
Expand Down
8 changes: 3 additions & 5 deletions jax/_src/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections.abc import Sequence
import dataclasses
import math
from typing import Any, Generic, TypeVar, Union
from typing import Any, Union

from jax._src import core
from jax._src import effects
Expand Down Expand Up @@ -74,8 +74,6 @@ class AccumEffect(RefEffect):

# ## `Ref`s

Aval = TypeVar("Aval", bound=core.AbstractValue)

@dataclasses.dataclass
class RefIndexer:
ref_or_view: Any
Expand Down Expand Up @@ -124,7 +122,7 @@ def __setitem__(self, slc, value):


# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
class AbstractRef(core.AbstractValue, Generic[Aval]):
class AbstractRef(core.AbstractValue):
__slots__ = ["inner_aval"]

def __init__(self, inner_aval: core.AbstractValue):
Expand Down Expand Up @@ -212,6 +210,6 @@ def get_ref_state_effects(

def shaped_array_ref(shape: tuple[int, ...], dtype,
weak_type: bool = False,
named_shape = None) -> AbstractRef[core.AbstractValue]:
named_shape = None) -> AbstractRef:
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type,
named_shape=named_shape))
8 changes: 6 additions & 2 deletions jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from __future__ import annotations

Expand Down Expand Up @@ -138,10 +139,13 @@ def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T]
lists[b].append(x)
return lists

def merge_lists(bs: Sequence[bool], l0: Sequence[T], l1: Sequence[T]) -> list[T]:
def merge_lists(bs: Sequence[bool],
l0: Sequence[T1],
l1: Sequence[T2]
) -> list[T1 | T2]:
assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0)
i0, i1 = iter(l0), iter(l1)
out = [next(i1) if b else next(i0) for b in bs]
out: list[T1 | T2] = [next(i1) if b else next(i0) for b in bs]
sentinel = object()
assert next(i0, sentinel) is next(i1, sentinel) is sentinel
return out
Expand Down

0 comments on commit ddd6d55

Please sign in to comment.