Skip to content

Commit

Permalink
Remove the canonicalize_dtypes argument from mlir.ir_constant(s).
Browse files Browse the repository at this point in the history
Instead, force the caller to explicitly canonicalize the argument if that's what they want.

The current behavior (canonicalize by default) is not the behavior we want to encourage: we want to canonicalize exactly where we need to and nowhere else.

PiperOrigin-RevId: 557806903
  • Loading branch information
hawkinsp authored and jax authors committed Aug 17, 2023
1 parent ab9555e commit 8894892
Show file tree
Hide file tree
Showing 11 changed files with 62 additions and 66 deletions.
5 changes: 2 additions & 3 deletions jax/_src/array.py
Expand Up @@ -677,9 +677,8 @@ def make_array_from_single_device_arrays(
basearray.Array.register(ArrayImpl)


def _array_mlir_constant_handler(val, canonicalize_types=True):
return mlir.ir_constants(val._value,
canonicalize_types=canonicalize_types)
def _array_mlir_constant_handler(val):
return mlir.ir_constants(val._value)
mlir.register_constant_handler(ArrayImpl, _array_mlir_constant_handler)


Expand Down
57 changes: 26 additions & 31 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -202,7 +202,7 @@ def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type:
# Constants

class ConstantHandler(Protocol):
def __call__(self, val: Any, canonicalize_types: bool) -> Sequence[ir.Value]:
def __call__(self, val: Any) -> Sequence[ir.Value]:
"""Builds an IR representation for a constant `val`.
A JAX value is represented by zero or more IR values."""
Expand All @@ -215,8 +215,7 @@ def register_constant_handler(type_: type, handler_fun: ConstantHandler):
def get_constant_handler(type_: type) -> ConstantHandler:
return _constant_handlers[type_]

def ir_constants(val: Any,
canonicalize_types: bool = True) -> Sequence[ir.Value]:
def ir_constants(val: Any) -> Sequence[ir.Value]:
"""Translate a Python `val` to an IR constant, canonicalizing its dtype.
Args:
Expand All @@ -228,26 +227,23 @@ def ir_constants(val: Any,
for t in type(val).__mro__:
handler = _constant_handlers.get(t)
if handler:
out = handler(val, canonicalize_types)
out = handler(val)
assert all(isinstance(v, ir.Value) for v in out), (type(val), out)
return out
if hasattr(val, '__jax_array__'):
return ir_constants(val.__jax_array__(), canonicalize_types)
return ir_constants(val.__jax_array__())
raise TypeError(f"No constant handler for type: {type(val)}")

def ir_constant(val: Any, canonicalize_types: bool = True) -> ir.Value:
def ir_constant(val: Any) -> ir.Value:
"""Convenience wrapper around ir_constants for singleton values."""
values = ir_constants(val, canonicalize_types=canonicalize_types)
values = ir_constants(val)
if len(values) != 1:
raise TypeError(f"ir_constant called on {val} which corresponds to "
f"multiple IR values {values}")
return values[0]


def _numpy_array_constant(x: np.ndarray, canonicalize_types
) -> Sequence[ir.Value]:
if canonicalize_types:
x = np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
def _numpy_array_constant(x: np.ndarray) -> Sequence[ir.Value]:
element_type = dtype_to_ir_type(x.dtype)
shape = x.shape
if x.dtype == np.bool_:
Expand All @@ -263,8 +259,7 @@ def _masked_array_constant_handler(*args, **kwargs):

register_constant_handler(np.ma.MaskedArray, _masked_array_constant_handler)

def _ndarray_constant_handler(val: np.ndarray, canonicalize_types
) -> Sequence[ir.Value]:
def _ndarray_constant_handler(val: np.ndarray) -> Sequence[ir.Value]:
"""Constant handler for ndarray literals, handling zero-size strides.
In most cases this function calls _numpy_array_constant(val) except it has
Expand All @@ -282,24 +277,20 @@ def _ndarray_constant_handler(val: np.ndarray, canonicalize_types
staged into the XLA Computation.
"""
if dtypes.result_type(val) == dtypes.float0:
return _numpy_array_constant(np.zeros(val.shape, dtype=np.bool_),
canonicalize_types=False)
return _numpy_array_constant(np.zeros(val.shape, dtype=np.bool_))
elif np.any(np.equal(0, val.strides)) and val.size > 0:
zero_stride_axes, = np.where(np.equal(0, val.strides))
other_axes, = np.where(np.not_equal(0, val.strides))
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) # type: ignore
for ax in range(val.ndim))] # type: ignore
if canonicalize_types:
collapsed_val = np.asarray(
collapsed_val, dtypes.canonicalize_dtype(collapsed_val.dtype))
out = hlo.BroadcastInDimOp(
ir.RankedTensorType.get(
val.shape, dtype_to_ir_type(collapsed_val.dtype)),
_numpy_array_constant(collapsed_val, canonicalize_types=False)[0],
_numpy_array_constant(collapsed_val)[0],
dense_int_elements(other_axes)).result
return (out,)
else:
return _numpy_array_constant(val, canonicalize_types)
return _numpy_array_constant(val)

register_constant_handler(np.ndarray, _ndarray_constant_handler)

Expand All @@ -310,13 +301,13 @@ def _ndarray_constant_handler(val: np.ndarray, canonicalize_types
np.bool_, np.longlong, dtypes.bfloat16]:
register_constant_handler(_scalar_type, _ndarray_constant_handler) # type: ignore

def _python_scalar_handler(dtype, val, canonicalize_dtypes):
return _numpy_array_constant(np.array(val, dtype), canonicalize_dtypes)
def _python_scalar_handler(dtype, val):
return _numpy_array_constant(np.array(val, dtype))

for ptype, dtype in dtypes.python_scalar_dtypes.items():
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))

def _token_constant_handler(val, canonicalize_types):
def _token_constant_handler(val):
return [hlo.CreateTokenOp().result]
register_constant_handler(core.Token, _token_constant_handler)

Expand Down Expand Up @@ -1110,9 +1101,10 @@ def aval_to_types(aval):
else:
args.append(arg)
callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name))
out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
jaxpr.jaxpr, tokens_in, map(ir_constants, jaxpr.consts),
*args, dim_var_values=dim_var_values)
consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
out_vals, tokens_out = jaxpr_subcomp(
ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, tokens_in,
consts, *args, dim_var_values=dim_var_values)
outs = []
if create_tokens:
for _ in range(num_output_tokens):
Expand Down Expand Up @@ -1229,7 +1221,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
assert ctx.platform != "gpu"
def read(v: core.Atom) -> Sequence[ir.Value]:
if type(v) is core.Literal:
return ir_constants(v.val, canonicalize_types=True)
return ir_constants(xla.canonicalize_dtype(v.val))
else:
assert isinstance(v, core.Var)
return env[v]
Expand Down Expand Up @@ -1329,12 +1321,16 @@ def get_lowering(primitive: core.Primitive) -> LoweringRule | None:
core.clean_up_dead_vars(eqn, env, last_used)
return map(read, jaxpr.outvars), tokens


def _ir_consts(consts):
unique_consts = {id(const): const for const in consts}
ir_consts = {
id_: ir_constants(const) for id_, const in unique_consts.items()}
id_: ir_constants(xla.canonicalize_dtype(const))
for id_, const in unique_consts.items()
}
return [ir_consts[id(const)] for const in consts]


def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
"""Converts a traceable JAX function `fun` into a lowering rule.
Expand Down Expand Up @@ -1606,7 +1602,7 @@ def iota(ctx: LoweringRuleContext, aval_out, *, dimension: int):

def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> ir.Value:
"""Returns an IR constant shaped full of `value` shaped like `aval`."""
zero = ir_constant(np.array(value, aval.dtype))
zero = ir_constant(np.array(value, dtypes.canonicalize_dtype(aval.dtype)))
return broadcast_in_dim(ctx, zero, aval, broadcast_dimensions=())

def zeros_like_lowering(ctx, x):
Expand Down Expand Up @@ -2083,8 +2079,7 @@ def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function
backend.get_emit_python_callback_descriptor(_wrapped_callback,
operand_shapes,
result_shapes))
descriptor_operand = ir_constant(
callback_descriptor, canonicalize_types=False)
descriptor_operand = ir_constant(callback_descriptor)
callback_operands = [descriptor_operand, *operands]
if operand_mlir_layouts is not None:
operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts]
Expand Down
8 changes: 2 additions & 6 deletions jax/_src/lax/ann.py
Expand Up @@ -325,12 +325,8 @@ def _approx_top_k_lowering(ctx, operand, *, k,
dimension=reduction_dimension)

