Skip to content

Commit

Permalink
Merge pull request #21261 from superbobry:mypy-ruff
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634654578
  • Loading branch information
jax authors committed May 17, 2024
2 parents 1829a66 + c3bc88d commit 5e2710c
Show file tree
Hide file tree
Showing 23 changed files with 75 additions and 65 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ repos:
files: \.py$

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.5
rev: v0.4.4
hooks:
- id: ruff

- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.9.0'
rev: 'v1.10.0'
hooks:
- id: mypy
files: (jax/|tests/typing_test\.py)
exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead
additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.23, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4]
additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.27, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4]
args: [--config=pyproject.toml]

- repo: https://github.com/mwouts/jupytext
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def result_paths(*args, **kwargs):
yield ans, [keystr(path) for path, _ in generate_key_paths(ans)]

def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
result_paths: tuple[str | None, ...] | None = None,
result_paths: tuple[str, ...] | None = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if trace_debug is None:
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/cache_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import os
import sys
from typing import cast as type_cast

from jax._src import config
from jax._src.lib import version_str as jaxlib_version_str
Expand Down Expand Up @@ -136,7 +137,7 @@ def _serialize_ir(m: ir.Module) -> bytes:

def _canonicalize_ir(m_original: ir.Module) -> bytes:
with m_original.context:
m = m_original.operation.clone()
m = type_cast(ir.Module, m_original.operation.clone())
passes = pm.PassManager.parse(
"builtin.module(strip-debuginfo)"
)
Expand Down
8 changes: 5 additions & 3 deletions jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,11 @@ def backend_compile(
options: xc.CompileOptions,
host_callbacks: Sequence[Any],
) -> xc.LoadedExecutable:
# Convert ir.Module to a string representation, unless the
# back-end expliclity flags the ability to handle a module directly
# (avoiding the overhead of back and forth conversions)
# Convert ir.Module to a string representation, unless the backend
# explicitly flags the ability to handle a module directly (avoiding the
# overhead of back and forth conversions).
# TODO(slebedev): Change the backend.compile() to accept ir.Module.
built_c: Any
if getattr(backend, "needs_str_ir", True):
built_c = mlir.module_to_bytecode(module)
else:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class JaxprDebugInfo(NamedTuple):
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... )
result_paths: tuple[str | None, ...] # e.g. ('[0]', '[1]', ...)
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)

class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def _hlo_sharding_callback(hlo_sharding: xc.HloSharding):
has_side_effect=ir.BoolAttr.get(True),
api_version=mlir.i32_attr(1),
called_computations=ir.ArrayAttr.get([]),
backend_config=ir.StringAttr.get(key),
backend_config=ir.StringAttr.get(key), # type: ignore[arg-type]
operand_layouts=None,
result_layouts=None)
return []
Expand Down Expand Up @@ -511,11 +511,11 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *,
heights[chunk_idxs] = None
widths[chunk_idxs] = horiz_size / shape[0]
slices.setdefault(chunk_idxs, set()).add(dev.id)
num_rows = max([a[0] for a in slices.keys()]) + 1
num_rows = max(a[0] for a in slices.keys()) + 1
if len(list(slices.keys())[0]) == 1:
num_cols = 1
else:
num_cols = max([a[1] for a in slices.keys()]) + 1
num_cols = max(a[1] for a in slices.keys()) + 1

color_iter = make_color_iter(color_map, num_rows, num_cols)
table = rich.table.Table(show_header=False, show_lines=not use_color,
Expand Down
51 changes: 26 additions & 25 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import re
import types
import typing
from typing import Any, Callable, NamedTuple, Protocol, Union
from typing import Any, Callable, NamedTuple, Protocol, Union, cast as type_cast
import warnings

import numpy as np
Expand Down Expand Up @@ -87,19 +87,20 @@
# IR Helpers

def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
return type_cast(ir.DenseIntElementsAttr,
ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)))

def dense_int_array(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
def dense_int_array(xs) -> ir.DenseElementsAttr | ir.DenseI64ArrayAttr:
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher
if hlo.get_api_version() < 5:
return dense_int_elements(xs)
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) # type: ignore

# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher
def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
if hlo.get_api_version() < 6:
return dense_int_elements(xs)
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) # type: ignore

def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
a = np.packbits(np.array(xs, np.bool_), bitorder='little')
Expand All @@ -114,7 +115,7 @@ def dense_bool_array(xs: Sequence[bool]) -> ir.DenseElementsAttr | ir.DenseBoolA
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v6 or higher
if hlo.get_api_version() < 6:
return dense_bool_elements(xs)
return ir.DenseBoolArrayAttr.get(xs)
return ir.DenseBoolArrayAttr.get(xs) # type: ignore

def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i)
def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)
Expand All @@ -132,7 +133,7 @@ def lower_dim(d):
return hlo.reshape(int1d, d)
ds = map(lower_dim, sizes)
if not ds:
return ir_constant(np.array([], np.int32))
return type_cast(ir.RankedTensorType, ir_constant(np.array([], np.int32)))
elif len(ds) == 1:
return ds[0]
else:
Expand Down Expand Up @@ -195,7 +196,7 @@ def _array_ir_types(aval: core.ShapedArray | core.DShapedArray
aval = core.physical_aval(aval) # type: ignore
if not core.is_constant_shape(aval.shape):
return _dynamic_array_ir_types(aval) # type: ignore
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),) # type: ignore

def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
dyn_size = ir.ShapedType.get_dynamic_size()
Expand Down Expand Up @@ -282,7 +283,7 @@ def _numpy_array_constant(x: np.ndarray | np.generic) -> Sequence[ir.Value]:
if x.dtype == np.bool_:
x = np.packbits(x, bitorder='little') # type: ignore
x = np.ascontiguousarray(x)
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore
return (hlo.constant(attr),)


Expand Down Expand Up @@ -314,11 +315,11 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value
elif np.any(np.equal(0, val.strides)) and val.size > 0:
zero_stride_axes, = np.where(np.equal(0, val.strides))
other_axes, = np.where(np.not_equal(0, val.strides))
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) # type: ignore
for ax in range(val.ndim))] # type: ignore
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) # type: ignore
for ax in range(val.ndim))]
out = hlo.broadcast_in_dim(
ir.RankedTensorType.get(
val.shape, dtype_to_ir_type(collapsed_val.dtype)),
val.shape, dtype_to_ir_type(collapsed_val.dtype)), # type: ignore
_numpy_array_constant(collapsed_val)[0],
dense_int_array_v6(other_axes))
return (out,)
Expand Down Expand Up @@ -738,7 +739,7 @@ def wrap_singleton_ir_values(x: ir.Value | Sequence[ir.Value]

def flatten_lowering_ir_args(
xs: Sequence[ir.Value | Sequence[ir.Value]]
) -> Sequence[Sequence[ir.Value]]:
) -> Sequence[ir.Value]:
return util.flatten(map(wrap_singleton_ir_values, xs))

