diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 63dc8313fe1e..5e5d2fc1031b 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -192,7 +192,10 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars, inline, keep_unused: bool): del inline # Only used at tracing time - arg_specs = unsafe_map(arg_spec, args) + if fun.in_type is None: + arg_specs = unsafe_map(arg_spec, args) + else: + arg_specs = [(None, getattr(x, '_device', None)) for x in args] compiled_fun = xla_callable(fun, device, backend, name, donated_invars, keep_unused, *arg_specs) try: @@ -283,25 +286,8 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, in_type = tuple(unsafe_zip(abstract_args, itertools.repeat(True))) fun = lu.annotate(fun, in_type) else: - # Check that the provided abstract_args are consistent with in_type by first - # collecting values of axis size arguments, then substituting them in for - # DBIdx occurrences. - axis_sizes: Dict[core.DBIdx, int] = {} - abstract_args_iter = iter(abstract_args) - for expected_type, explicit in fun.in_type: - if explicit: - aval = next(abstract_args_iter) - if isinstance(expected_type, core.DShapedArray): - # Check the value for any DBIdx variables is consistent. - assert all(axis_sizes.setdefault(d1, d2) == d2 - for d1, d2 in zip(expected_type.shape, aval.shape) - if type(d1) is core.DBIdx) - # Check the type matches after substitution. - expected_shape = [axis_sizes.get(d, d) for d in expected_type.shape] # type: ignore - expected_aval = core.ShapedArray( - shape=tuple(expected_shape), dtype=expected_type.dtype, - weak_type=expected_type.weak_type) - assert core.typematch(expected_aval, aval) + assert abstract_args == (None,) * len(abstract_args) + abstract_args = [aval for aval, _ in fun.in_type] with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " "for jit in {elapsed_time} sec"): jaxpr, out_type, consts = pe.trace_to_jaxpr_final2( @@ -326,7 +312,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, if i in kept_var_idx] del kept_const_idx else: - kept_var_idx = set(range(len(abstract_args))) + kept_var_idx = set(range(len(fun.in_type))) nreps = jaxpr_replicas(jaxpr) device = _xla_callable_device(nreps, backend, device, arg_devices) @@ -430,12 +416,10 @@ def jaxpr_has_pmap(jaxpr): return False def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool: - return (any(type(d) is core.Var for v in jaxpr.invars - if type(v.aval) is core.DShapedArray for d in v.aval.shape) or - any(type(d) is core.Var + return (any(type(v.aval) is core.AbstractBInt for v in jaxpr.invars) or + any(type(v.aval) is core.AbstractBInt for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr)) - for e in j.eqns for v in itertools.chain(e.invars, e.outvars) - if type(v.aval) is core.DShapedArray for d in v.aval.shape)) + for e in j.eqns for v in e.outvars)) def _prune_unused_inputs( jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]: @@ -545,7 +529,7 @@ def _input_handler(backend: Backend, in_avals, which_explicit = util.unzip2(in_type) # Check whether we actually need an input_handler. needs_implicit = which_explicit and not all(which_explicit) - needs_out_handling = any(type(d) is core.InDBIdx for a in out_type or [] + needs_out_handling = any(type(d) is core.InDBIdx for a, _ in out_type or [] if type(a) is core.DShapedArray for d in a.shape) if not needs_implicit and not needs_out_handling: @@ -565,7 +549,7 @@ def _input_handler(backend: Backend, # Precompute which input values are needed for output types. inputs_needed_for_out_types = out_type and [ - d.val for aval in out_type if type(aval) is core.DShapedArray # type: ignore + d.val for aval, _ in out_type if type(aval) is core.DShapedArray # type: ignore for d in aval.shape if type(d) is core.InDBIdx] def elaborate(explicit_args: Sequence[Any]) -> Tuple[Tuple, Optional[Tuple]]: diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a0d78f2dbe0d..a9410eba2b5f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -18,8 +18,8 @@ from functools import partial import itertools import operator -from typing import (Any, Callable, Dict, Optional, Sequence, Tuple, - List, TypeVar, Union, cast as type_cast) +from typing import (Any, Callable, Optional, Sequence, Tuple, List, Dict, + TypeVar, Union, cast as type_cast) import warnings import numpy as np @@ -106,13 +106,20 @@ def _try_broadcast_shapes( rank, *others = {len(shape) for shape in shapes} if others: return None # must have consistent rank if not rank: return () # scalar case - result_shape = [-1] * rank - for i, sizes in enumerate(zip(*shapes)): - non_1s = {d for d in sizes if not core.symbolic_equal_dim(d, 1)} - if len(non_1s) > 1: - return None # must have equal sizes other than 1-sized axes - result_shape[i] = next(iter(non_1s), 1) - + result_shape = [] + for ds in unsafe_zip(*shapes): + if all(core.same_referent(d, ds[0]) for d in ds[1:]): + # if all axes are identical objects, the resulting size is the object + result_shape.append(ds[0]) + else: + # if all dims are equal (or 1), the result is the non-1 size (or 1) + non_1s = [d for d in ds if not core.symbolic_equal_dim(d, 1)] + if not non_1s: + result_shape.append(1) + elif all(core.symbolic_equal_dim(non_1s[0], d) for d in non_1s[1:]): + result_shape.append(non_1s[0]) + else: + return None return tuple(result_shape) def broadcast_shapes(*shapes: Tuple[Union[int, core.Tracer], ...] @@ -156,60 +163,39 @@ def _broadcast_ranks(s1, s2): def _identity(x): return x -def _extract_tracers_dyn_shape(shape: Sequence[Union[int, core.Tracer]] - ) -> Tuple[Sequence[core.Tracer], - Sequence[Optional[int]]]: - """Returns the list of tracers in `shape`, and a static version of `shape` - with tracers replaced with None""" +def _extract_tracers_dyn_shape( + shape: Sequence[Union[int, core.Tracer]] + ) -> Tuple[List[core.Tracer], List[Optional[int]]]: + # Given a sequence representing a shape, pull out Tracers, replacing with None if config.jax_dynamic_shapes: # We must gate this behavior under a flag because otherwise the errors # raised are different (and have worse source provenance information). - dyn_shape = tuple(d for d in shape if isinstance(d, core.Tracer)) - static_shape = tuple(d if not isinstance(d, core.Tracer) else None for d in shape) + dyn_shape = [d for d in shape if isinstance(d, core.Tracer)] + static_shape = [None if isinstance(d, core.Tracer) else d for d in shape] return dyn_shape, static_shape else: - return (), shape # type: ignore[return-value] + return [], list(shape) # type: ignore - -def _merge_dyn_shape(static_shape: Sequence[Optional[int]], - dyn_shape: Sequence[mlir.Value], - ) -> Sequence[mlir.Value]: - """Returns static_shape with None values filled in from dyn_shape.""" +def _merge_dyn_shape( + static_shape: Sequence[Optional[int]], + dyn_shape: Sequence[Any], + ) -> Tuple[Union[int, mlir.Value], ...]: + # Replace Nones in static_shape with elements of dyn_shape, in order dyn_shape_it = iter(dyn_shape) shape = tuple(next(dyn_shape_it) if d is None else d for d in static_shape) assert next(dyn_shape_it, None) is None return shape -def _stage_with_dyn_shape(trace: core.Trace, - prim: core.Primitive, - args: Sequence[core.Tracer], - dyn_shape_args: Sequence[core.Tracer], - params: Dict[str, Any], - static_shape: Sequence[Optional[int]], - out_dtype: Any, - out_weak_type: bool, - ) -> core.Tracer: - """Stages out a primitive that takes dynamic shapes. - - dyn_shape_args are the tracers corresponding to the None values in static_shape. - """ - if not dyn_shape_args: - return trace.default_process_primitive(prim, args, params) # type: ignore - assert len(dyn_shape_args) == sum(d is None for d in static_shape) +def _dyn_shape_staging_rule(trace, prim, out_aval, *args, **params): source_info = source_info_util.current() - - ds = iter(dyn_shape_args) - out_shape_for_tracer: List[Union[int, core.Tracer]] = [ - next(ds) if d is None else d for d in static_shape] - aval = core.DShapedArray(tuple(out_shape_for_tracer), out_dtype, out_weak_type) - out_tracer = pe.DynamicJaxprTracer(trace, aval, source_info) - invars = [*(trace.getvar(x) for x in args), *(trace.getvar(d) for d in dyn_shape_args)] # type: ignore - eqn = pe.new_jaxpr_eqn(invars, [trace.makevar(out_tracer)], # type: ignore + out_tracer = pe.DynamicJaxprTracer(trace, out_aval, source_info) + eqn = pe.new_jaxpr_eqn([trace.getvar(x) for x in args], + [trace.makevar(out_tracer)], prim, params, core.no_effects, source_info) - trace.frame.eqns.append(eqn) # type: ignore - + trace.frame.add_eqn(eqn) return out_tracer + ### traceables def neg(x: Array) -> Array: @@ -794,7 +780,12 @@ def broadcast_in_dim(operand: Array, shape: Shape, if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, (device_array.DeviceArray, core.Tracer))): return operand - dyn_shape, static_shape = _extract_tracers_dyn_shape(shape) + if config.jax_dynamic_shapes: + # We must gate this behavior under a flag because otherwise the errors + # raised are different (and have worse source provenance information). + dyn_shape, static_shape = _extract_tracers_dyn_shape(shape) + else: + dyn_shape, static_shape = [], shape # type: ignore return broadcast_in_dim_p.bind( operand, *dyn_shape, shape=tuple(static_shape), broadcast_dimensions=tuple(broadcast_dimensions)) @@ -858,7 +849,7 @@ def reshape(operand: Array, new_sizes: Shape, dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes) return reshape_p.bind( - operand, *dyn_shape, new_sizes=static_new_sizes, + operand, *dyn_shape, new_sizes=tuple(static_new_sizes), dimensions=None if dims is None or same_dims else dims) def pad(operand: Array, padding_value: Array, @@ -1165,18 +1156,18 @@ def iota(dtype: DType, size: int) -> Array: `_ operator. """ - dtype = dtypes.canonicalize_dtype(dtype) - size, = canonicalize_shape((size,)) - dyn_shape, static_shape = _extract_tracers_dyn_shape((size,)) - return iota_p.bind(*dyn_shape, dtype=dtype, shape=static_shape, dimension=0) + return broadcasted_iota(dtype, (size,), 0) def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array: """Convenience wrapper around ``iota``.""" dtype = dtypes.canonicalize_dtype(dtype) shape = canonicalize_shape(shape) + dynamic_shape = [d for d in shape if isinstance(d, core.Tracer)] + static_shape = [None if isinstance(d, core.Tracer) else d for d in shape] dimension = core.concrete_or_error( int, dimension, "dimension argument of lax.broadcasted_iota") - return iota_p.bind(dtype=dtype, shape=shape, dimension=dimension) + return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape), + dimension=dimension) def _eye(dtype: DType, shape: Shape, offset: int) -> Array: """Like numpy.eye, create a 2D array with ones on a diagonal.""" @@ -1484,6 +1475,7 @@ def _broadcasting_shape_rule(name, *avals): if len({len(shape) for shape in shapes}) != 1: msg = '{}: arrays must have same number of dimensions, got {}.' raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes))))) + # TODO(mattjj): de-duplicate with _try_broadcast_shapes result_shape = [] for ds in zip(*shapes): if all(core.same_referent(d, ds[0]) for d in ds[1:]): @@ -1492,14 +1484,13 @@ def _broadcasting_shape_rule(name, *avals): else: # if all dims are equal (or 1), the result is the non-1 size non_1s = [d for d in ds if not core.symbolic_equal_dim(d, 1)] - if non_1s: - first_non_1 = non_1s.pop() - if tuple(filter(lambda d: not core.symbolic_equal_dim(d, first_non_1), non_1s)): - raise TypeError(f'{name} got incompatible shapes for broadcasting: ' - f'{", ".join(map(str, map(tuple, shapes)))}.') - result_shape.append(first_non_1) - else: + if not non_1s: result_shape.append(1) + elif all(core.symbolic_equal_dim(non_1s[0], d) for d in non_1s[1:]): + result_shape.append(non_1s[0]) + else: + raise TypeError(f'{name} got incompatible shapes for broadcasting: ' + f'{", ".join(map(str, map(tuple, shapes)))}.') return tuple(result_shape) @@ -1582,9 +1573,14 @@ def broadcast_mhlo( assert len(aval.shape) <= len(aval_out.shape), (aval, aval_out) dims = mlir.dense_int_elements( range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape))) - arg = mhlo.BroadcastInDimOp( - mlir.aval_to_ir_type(aval.update(shape=aval_out.shape)), arg, - dims).result + if any(isinstance(d, ir.Value) for d in aval_out.shape): + arg = mhlo.DynamicBroadcastInDimOp( + mlir.aval_to_ir_type(aval_out), arg, + mlir.shape_tensor(aval_out.shape), dims).result + else: + arg = mhlo.BroadcastInDimOp( + mlir.aval_to_ir_type(aval.update(shape=aval_out.shape)), arg, + dims).result out.append(arg) return out @@ -1598,13 +1594,23 @@ def _nary_lower_mhlo(op: Callable, ctx, provided? """ del params - aval_out, = ctx.avals_out - broadcasted_args = broadcast_mhlo(aval_out, ctx.avals_in, args) + avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out + if config.jax_dynamic_shapes: + substitute = partial(_substitute_axis_sizes_in_aval, ctx.axis_size_env) + avals_in = map(substitute, avals_in) + aval_out = substitute(aval_out) + broadcasted_args = broadcast_mhlo(aval_out, avals_in, args) if explicit_type: return op(mlir.aval_to_ir_type(aval_out), *broadcasted_args).results else: return op(*broadcasted_args).results +def _substitute_axis_sizes_in_aval( + env: Dict[core.Var, ir.Value], a: core.AbstractValue) -> core.AbstractValue: + if isinstance(a, core.DShapedArray): + return a.update(shape=tuple(env.get(d, d) for d in a.shape)) # type: ignore + return a + _float = {np.floating} _complex = {np.complexfloating} @@ -2727,11 +2733,13 @@ def _broadcast_in_dim_fwd_rule(eqn): return [None], eqn def _broadcast_in_dim_staging_rule( - trace, x, *dyn_shape, shape, broadcast_dimensions): + trace, x, *dyn, shape, broadcast_dimensions): params = dict(shape=shape, broadcast_dimensions=broadcast_dimensions) - return _stage_with_dyn_shape(trace, broadcast_in_dim_p, - (x,), dyn_shape, params, - shape, x.dtype, x.weak_type) + if not dyn: + return trace.default_process_primitive(broadcast_in_dim_p, (x,), params) + aval = core.DShapedArray(_merge_dyn_shape(shape, dyn), x.dtype, x.weak_type) + return _dyn_shape_staging_rule(trace, broadcast_in_dim_p, aval, x, *dyn, + **params) def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape, shape, broadcast_dimensions): @@ -2791,7 +2799,6 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions): mlir.aval_to_ir_type(aval_out), x, mlir.shape_tensor(shape), mlir.dense_int_elements(broadcast_dimensions), - None, None, ).results else: return mhlo.BroadcastInDimOp( @@ -2824,6 +2831,7 @@ def _broadcast_in_dim_pp_rule(eqn, context, settings): pe.padding_rules[broadcast_in_dim_p] = _broadcast_in_dim_padding_rule core.custom_typechecks[broadcast_in_dim_p] = _broadcast_in_dim_typecheck_rule mlir.register_lowering(broadcast_in_dim_p, _broadcast_in_dim_lower) +# TODO(mattjj): un-comment the next line # core.pp_eqn_rules[broadcast_in_dim_p] = _broadcast_in_dim_pp_rule @@ -3154,7 +3162,8 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions): msg = 'reshape new_sizes must all be positive, got {}.' raise TypeError(msg.format(new_sizes)) # TODO(necula): re-enable this check - if not config.jax_dynamic_shapes and not core.same_shape_sizes(np.shape(operand), new_sizes): + if (not config.jax_dynamic_shapes and + not core.same_shape_sizes(np.shape(operand), new_sizes)): msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.' raise TypeError(msg.format(new_sizes, np.shape(operand))) if dimensions is not None: @@ -3201,12 +3210,6 @@ def merge_const_sizes(shape): new_sizes=masking.padded_shape_as_value(new_sizes), dimensions=dimensions) -reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, - 'reshape') -ad.deflinear2(reshape_p, _reshape_transpose_rule) -batching.primitive_batchers[reshape_p] = _reshape_batch_rule -masking.masking_rules[reshape_p] = _reshape_masking_rule - def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions): aval_out, = ctx.avals_out if dimensions is not None: @@ -3220,17 +3223,23 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions): else: return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), x).results -mlir.register_lowering(reshape_p, _reshape_lower) - def _reshape_staging_rule( - trace, x, *dyn_shape, new_sizes, dimensions): + trace, x, *dyn, new_sizes, dimensions): params = dict(new_sizes=new_sizes, dimensions=dimensions) - # TODO(necula): shouldn't this include the same checks as in reshape_shape_rule? - return _stage_with_dyn_shape(trace, reshape_p, (x,), dyn_shape, params, - new_sizes, x.dtype, x.weak_type) + if not dyn: + return trace.default_process_primitive(reshape_p, (x,), params) + av = core.DShapedArray(_merge_dyn_shape(new_sizes, dyn), x.dtype, x.weak_type) + return _dyn_shape_staging_rule(trace, reshape_p, av, x, *dyn, **params) +reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, + 'reshape') +ad.deflinear2(reshape_p, _reshape_transpose_rule) +batching.primitive_batchers[reshape_p] = _reshape_batch_rule +masking.masking_rules[reshape_p] = _reshape_masking_rule +mlir.register_lowering(reshape_p, _reshape_lower) pe.custom_staging_rules[reshape_p] = _reshape_staging_rule + def _rev_shape_rule(operand, *, dimensions): _check_shapelike('rev', 'dimensions', dimensions) if len(set(dimensions)) != len(dimensions): @@ -4381,11 +4390,12 @@ def _iota_abstract_eval(*, dtype, shape, dimension): iota_p.def_impl(partial(xla.apply_primitive, iota_p)) iota_p.def_abstract_eval(_iota_abstract_eval) -def _iota_staging_rule( - trace, *dyn_shape, dtype, shape, dimension): +def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension): params = dict(dtype=dtype, shape=shape, dimension=dimension) - return _stage_with_dyn_shape(trace, iota_p, (), dyn_shape, params, - shape, dtype, False) + if not dyn_shape: + return trace.default_process_primitive(iota_p, (), params) + aval = core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False) + return _dyn_shape_staging_rule(trace, iota_p, aval, *dyn_shape, **params) pe.custom_staging_rules[iota_p] = _iota_staging_rule def _iota_typecheck_rule(*dyn_shape, dtype, shape, dimension): diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 61ea7096596f..33a5db692540 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -71,9 +71,10 @@ reciprocal, remainder, right_shift, rint, sign, signbit, sin, sinc, sinh, sqrt, square, subtract, tan, tanh, true_divide) from jax._src.numpy.util import ( # noqa: F401 - _arraylike, _broadcast_arrays, _broadcast_to, _check_arraylike, _complex_elem_type, _promote_args, - _promote_args_inexact, _promote_dtypes, _promote_dtypes_inexact, _promote_shapes, _register_stackable, - _stackable, _where, _wraps) + _arraylike, _broadcast_arrays, _broadcast_to, _check_arraylike, + _complex_elem_type, _promote_args, _promote_args_inexact, _promote_dtypes, + _promote_dtypes_inexact, _promote_shapes, _register_stackable, _stackable, + _where, _wraps) from jax._src.numpy.vectorize import vectorize from jax._src.ops import scatter from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio, @@ -2099,8 +2100,12 @@ def arange(start: core.DimSize, stop: Optional[core.DimSize]=None, dtype = result_type(start, *(x for x in [stop, step] if x is not None)) dtype = _jnp_dtype(dtype) if stop is None and step is None: - start = require(start, msg("stop")) - start = np.ceil(start).astype(int) + if (jax.config.jax_dynamic_shapes and + not isinstance(core.get_aval(start), core.ConcreteArray)): + start = ceil(start).astype(int) # note using jnp here + else: + start = require(start, msg("stop")) + start = np.ceil(start).astype(int) return lax.iota(dtype, start) else: start = require(start, msg("start")) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index be231b13d792..4b0223bcdd21 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -24,13 +24,16 @@ from jax._src import dtypes from jax._src.lax import lax as lax_internal from jax._src.numpy.ndarray import ndarray -from jax._src.util import safe_zip +from jax._src.util import safe_zip, safe_map from jax._src import api from jax import core from jax._src.lax import lax import numpy as np +zip, unsafe_zip = safe_zip, zip +map, unsafe_map = safe_map, map + _T = TypeVar("_T") _parameter_break = re.compile("\n(?=[A-Za-z_])") @@ -219,20 +222,21 @@ def _promote_shapes(fun_name, *args): return args else: shapes = [np.shape(arg) for arg in args] - if all(len(shapes[0]) == len(s) for s in shapes[1:]): - return args # no need for rank promotion, so rely on lax promotion - nonscalar_ranks = {len(shp) for shp in shapes if shp} - if len(nonscalar_ranks) < 2: - return args + if config.jax_dynamic_shapes: + # With dynamic shapes we don't support singleton-dimension broadcasting; + # we instead broadcast out to the full shape as a temporary workaround. + # TODO(mattjj): revise this workaround + res_shape = lax.broadcast_shapes(*shapes) # Can raise an error! + return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)] else: - if config.jax_numpy_rank_promotion != "allow": - _rank_promotion_warning_or_error(fun_name, shapes) - if config.jax_dynamic_shapes: - # With dynamic shapes we don't support singleton-dimension broadcasting; - # we instead broadcast out to the full shape as a temporary workaround. - res_shape = lax.broadcast_shapes(*shapes) - return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)] + if all(len(shapes[0]) == len(s) for s in shapes[1:]): + return args # no need for rank promotion, so rely on lax promotion + nonscalar_ranks = {len(shp) for shp in shapes if shp} + if len(nonscalar_ranks) < 2: + return args # rely on lax scalar promotion else: + if config.jax_numpy_rank_promotion != "allow": + _rank_promotion_warning_or_error(fun_name, shapes) result_rank = len(lax.broadcast_shapes(*shapes)) return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp) for arg, shp in zip(args, shapes)] diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 32f0532c1caa..c484a91f3f36 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -315,34 +315,33 @@ def process_primitive(self, primitive, tracers, params): def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) - nonzero_tangents, tangent_tree_def = tree_flatten(tangents) - nz_tangents = [type(t) is not Zero for t in tangents] + which_nz = [ type(t) is not Zero for t in tangents] + tangents = [t if type(t) is not Zero else None for t in tangents] + args, in_tree = tree_flatten((primals, tangents)) if 'name' in params and not config.jax_experimental_name_stack: params = dict(params, name=wrap_name(params['name'], 'jvp')) f_jvp = jvp_subtrace(f, self.main) - f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp) + f_jvp, which_nz_out = nonzero_tangent_outputs(f_jvp) if isinstance(call_primitive, core.MapPrimitive): in_axes = params['in_axes'] - tangent_in_axes = [ax for ax, nz in zip(in_axes, nz_tangents) if nz] + tangent_in_axes = [ax for ax, nz in zip(in_axes, which_nz) if nz] out_axes_thunk = params['out_axes_thunk'] - # The new thunk depends deterministically on the old thunk and the wrapped function. - # Any caching already has to include the wrapped function as part of the key, so we - # only use the previous thunk for equality checks. - # NOTE: This assumes that the output tangents being zero is a deterministic - # function of which input tangents were zero. - @as_hashable_function(closure=(tuple(nz_tangents), out_axes_thunk)) + # NOTE: This assumes that the output tangents being zero is a + # deterministic function of which input tangents were zero. + @as_hashable_function(closure=out_axes_thunk) def new_out_axes_thunk(): - out_axes = out_axes_thunk() - return (*out_axes, *(ax for ax, nz in zip(out_axes, nz_tangents_out()) if nz)) - params = dict(params, - in_axes=(*in_axes, *tangent_in_axes), + out_ax = out_axes_thunk() + return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) + params = dict(params, in_axes=(*in_axes, *tangent_in_axes), out_axes_thunk=new_out_axes_thunk) - f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def) + f_jvp, out_tree = traceable(f_jvp, in_tree) update_params = call_param_updaters.get(call_primitive) - new_params = update_params(params, nz_tangents) if update_params else params - f_jvp = _update_annotation(f_jvp, f.in_type, nz_tangents) - result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params) - primal_out, tangent_out = tree_unflatten(out_tree_def(), result) + new_params = update_params(params, which_nz) if update_params else params + result = call_primitive.bind(_update_annotation(f_jvp, f.in_type, which_nz), + *args, **new_params) + primal_out, tangent_out = tree_unflatten(out_tree(), result) + tangent_out = [Zero(get_aval(p).at_least_vspace()) if t is None else t + for p, t in zip(primal_out, tangent_out)] return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)] def post_process_call(self, call_primitive, out_tracers, params): @@ -588,13 +587,14 @@ def instantiate_zeros_aval(aval, tangent): return tangent @lu.transformation_with_aux -def traceable(num_primals, in_tree_def, *primals_and_tangents): - new_primals = primals_and_tangents[:num_primals] - new_tangents = primals_and_tangents[num_primals:] - new_tangents = tree_unflatten(in_tree_def, new_tangents) - primal_out, tangent_out = yield (new_primals, new_tangents), {} - out_flat, tree_def = tree_flatten((primal_out, tangent_out)) - yield out_flat, tree_def +def traceable(in_tree, *primals_and_tangents): + primals, tangents = tree_unflatten(in_tree, primals_and_tangents) + tangents = [Zero(get_aval(p).at_least_vspace()) if t is None else t + for p, t in zip(primals, tangents)] + primals_out, tangents_out = yield (primals, tangents), {} + tangents_out = [None if type(t) is Zero else t for t in tangents_out] + out_flat, out_tree = tree_flatten((primals_out, tangents_out)) + yield out_flat, out_tree def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 58caca354013..3e57b82eb5cc 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -418,6 +418,7 @@ class LoweringRuleContext: avals_out: Any # Usually Sequence[core.AbstractValue], but sometimes None. tokens_in: TokenSet tokens_out: Optional[TokenSet] # Mutable store for output containers + axis_size_env: Optional[Dict[core.Var, ir.Value]] = None # Dynamic axis sizes def set_tokens_out(self, tokens_out: TokenSet): assert self.tokens_out is None, 'Should only set `tokens_out` once.' @@ -928,13 +929,19 @@ def write(v: core.Var, node: Sequence[ir.Value]): f"found for platform {ctx.platform}") eqn_ctx = (ctx.replace(name_stack=source_info.name_stack) if - config.jax_experimental_name_stack else ctx) + config.jax_experimental_name_stack else ctx) effects = [eff for eff in eqn.effects if eff in core.ordered_effects] tokens_in = tokens.subset(effects) + avals_in = map(aval, eqn.invars) rule_ctx = LoweringRuleContext( - module_context=eqn_ctx, primitive=eqn.primitive, - avals_in=map(aval, eqn.invars), avals_out=map(aval, eqn.outvars), - tokens_in=tokens_in, tokens_out=None) + module_context=eqn_ctx, primitive=eqn.primitive, avals_in=avals_in, + avals_out=map(aval, eqn.outvars), tokens_in=tokens_in, + tokens_out=None) + if config.jax_dynamic_shapes: + axis_size_env = {d: read(d)[0] for a in avals_in + if type(a) is core.DShapedArray for d in a.shape + if type(d) is core.Var} + rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env) ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes), **eqn.params) if effects: @@ -976,16 +983,33 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable: The returned function does not use `avals_out`, so callers may pass any value as `avals_out`.""" def f_lowered(ctx, *args, **params): - if multiple_results: - f = fun - else: - f = lambda *args, **kw: (fun(*args, **kw),) + f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) wrapped_fun = lu.wrap_init(f, params) axis_env = ctx.module_context.axis_env - with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)): - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) - out, tokens = jaxpr_subcomp(ctx.module_context, jaxpr, ctx.tokens_in, _ir_consts(consts), - *map(wrap_singleton_ir_values, args)) + + if config.jax_dynamic_shapes: + # We might be applying this function to arguments with dynamic shapes, + # i.e. there might be Vars in the shape tuples of ctx.avals_in. In that + # case, we need to form a jaxpr with leading binders for those axis size + # arguments (by computing an InputType and using trace_to_jaxpr_dynamic2), + # and we need to call jaxpr_subcomp with these arguments made explicit. + args = (*ctx.axis_size_env.values(), *args) + idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)} + i32_aval = core.ShapedArray((), np.dtype('int32')) + implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env) + explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) + if type(a) is core.DShapedArray else a, True) + for a in ctx.avals_in] + wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args)) + with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)): + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun) + else: + with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)): + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + + out, tokens = jaxpr_subcomp( + ctx.module_context, jaxpr, ctx.tokens_in, _ir_consts(consts), + *map(wrap_singleton_ir_values, args)) ctx.set_tokens_out(tokens) return out diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index e1a6a61c0f7d..3ea332fbea0b 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -667,7 +667,7 @@ def trace_to_subjaxpr_nounits_dyn( out_type = [(a.update(shape=tuple(idx_map.get(d, d) for d in a.shape)) # type: ignore if type(a) is DShapedArray else a, True) for a in out_avals] - # Which residuals are just forwarded inputs? Check obj id, then prune. + # Which residuals are just forwarded inputs? Check obj id, then prune. id_map = {id(c.recipe.val): i for i, c in enumerate(in_consts_full) # type: ignore if c is not None} fwds: List[Optional[int]] = [id_map.get(id(c)) for c in res] diff --git a/tests/api_test.py b/tests/api_test.py index b85584b1cb0a..111efbcf51a7 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -8522,7 +8522,8 @@ def f(x, y): x = np.ones(3) y = np.ones(3) - with self.assertRaisesRegex(TypeError, 'add got incompatible shapes for broadcasting'): + with self.assertRaisesRegex( + Exception, '[Ii]ncompatible shapes for broadcasting'): _ = jax.make_jaxpr(f, abstracted_axes=({0: 'n'}, {}))(x, y) def test_shape_errors_distinct_vars(self): @@ -8531,7 +8532,8 @@ def f(x, y): x = np.ones(3) y = np.ones(3) - with self.assertRaisesRegex(TypeError, 'add got incompatible shapes for broadcasting'): + with self.assertRaisesRegex( + Exception, '[Ii]ncompatible shapes for broadcasting'): _ = jax.make_jaxpr(f, abstracted_axes=({0: 'n'}, {0: 'm'}))(x, y) def test_basic_dot(self): @@ -9096,6 +9098,7 @@ def f(x): # f32[w] -> f32[w] f(np.ones((5,), dtype=np.float32)) # TODO: add assertions + @unittest.skip('failing w/ iree error') @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_broadcast(self): @partial(jax.jit, abstracted_axes=({0: 'w'},)) @@ -9112,6 +9115,7 @@ def f(x): # f32[w] -> f32[w] f(np.ones((5,), dtype=np.float32)) # TODO: add assertions + @unittest.skip('failing w/ iree error') @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") def test_stack(self): @partial(jax.jit, abstracted_axes=({0: 'w'},)) @@ -9121,6 +9125,23 @@ def f(x): f(np.ones((5,), dtype=np.float32)) # TODO: add assertions + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_jit_dependent_pair_output_iree(self): + # Like the above 'polymorhpic output' test, but now with a `2 * n`! + count = 0 + + @jax.jit + def f(n): + nonlocal count + count += 1 + return jnp.arange(2 * n) + + x = f(3) + y = f(4) + self.assertAllClose(x, jnp.arange(2 * 3), check_dtypes=False) + self.assertAllClose(y, jnp.arange(2 * 4), check_dtypes=False) + self.assertEqual(count, 1) + def test_slicing_basic(self): f = jax.jit(lambda x, n: jnp.sum(x[:n])) # TODO(mattjj): revise getslice, add typecheck rule for it, enable checks @@ -9446,6 +9467,48 @@ def loss_lin(params, batch): jaxpr = jax.make_jaxpr(jax.grad(loss))(params, batch) core.check_jaxpr(jaxpr.jaxpr) + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_mlp_autodiff_dynamic_batch_iree(self): + count = 0 + + def predict(params, inputs): + for W, b in params: + outputs = jnp.dot(inputs, W) + b + inputs = jnp.maximum(0, outputs) + return outputs + + def loss_ref(params, batch): + nonlocal count + count += 1 # count retraces + inputs, targets = batch + predictions = predict(params, inputs) + return jnp.sum((predictions - targets) ** 2) + + loss = jax.jit(loss_ref, abstracted_axes=({}, {0: 'n'})) + + params = [(jnp.ones((784, 256)), jnp.ones(256)), + (jnp.ones((256, 10)), jnp.ones( 10))] + + # two different size batches + batch1 = (inputs, targets) = (jnp.ones((128, 784)), jnp.ones((128, 10))) + batch2 = (inputs, targets) = (jnp.ones((32, 784)), jnp.ones((32, 10))) + + _ = loss(params, batch1) + _ = loss(params, batch2) + self.assertEqual(count, 1) + + _ = grad(loss)(params, batch1) + _ = grad(loss)(params, batch2) + self.assertEqual(count, 2) + + ans = loss( params, batch1) + expected = loss_ref(params, batch1) + self.assertAllClose(ans, expected) + + ans = grad(loss )(params, batch1) + expected = grad(loss_ref)(params, batch1) + self.assertAllClose(ans, expected) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())