init_arg = hlo.ConstantOp(ir.DenseElementsAttr.get(np.int32(-1))).result
# Can't write bf16 literals, so we write a f64 literal and convert it.
init_val_literal = _get_init_val_literal(np.float64, is_max_k)
init_val_array = np.array(init_val_literal, dtype=np.float64).reshape(())
init_val = mlir.ir_constant(init_val_array)
init_val = hlo.ConvertOp(ir.RankedTensorType.get([],
mlir.dtype_to_ir_type(ctx.avals_in[0].dtype)), init_val).result
init_val_array = _get_init_val_literal(ctx.avals_in[0].dtype, is_max_k)
init_val = mlir.ir_constant(init_val_array.reshape(()))

backend_config = {
"top_k" : mlir.i64_attr(k),
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow/conditionals.py
Expand Up @@ -850,10 +850,10 @@ def _cond_lowering(ctx, index, *args, branches, linear):
with ir.InsertionPoint(branch):
sub_ctx = ctx.module_context.replace(
name_stack=name_stack.extend(f'branch_{i}_fun'))
consts = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
out_vals, tokens_out = mlir.jaxpr_subcomp(
sub_ctx, jaxpr.jaxpr, tokens_in,
map(mlir.ir_constants, jaxpr.consts),
*map(mlir.wrap_singleton_ir_values, args),
consts, *map(mlir.wrap_singleton_ir_values, args),
dim_var_values=ctx.dim_var_values)
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
out_vals = [*out_tokens, *out_vals]
Expand Down
24 changes: 17 additions & 7 deletions jax/_src/lax/control_flow/loops.py
Expand Up @@ -1611,9 +1611,17 @@ def fun(*args):
cond_args = cond_args[num_tokens:]
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
cond_ctx = ctx.module_context.replace(name_stack=name_stack.extend('cond'))
((pred,),), _ = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
_map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z), dim_var_values=ctx.dim_var_values)
cond_consts = [
mlir.ir_constants(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts
]
((pred,),), _ = mlir.jaxpr_subcomp(
cond_ctx,
cond_jaxpr.jaxpr,
mlir.TokenSet(),
cond_consts,
*(x + z),
dim_var_values=ctx.dim_var_values,
)
if batched:
pred_ctx = mlir.LoweringRuleContext(
module_context=ctx.module_context,
Expand Down Expand Up @@ -1642,17 +1650,19 @@ def fun(*args):
tokens_in = mlir.TokenSet(zip(body_effects, token_args))
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
body_ctx = ctx.module_context.replace(name_stack=name_stack.extend('body'))
body_consts = [mlir.ir_constants(xla.canonicalize_dtype(x))
for x in body_jaxpr.consts]
new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
tokens_in, _map(mlir.ir_constants, body_jaxpr.consts),
*(y + z), dim_var_values=ctx.dim_var_values)
tokens_in, body_consts, *(y + z), dim_var_values=ctx.dim_var_values)
out_tokens = [tokens_out.get(eff) for eff in body_effects]
if batched:
body_pred_ctx = ctx.module_context.replace(
name_stack=name_stack.extend('body_pred'))
cond_consts = [mlir.ir_constants(xla.canonicalize_dtype(x))
for x in cond_jaxpr.consts]
((body_pred,),), _ = mlir.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
_map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z), dim_var_values=ctx.dim_var_values)
cond_consts, *(x + z), dim_var_values=ctx.dim_var_values)
new_z = _map(
partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z,
body_jaxpr.out_avals)
Expand Down
5 changes: 2 additions & 3 deletions jax/_src/lax/fft.py
Expand Up @@ -146,8 +146,7 @@ def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = dtype