_module_name_regex = re.compile(r"[^\w.-]")
Expand Down Expand Up @@ -863,7 +864,7 @@ def lower_jaxpr_to_module(
in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
arg_names: Sequence[str | None] | None = None,
result_names: Sequence[str | None] | None = None,
result_names: Sequence[str] | None = None,
num_replicas: int = 1,
num_partitions: int = 1,
all_default_mem_kind: bool = True,
Expand Down Expand Up @@ -1106,7 +1107,7 @@ def lower_jaxpr_to_fun(
xla_donated_args: Sequence[bool] | None = None,
api_name: str = "jit",
arg_names: Sequence[str | None] | None = None,
result_names: Sequence[str | None] | None = None,
result_names: Sequence[str] | None = None,
arg_memory_kinds: Sequence[str | None] | None = None,
result_memory_kinds: Sequence[str | None] | None = None,
arg_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
Expand Down Expand Up @@ -1618,7 +1619,7 @@ def lower_per_platform(ctx: LoweringRuleContext,
default_rule: LoweringRule | None,
effects: effects_lib.Effects,
*rule_args: ir.Value,
**rule_kwargs) -> ir.Value:
**rule_kwargs) -> Sequence[ir.Value]:
"""Emits code for a primitive for the current lowering platform(s).
For example, given
Expand Down Expand Up @@ -2039,9 +2040,8 @@ def compare_hlo(x, y, direction: str, comparison_type: str | None = None):
"""Creates CompareOp."""
if comparison_type is None:
elem_type = ir.RankedTensorType(x.type).element_type
if ir.IntegerType.isinstance(elem_type):
comparison_type = ("UNSIGNED" if ir.IntegerType.is_unsigned(elem_type)
else "SIGNED")
if isinstance(elem_type, ir.IntegerType):
comparison_type = "UNSIGNED" if elem_type.is_unsigned else "SIGNED"
else:
comparison_type = "FLOAT"

Expand Down Expand Up @@ -2129,7 +2129,7 @@ def get_sharding_attr(sharding_proto: xc.OpSharding):
# The MHLO to HLO conversion supports both, and the proto representation is
# more compact.
if len(sharding_proto.tile_assignment_devices) > 100:
return ir.StringAttr.get(sharding_proto.SerializeToString())
return ir.StringAttr.get(sharding_proto.SerializeToString()) # type: ignore
else:
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))

Expand Down Expand Up @@ -2315,7 +2315,8 @@ def send_to_host(channel: int, token: hlo.TokenType, operand: Any,

def receive_from_host(channel: int, token: hlo.TokenType,
out_aval: core.ShapedArray, name: str, *,
sharding: xc.OpSharding | None = None) -> ir.Value:
sharding: xc.OpSharding | None = None,
) -> tuple[ir.Value, ir.Value]:
channel_handle = hlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE)
recv_op = hlo.RecvOp([aval_to_ir_type(out_aval),
hlo.TokenType.get()], token, channel_handle,
Expand Down Expand Up @@ -2592,7 +2593,7 @@ def custom_call(
if backend_config is None:
backend_config_attr = ir.StringAttr.get("")
elif isinstance(backend_config, (str, bytes)):
backend_config_attr = ir.StringAttr.get(backend_config)
backend_config_attr = ir.StringAttr.get(backend_config) # type: ignore
elif isinstance(backend_config, dict):
# TODO(necula): it seems that the CustomCallOp constructor requires that
# backend_config_attr be a string attribute, even though in some cases we
Expand Down Expand Up @@ -2661,8 +2662,8 @@ def custom_call(
op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands,
attributes=attributes)
if isinstance(backend_config, dict):
backend_config_attr = ir.DictAttr.get(backend_config)
op.operation.attributes["mhlo.backend_config"] = backend_config_attr
op.operation.attributes["mhlo.backend_config"] = ir.DictAttr.get(
backend_config)
return op


Expand Down Expand Up @@ -2721,7 +2722,7 @@ def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):
base_dilations=dense_int_array_v6(base_dilation),
window_dilations=dense_int_array_v6(window_dilation),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
shape=[len(padding), 2]))
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
with ir.InsertionPoint(reducer):
hlo.return_(reducer_body(reducer))
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,7 +2000,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
return False
return True

memory_kind_propagate_rule = {} # type: ignore
memory_kind_propagate_rule: dict[Any, Any] = {}

@weakref_lru_cache
def get_out_memory_kinds_via_propagation(closed_jaxpr: core.ClosedJaxpr
Expand Down Expand Up @@ -2386,10 +2386,10 @@ def lower_mesh_computation(
all_args_info=None)

class MeshComputation(stages.XlaLowering):
_hlo: ir.Module | None
_hlo: ir.Module
_executable: MeshExecutable | None

def __init__(self, name: str, hlo: ir.Module | None,
def __init__(self, name: str, hlo: ir.Module,
donated_invars: Sequence[bool], **compile_args):
self._name = name
self._hlo = hlo
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,9 @@ def _platform_index_lowering(ctx: mlir.LoweringRuleContext,
*,
platforms: Sequence[Sequence[str]],
has_default: bool):
def lower_constant(ctx: mlir.LoweringRuleContext, *, i: int) -> mlir.ir.Value:
def lower_constant(
ctx: mlir.LoweringRuleContext, *, i: int
) -> Sequence[ir.Value]:
return mlir.ir_constants(np.int32(i))
platform_rules: dict[str, mlir.LoweringRule] = {}
for i, ps in enumerate(platforms):
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,7 +1774,7 @@ def broadcast_hlo(
return out

def _nary_lower_hlo(op: Callable, ctx,
*args: ir.Value | Sequence[ir.Value],
*args: ir.Value,
explicit_type=False, **params) -> Sequence[ir.Value]:
"""Lowers an elementwise operator to its MLIR equivalent.
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def _replica_groups(axis_env, axis_name, axis_index_groups):
return replica_groups

def _replica_groups_hlo(replica_groups: Sequence[Sequence[int]]
) -> ir.DenseIntElementsAttr:
) -> ir.DenseElementsAttr:
# Uneven replica groups are padded with -1.
groups = np.array(list(itertools.zip_longest(*replica_groups, fillvalue=-1)),
dtype=np.int64).T
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/lax/windowed_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,9 @@ def snd(t, t_aval):
double_word_out_aval = out_aval.update(dtype=double_word_dtype)