zero = mlir.ir_constant(np.array(0, dtype=out_dtype),
canonicalize_types=False)
zero = mlir.ir_constant(np.array(0, dtype=out_dtype))
return [
mlir.broadcast_in_dim(ctx, zero, out_aval, broadcast_dimensions=[])]

Expand All @@ -172,7 +171,7 @@ def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
size_fft_length_prod = np.prod(fft_lengths) if fft_lengths else 1
size_fft_lengths, = mlir.eval_dynamic_shape_as_vals(ctx, (size_fft_length_prod,))
size_fft_lengths = hlo.ConvertOp(double_type, size_fft_lengths)
one = mlir.ir_constant(np.float64(1.), canonicalize_types=False)
one = mlir.ir_constant(np.float64(1.))
scale = one if forward else hlo.DivOp(one, size_fft_lengths)
scale = hlo.ReshapeOp(
mlir.ir.RankedTensorType.get((1,), mlir.ir.F64Type.get()),
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/lax/lax.py
Expand Up @@ -4396,8 +4396,7 @@ def _rng_uniform_abstract_eval(a, b, *, shape):

def _rng_uniform_lowering(ctx, a, b, *, shape):
aval_out, = ctx.avals_out
shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64),
canonicalize_types=False)
shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64))
return hlo.RngOp(a, b, shape,
hlo.RngDistributionAttr.get('UNIFORM')).results

Expand Down
3 changes: 1 addition & 2 deletions jax/_src/lax/windowed_reductions.py
Expand Up @@ -668,8 +668,7 @@ def _select_and_gather_add_lowering(
assert nbits <= max_bits
double_word_reduction = nbits * 2 <= max_bits

const = lambda dtype, x: mlir.ir_constant(np.array(x, dtype=dtype),
canonicalize_types=False)
const = lambda dtype, x: mlir.ir_constant(np.array(x, dtype=dtype))

def _broadcast_scalar_const(x, aval_out):
return mlir.broadcast_in_dim(ctx, const(aval_out.dtype, x),
Expand Down
7 changes: 4 additions & 3 deletions jax/_src/maps.py
Expand Up @@ -45,6 +45,7 @@
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters.partial_eval import (
trace_to_subjaxpr_dynamic, DynamicJaxprTracer,
Expand Down Expand Up @@ -1335,7 +1336,7 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
# them!
vectorized_jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(f, local_avals)
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
const_nodes = map(mlir.ir_constants, consts)
const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts]

local_mesh_shape = mesh.local_mesh.shape
tiled_ins = (
Expand Down Expand Up @@ -1418,7 +1419,7 @@ def add_spmd_axes(
if aval_axes else [node]
for node, aval, aval_axes in zip(global_in_nodes, global_in_avals, mesh_in_axes)
]
const_nodes = map(mlir.ir_constants, consts)
const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts]

# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
Expand Down Expand Up @@ -1469,7 +1470,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
# them!
global_in_avals = ctx.avals_in
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
const_nodes = map(mlir.ir_constants, consts)
const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts]

# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
Expand Down
10 changes: 4 additions & 6 deletions jax/_src/prng.py
Expand Up @@ -630,9 +630,9 @@ def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, devices, indices, sharding)
pxla.shard_arg_handlers[PRNGKeyArrayImpl] = key_array_shard_arg_handler


def key_array_constant_handler(x, canonicalize_dtypes):
def key_array_constant_handler(x):
arr = x.unsafe_raw_array()
return mlir.get_constant_handler(type(arr))(arr, canonicalize_dtypes)
return mlir.get_constant_handler(type(arr))(arr)
mlir.register_constant_handler(PRNGKeyArrayImpl, key_array_constant_handler)


Expand Down Expand Up @@ -1178,8 +1178,7 @@ def _add(x: ir.Value, y: ir.Value) -> ir.Value:

def _mul(x: core.DimSize, y: ir.Value) -> ir.Value:
if core.is_constant_dim(x):
x_const = mlir.ir_constant(np.array(x, np.dtype('uint64')),
canonicalize_types=False)
x_const = mlir.ir_constant(np.array(x, np.dtype('uint64')))
else:
x_const, = mlir.eval_dynamic_shape(ctx, (x,))
x_const = hlo.ConvertOp(
Expand All @@ -1195,8 +1194,7 @@ def _mul(x: core.DimSize, y: ir.Value) -> ir.Value:
iotas = [mlir.iota(ctx, aval_u64, dimension=dimension)
for dimension in range(len(shape))]
counts = bcast_iotas_to_reshaped_iota(_add, _mul, shape, iotas)
shift = mlir.ir_constant(np.array(32, np.dtype('uint64')),
canonicalize_types=False)
shift = mlir.ir_constant(np.array(32, np.dtype('uint64')))
shift = mlir.broadcast_in_dim(ctx, shift, aval_u64,
broadcast_dimensions=[])
counts_shifted = mlir.hlo.ShiftRightLogicalOp(counts, shift).result
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/call_tf.py
Expand Up @@ -494,7 +494,7 @@ def _call_tf_lowering(
captured_inputs.append(inp)

captured_ops = tuple(
mlir.ir_constant(np.asarray(inp), canonicalize_types=False)
mlir.ir_constant(np.asarray(inp))
for inp in captured_inputs
)

Expand Down

0 comments on commit 8894892

Please sign in to comment.