def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]:
x, y = reducer.arguments
x: ir.Value
y: ir.Value
x, y = reducer.arguments # type: ignore
assert select_prim is lax.ge_p or select_prim is lax.le_p
cmp_op = "GE" if select_prim is lax.ge_p else "LE"
out = hlo.SelectOp(mlir.compare_hlo(fst(x), fst(y), cmp_op), x, y)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from functools import wraps, partial, partialmethod, lru_cache
import itertools as it
import math
from typing import Callable, Any, NamedTuple, Union
from typing import Callable, Any, NamedTuple, Union, cast as type_cast

import numpy as np

Expand Down Expand Up @@ -631,7 +631,7 @@ def lower(*args, **kwargs):
no_kwargs=True)

fun_mapped.lower = lower
return fun_mapped
return type_cast(stages.Wrapped, fun_mapped)

def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_invars,
global_axis_sizes, axis_resources, resource_env, backend,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def gradient_along_axis(a, h, axis):
if len(axis_tuple) == 0:
return []

if min([s for i, s in enumerate(a.shape) if i in axis_tuple]) < 2:
if min(s for i, s in enumerate(a.shape) if i in axis_tuple) < 2:
raise ValueError("Shape of array too small to calculate "
"a numerical gradient, "
"at least 2 elements are required.")
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from typing import Any, Callable
from collections.abc import Sequence

from jaxlib.mlir.ir import Module

import jax
from jax import core as jax_core
from jax import lax
Expand Down Expand Up @@ -341,7 +343,7 @@ def lower_jaxpr_to_module(
jaxpr: jax_core.Jaxpr,
dimension_semantics: tuple[str | None, ...] | None,
mesh: mesh_lib.Mesh | None = None
) -> ir.Module:
) -> tuple[Module, tuple[Any, ...]]:
mosaic_grid_mapping = MosaicGridMapping(
jaxpr, grid_mapping, dimension_semantics, mesh)
mosaic_grid_mapping.maybe_compress_grid()
Expand Down Expand Up @@ -2199,7 +2201,7 @@ def _device_id_to_logical(
device_ids = tree_util.tree_leaves(device_id)
mesh_strides = ctx.lowering_context.mesh_context.mesh_strides
def _linearize_mesh_indices(*indices):
return sum([a * b for a, b in zip(indices, mesh_strides)])
return sum(a * b for a, b in zip(indices, mesh_strides))
lower_ctx = LoweringRuleContext(
lowering_context=ctx.lowering_context,
avals_in=[jax_core.ShapedArray((), jnp.int32)] * len(device_ids),
Expand Down

0 comments on commit 5e2710c

Please sign in to comment.