diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 5e5d2fc1031b..dd9ea3e18917 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -299,6 +299,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, if config.jax_dynamic_shapes: keep_unused = True has_outfeed = False + donated_invars = [False] * len(fun.in_type) else: has_outfeed = core.jaxpr_uses_outfeed(jaxpr) jaxpr = apply_outfeed_rewriter(jaxpr) @@ -318,8 +319,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, device = _xla_callable_device(nreps, backend, device, arg_devices) backend = xb.get_device_backend(device) if device else xb.get_backend(backend) - if (config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr) and - not _backend_supports_unbounded_dynamic_shapes(backend)): + if config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr): jaxpr, consts = pe.pad_jaxpr(jaxpr, consts) map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr))) @@ -517,6 +517,7 @@ def aval_to_num_buffers(aval: core.AbstractValue) -> int: num_buffers_handlers[core.ShapedArray] = lambda _: 1 num_buffers_handlers[core.DShapedArray] = lambda _: 1 num_buffers_handlers[core.ConcreteArray] = lambda _: 1 +num_buffers_handlers[core.AbstractBInt] = lambda _: 1 def _input_handler(backend: Backend, @@ -649,19 +650,24 @@ def dynamic_array_result_handler(sticky_device: Optional[Device], return partial(_dynamic_array_result_handler, sticky_device, aval) def _dynamic_array_result_handler(sticky_device, aval, env, buf): - if all(type(d) is int for d in aval.shape): - del env + in_env, out_env = env or (None, None) + shape = [in_env[d.val] if type(d) is core.InDBIdx else + out_env[d.val] if type(d) is core.OutDBIdx else d + for d in aval.shape] + if all(type(d) is int for d in shape): + aval = core.ShapedArray(tuple(shape), aval.dtype) return _maybe_create_array_from_da(buf, aval, sticky_device) + elif any(type(d) is core.BInt for d in shape): + padded_shape = [d.bound if type(d) is core.BInt else d for d in shape] + buf_aval = core.ShapedArray(tuple(padded_shape), aval.dtype, aval.weak_type) + data = _maybe_create_array_from_da(buf, buf_aval, sticky_device) + return core.PaddedArray(aval.update(shape=tuple(shape)), data) else: - assert env is not None - in_env, out_env = env - shape = [in_env[d.val] if type(d) is core.InDBIdx else - out_env[d.val] if type(d) is core.OutDBIdx else d - for d in aval.shape] aval = core.ShapedArray(tuple(shape), aval.dtype) return _maybe_create_array_from_da(buf, aval, sticky_device) + result_handlers: Dict[ Type[core.AbstractValue], Callable[[Optional[Device], Any], ResultHandler]] = {} @@ -669,6 +675,8 @@ def _dynamic_array_result_handler(sticky_device, aval, env, buf): result_handlers[core.ShapedArray] = array_result_handler result_handlers[core.DShapedArray] = dynamic_array_result_handler result_handlers[core.ConcreteArray] = array_result_handler +result_handlers[core.AbstractBInt] = \ + lambda _, a: lambda _, b: core.BInt(int(b), a.bound) def needs_check_special(): @@ -1005,6 +1013,7 @@ def _device_put_token(_, device): device_put_handlers.update((t, _device_put_array) for t in array_types) device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types) device_put_handlers[core.Token] = _device_put_token +device_put_handlers[core.BInt] = lambda x, d: _device_put_scalar(x.val, d) def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[Device]): @@ -1012,6 +1021,7 @@ def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_a return (x.device_buffer,) for t in device_array.device_array_types: device_put_handlers[t] = _device_put_device_array +device_put_handlers[core.PaddedArray] = lambda x, d: device_put(x._data, d) def _copy_device_array_to_device( x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], diff --git a/jax/_src/iree.py b/jax/_src/iree.py index dda183e8eb18..ff8f3431f200 100644 --- a/jax/_src/iree.py +++ b/jax/_src/iree.py @@ -104,8 +104,8 @@ def block_until_ready(self) -> IreeBuffer: return self # no async # overrides repr on base class which expects _value and aval attributes - def __repr__(self): - return f'IreeBuffer({self.to_py()})' + def __repr__(self): return f'IreeBuffer({self.to_py()})' + _value = property(to_py) class IreeExecutable: diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a9410eba2b5f..1d6e549e6def 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1440,6 +1440,7 @@ def unop(result_dtype, accepted_dtypes, name): weak_type_rule=weak_type_rule) batching.defvectorized(prim) masking.defvectorized(prim) + pe.padding_rules[prim] = lambda _, __, x, **kw: [prim.bind(x, **kw)] return prim standard_unop = partial(unop, _identity) _attrgetter = lambda name: lambda x, **kwargs: getattr(x, name) @@ -1515,6 +1516,7 @@ def naryop(result_dtype, accepted_dtypes, name): weak_type_rule=weak_type_rule) batching.defbroadcasting(prim) masking.defnaryop(prim) + pe.padding_rules[prim] = lambda _, __, *xs, **kw: [prim.bind(*xs, **kw)] return prim standard_naryop = partial(naryop, _input_dtype) @@ -2080,7 +2082,6 @@ def _add_inverse(r, x, y): ad.primitive_jvps[add_p] = _add_jvp ad.primitive_transposes[add_p] = _add_transpose mlir.register_lowering(add_p, partial(_nary_lower_mhlo, mhlo.AddOp)) -pe.padding_rules[add_p] = lambda _, __, x, y: [add(x, y)] def _sub_jvp(primals, tangents): x, y = primals @@ -2110,7 +2111,6 @@ def _sub_transpose(t, x, y): ad.primitive_jvps[sub_p] = _sub_jvp ad.primitive_transposes[sub_p] = _sub_transpose mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubOp)) -pe.padding_rules[sub_p] = lambda _, __, x, y: [sub(x, y)] def _mul_transpose(ct, x, y): @@ -2137,7 +2137,6 @@ def _mul_inverse(r, x, y): lambda ydot, x, y: mul(x, ydot)) ad.primitive_transposes[mul_p] = _mul_transpose mlir.register_lowering(mul_p, partial(_nary_lower_mhlo, mhlo.MulOp)) -pe.padding_rules[mul_p] = lambda _, __, x, y: [mul(x, y)] def _div_transpose_rule(cotangent, x, y): assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y) @@ -2174,7 +2173,6 @@ def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x): lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) mlir.register_lowering(max_p, partial(_nary_lower_mhlo, mlir.max_mhlo)) -pe.padding_rules[max_p] = lambda _, __, x, y: [max(x, y)] min_p: core.Primitive = standard_naryop([_any, _any], 'min') ad.defjvp2(min_p, @@ -2297,9 +2295,13 @@ def _convert_elt_type_pp_rule(eqn, context, settings): printed_params = {} if eqn.params['weak_type']: printed_params['weak_type'] = True - return [pp.text(eqn.primitive.name), - core.pp_kv_pairs(sorted(printed_params.items()), context, settings), - pp.text(" ") + core.pp_vars(eqn.invars, context)] + lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes) + rhs = [pp.text(eqn.primitive.name), + core.pp_kv_pairs(sorted(printed_params.items()), context, settings), + pp.text(" ") + core.pp_vars(eqn.invars, context)] + annotation = (source_info_util.summarize(eqn.source_info) + if settings.source_info else None) + return [lhs, pp.text(" = ", annotation=annotation), *rhs] convert_element_type_p = Primitive('convert_element_type') @@ -2756,7 +2758,7 @@ def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape, assert isinstance(d, core.Tracer) new_shape.append(None) new_dyn_shape.append(d) - return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=new_shape, + return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=tuple(new_shape), broadcast_dimensions=broadcast_dimensions)] def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions): @@ -2820,8 +2822,18 @@ def _broadcast_in_dim_pp_rule(eqn, context, settings): if settings.source_info else None) return [lhs, pp.text(" = ", annotation=annotation), *rhs] +def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions): + if not any(isinstance(d, core.BInt) for d in shape): + shape = _broadcast_in_dim_shape_rule( # error checking + x, shape=shape, broadcast_dimensions=broadcast_dimensions) + return core.ShapedArray(shape, x.dtype, x.weak_type, x.named_shape) + # If any BInts in shape, produce a DShapedArray (even if x is a ShapedArray) + # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code + return core.DShapedArray(shape, x.dtype, x.weak_type) + broadcast_in_dim_p = standard_primitive( _broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim') +broadcast_in_dim_p.def_abstract_eval(_broadcast_in_dim_abstract_eval) ad.primitive_jvps[broadcast_in_dim_p] = _broadcast_in_dim_jvp_rule ad.primitive_transposes[broadcast_in_dim_p] = _broadcast_in_dim_transpose_rule batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule @@ -3605,8 +3617,8 @@ def _reduce_sum_padding_rule(in_avals, out_avals, operand, *, axes): def _replace_masked_values(x, val, padded_axes): if not padded_axes: return x - masks = [broadcasted_iota(np.dtype('int32'), x.shape, i) < d - for i, d in padded_axes] + dtype = dtypes._scalar_type_to_dtype(int) + masks = [broadcasted_iota(dtype, x.shape, i) < d for i, d in padded_axes] return select(_reduce(operator.and_, masks), x, full_like(x, val)) @@ -4384,7 +4396,10 @@ def _iota_abstract_eval(*, dtype, shape, dimension): if not 0 <= dimension < len(shape): raise ValueError("iota dimension must be between 0 and len(shape), got " f"dimension={dimension} for shape {shape}") - return ShapedArray(shape, dtype) + if not any(isinstance(d, core.BInt) for d in shape): + return ShapedArray(shape, dtype) + # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code + return core.DShapedArray(shape, dtype, False) iota_p = Primitive('iota') iota_p.def_impl(partial(xla.apply_primitive, iota_p)) @@ -4425,12 +4440,46 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension): mlir.i64_attr(dimension)).results mlir.register_lowering(iota_p, _iota_lower) +def _iota_pp_rule(eqn, context, settings): + printed_params = {} + if len(eqn.params['shape']) > 1: + printed_params['dimension'] = eqn.params['dimension'] + lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes) + rhs = [pp.text(eqn.primitive.name), + core.pp_kv_pairs(sorted(printed_params.items()), context, settings), + pp.text(" ") + core.pp_vars(eqn.invars, context)] + annotation = (source_info_util.summarize(eqn.source_info) + if settings.source_info else None) + return [lhs, pp.text(" = ", annotation=annotation), *rhs] +# core.pp_eqn_rules[iota_p] = _iota_pp_rule + +def _iota_padding_rule(in_avals, out_avals, *dyn_shape, dtype, shape, dimension): + out_aval, = out_avals + new_shape = [] + new_dyn_shape = [] + for d in out_aval.shape: + if type(d) is pe.BoundedAxisSize: + new_shape.append(d.bound) + elif type(d) is int: + new_shape.append(d) + else: + assert isinstance(d, core.Tracer) + new_shape.append(None) + new_dyn_shape.append(d) + return [iota_p.bind(*new_dyn_shape, shape=tuple(new_shape), + dtype=dtype, dimension=dimension)] +pe.padding_rules[iota_p] = _iota_padding_rule + def make_bint(i, bd: int): return bint_p.bind(i, bd=bd) bint_p = core.Primitive('bint') +@bint_p.def_impl +def _bint_impl(i, *, bd): + return core.BInt(i, bd) + @bint_p.def_abstract_eval def bint_abstract_eval(_, *, bd: int): return core.AbstractBInt(bound=bd) @@ -4566,7 +4615,7 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False): if not len(obj): # pylint: disable=g-explicit-length-test return if (config.jax_dynamic_shapes and isinstance(obj, (tuple, list)) and - any(isinstance(d, core.Tracer) for d in obj)): + any(isinstance(d, (core.Tracer, core.BInt)) for d in obj)): return # TODO(mattjj): handle more checks in the dynamic shape case obj_arr = np.array(obj) if obj_arr.ndim != 1: diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 0a624505e62f..8a7d1e7520d1 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -903,7 +903,7 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): dynamic_slice_p = standard_primitive( _dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice', weak_type_rule=_argnum_weak_type(0)) -ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp # TODO +ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 33a5db692540..1c63e3553236 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2101,8 +2101,12 @@ def arange(start: core.DimSize, stop: Optional[core.DimSize]=None, dtype = _jnp_dtype(dtype) if stop is None and step is None: if (jax.config.jax_dynamic_shapes and + not isinstance(core.get_aval(start), core.AbstractBInt) and not isinstance(core.get_aval(start), core.ConcreteArray)): start = ceil(start).astype(int) # note using jnp here + elif (isinstance(start, core.BInt) or isinstance(start, core.Tracer) and + isinstance(core.get_aval(start), core.AbstractBInt)): + pass else: start = require(start, msg("stop")) start = np.ceil(start).astype(int) diff --git a/jax/core.py b/jax/core.py index efe48b2d3018..9e0be5bca817 100644 --- a/jax/core.py +++ b/jax/core.py @@ -1104,22 +1104,9 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType: for v in jaxpr.invars] return tuple(out) - class Bot(AbstractValue): pass - bot = Bot() -class AbstractBInt(AbstractValue): - __slots__ = ['bound'] - bound: int - def __init__(self, bound): - self.bound = bound - def str_short(self, short_dtypes=False) -> str: - return f'bint{{≤{self.bound}}}[]' - def __eq__(self, other): - return type(other) is AbstractBInt and self.bound == other.bound - def __hash__(self) -> int: - return hash((type(self), self.bound)) def lattice_join(x: Optional[AbstractValue], y: Optional[AbstractValue]) -> AbstractValue: @@ -1171,9 +1158,6 @@ def get_aval(x): return concrete_aval(x) -pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {} - - def concretization_function_error(fun, suggest_astype=False): fname = getattr(fun, "__name__", fun) fname_context = f"The problem arose with the `{fname}` function. " @@ -1204,7 +1188,7 @@ def _short_dtype_name(dtype): class UnshapedArray(AbstractValue): __slots__ = ['dtype', 'weak_type'] - array_abstraction_level = 3 + array_abstraction_level = 4 def __init__(self, dtype, weak_type=False): self.dtype = np.dtype(dtype) @@ -1269,77 +1253,9 @@ def shape(self): raise TypeError(msg) -# We have a convention of reusing AbsractValues as types, in particular reusing -# ShapedArrays as types, even though we could make a distinction and use -# abstract values during tracing only. This reuse becomes a bit more extreme -# with DShapedArrays. A DShapedArray's shape attribute is a tuple which can -# contain several different types: ints, other AbstractValues (specifically at -# the input and output to pe.trace_to_jaxpr_dynamic), Tracers (while tracing), -# or Vars (when used as jaxpr type annotations). We could reduce this -# polymorphism if it seems cleaner, though it's kind of convenient! -AxisSizeForTracing = Union[int, Tracer] -AxisSizeForJaxprType = Union[int, Var] -AxisSizeForJaxprTracingSpec = Union[int, AbstractValue] -AxisSize = Union[AxisSizeForTracing, AxisSizeForJaxprType, - AxisSizeForJaxprTracingSpec] - -class DShapedArray(UnshapedArray): - __slots__ = ['shape'] - shape: Tuple[AxisSize, ...] # noqa: F821 - array_abstraction_level: int = 2 - - def __init__(self, shape, dtype, weak_type): - self.shape = shape - self.dtype = dtype - self.weak_type = weak_type - - ndim = property(lambda self: len(self.shape)) - size = property(lambda self: prod(self.shape)) - - def str_short(self, short_dtypes=False) -> str: - del short_dtypes # ignored - shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else '' - dtype = _short_dtype_name(self.dtype) - return f'{dtype}[{shape}]' - __str__ = __repr__ = str_short - - def update(self, shape=None, dtype=None, weak_type=None): - if shape is None: - shape = self.shape - if dtype is None: - dtype = self.dtype - if weak_type is None: - weak_type = self.weak_type - return DShapedArray(shape, dtype, weak_type) - - def __eq__(self, other): - return (type(self) is type(other) - and self.dtype == other.dtype and self.shape == other.shape - and self.weak_type == other.weak_type) - - def __hash__(self): - return hash((self.shape, self.dtype, self.weak_type)) - - def join(self, other): - if (symbolic_equal_shape(self.shape, other.shape) and - self.dtype == other.dtype): - weak_type = self.weak_type and other.weak_type - return self.update(weak_type=weak_type) - elif self.dtype == other.dtype: - return UnshapedArray(self.dtype) - else: - raise TypeError(self, other) - - def at_least_vspace(self): - return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type) - -del AxisSize, AxisSizeForTracing, AxisSizeForJaxprType, \ - AxisSizeForJaxprTracingSpec - class ShapedArray(UnshapedArray): __slots__ = ['shape', 'named_shape'] - array_abstraction_level = 1 + array_abstraction_level = 2 def __init__(self, shape, dtype, weak_type=False, named_shape=None): self.shape = canonicalize_shape(shape) @@ -1415,6 +1331,7 @@ def _len(self, ignored_tracer): def _forward_to_value(self, fun, ignored_tracer, *args): return fun(self.val, *args) + class ConcreteArray(ShapedArray): __slots__ = ['val'] array_abstraction_level = 0 @@ -1477,6 +1394,135 @@ def primal_dtype_to_tangent_dtype(primal_dtype): else: return primal_dtype + +# Dynamic shape stuff below here! We keep the abstract values distinct just so +# as not to interfere with any static shape machinery. + +# We have a convention of reusing AbsractValues as types, even though we could +# make a distinction and use abstract values during tracing only. This reuse +# becomes a bit more extreme with DShapedArrays. A DShapedArray's shape +# attribute is a tuple which can contain several different types: int, BInt, +# Tracer (while tracing), Var (when used as jaxpr type annotations), or +# DBIdx/InDBIdx/OutDBIdx (when used in InputType or OutputType). We could reduce +# this polymorphism if it seems cleaner, though it's kind of convenient! +AxisSize = Any + +class DShapedArray(UnshapedArray): + __slots__ = ['shape'] + shape: Tuple[AxisSize, ...] # noqa: F821 + array_abstraction_level: int = 3 + + def __init__(self, shape, dtype, weak_type): + self.shape = shape + self.dtype = dtype + self.weak_type = weak_type + + ndim = property(lambda self: len(self.shape)) + size = property(lambda self: prod(self.shape)) + + def str_short(self, short_dtypes=False) -> str: + del short_dtypes # ignored + shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else '' + dtype = _short_dtype_name(self.dtype) + return f'{dtype}[{shape}]' + __str__ = __repr__ = str_short + + def update(self, shape=None, dtype=None, weak_type=None): + if shape is None: + shape = self.shape + if dtype is None: + dtype = self.dtype + if weak_type is None: + weak_type = self.weak_type + return DShapedArray(shape, dtype, weak_type) + + def __eq__(self, other): + return (type(self) is type(other) + and self.dtype == other.dtype and self.shape == other.shape + and self.weak_type == other.weak_type) + + def __hash__(self): + return hash((self.shape, self.dtype, self.weak_type)) + + def join(self, other): + if (symbolic_equal_shape(self.shape, other.shape) and + self.dtype == other.dtype): + weak_type = self.weak_type and other.weak_type + return self.update(weak_type=weak_type) + elif self.dtype == other.dtype: + return UnshapedArray(self.dtype) + else: + raise TypeError(self, other) + + def at_least_vspace(self): + return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), + self.weak_type) + +class DConcreteArray(DShapedArray): + __slots__ = ['val'] + array_abstraction_level = 1 + def __init__(self, shape, dtype, weak_type, val): + super().__init__(shape, dtype, weak_type) + self.val = val + + +pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {} + + +class AbstractBInt(AbstractValue): + __slots__ = ['bound'] + bound: int + def __init__(self, bound): + self.bound = bound + def str_short(self, short_dtypes=False) -> str: + return f'bint{{≤{self.bound}}}[]' + __repr__ = str_short + def __eq__(self, other): + return type(other) is AbstractBInt and self.bound == other.bound + def __hash__(self) -> int: + return hash((type(self), self.bound)) + +class BInt: + val: Any # Union[int, Array] + bound: int + def __init__(self, val, bound): + self.val = val + self.bound = bound + def __repr__(self) -> str: + return f'{self.val}{{≤{self.bound}}}' + def __int__(self) -> int: + return self.val + def __eq__(self, other) -> bool: + return (isinstance(other, BInt) and + (self.val, self.bound) == (other.val, other.bound)) + def __hash__(self): + return hash((self.val, self.bound)) +pytype_aval_mappings[BInt] = lambda x: AbstractBInt(x.bound) + + +# DShapedArray w/ BInt in shapes => PaddedArray runtime representation +class PaddedArray: + _aval: DShapedArray + _data: Any # standard array type + def __init__(self, aval, data): + padded_shape = tuple(d.bound if type(d) is BInt else d for d in aval.shape) + assert data.shape == padded_shape + self._aval = aval + self._data = data + shape = property(lambda self: self._aval.shape) + dtype = property(lambda self: self._aval.dtype) + def __repr__(self) -> str: + dtypestr = _short_dtype_name(self._aval.dtype) + shapestr = ','.join(map(str, self.shape)) + slices = tuple(slice(d.val) if type(d) is BInt else slice(None) + for d in self.shape) + data = self._data[slices] + return f'{dtypestr}[{shapestr}] with value:\n{data}' +pytype_aval_mappings[PaddedArray] = \ + lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type, + x._data) + + class AbstractToken(AbstractValue): def join(self, other): if isinstance(other, AbstractToken): @@ -1485,7 +1531,6 @@ def join(self, other): assert False, f"Cannot join {self} with {other}" def str_short(self, short_dtypes=False): return 'Tok' def at_least_vspace(self): return self - abstract_token: AbstractToken = AbstractToken() # Concrete token object @@ -1759,6 +1804,21 @@ def _invalid_shape_error(shape: Shape, context: str=""): "smaller subfunctions.") return TypeError(msg) +class BIntDimensionHandler(DimensionHandler): + def symbolic_equal(self, d1, d2) -> bool: + return isinstance(d2, BInt) and d1.val == d2.val and d1.bound == d2.bound + def sum(self, *ds) -> BInt: + if not all(isinstance(d, BInt) for d in ds): + raise InconclusiveDimensionOperation + if len({d.bound for d in ds}) != 1: + raise InconclusiveDimensionOperation + return BInt(sum(d.val for d in ds), ds[0].bound) + def fail(self, *_): raise InconclusiveDimensionOperation + great_equal = diff = divide_shape_sizes = stride = dilate = as_value = fail +_SPECIAL_DIMENSION_HANDLERS[BInt] = BIntDimensionHandler() + + + # ------------------- Named shapes ------------------- @@ -2436,7 +2496,7 @@ def check_type( if isinstance(ty, DShapedArray): # Check all elements in the shape tuple are well-typed. for d in ty.shape: - if isinstance(d, int): + if isinstance(d, (int, BInt)): continue elif isinstance(d, Var): if d not in env: diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 3e57b82eb5cc..f8844d86e91e 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -122,15 +122,19 @@ def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type: f"No dtype_to_ir_type handler for dtype: {dtype}") from err return ir_type_factory() -def _array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]: +def _array_ir_types(aval: Union[core.ShapedArray, core.DShapedArray] + ) -> Sequence[ir.Type]: return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),) def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]: - shape = [d if type(d) is int else -1 for d in aval.shape] + # in the MHLO builder, -1 indicates a '?' axis size + shape = [d if type(d) is int else d.bound if type(d) is core.BInt else -1 + for d in aval.shape] return (ir.RankedTensorType.get(shape, dtype_to_ir_type(aval.dtype)),) def _bint_ir_types(aval: core.AbstractBInt) -> Sequence[ir.Type]: - return (ir.RankedTensorType.get((), dtype_to_ir_type(dtypes.dtype('int32'))),) + dtype = dtypes._scalar_type_to_dtype(int) + return (ir.RankedTensorType.get((), dtype_to_ir_type(dtype)),) ir_type_handlers: Dict[Type[core.AbstractValue], Callable[[Any], Sequence[ir.Type]]] = {} diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 3ea332fbea0b..8dd31a3eb0d2 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -2414,7 +2414,7 @@ def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const] def substitute(aval: AbstractValue) -> AbstractValue: if isinstance(aval, AbstractBInt): - return ShapedArray((), np.dtype('int32')) + return ShapedArray((), dtypes._scalar_type_to_dtype(int)) elif isinstance(aval, DShapedArray): shape = [bounds.get(d, idxs.get(d, d)) for d in aval.shape] # type: ignore typ = ShapedArray if all(type(d) is int for d in shape) else DShapedArray diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 55e8842ca08e..b542dfac569b 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -253,12 +253,15 @@ def _canonicalize_python_scalar_dtype(typ, x): canonicalize_dtype_handlers: Dict[Any, Callable] = {} for t in device_array.device_array_types: - canonicalize_dtype_handlers[t] = lambda x: x + canonicalize_dtype_handlers[t] = identity canonicalize_dtype_handlers.update( (t, _canonicalize_ndarray_dtype) for t in array_types) canonicalize_dtype_handlers.update( (t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types) -canonicalize_dtype_handlers[core.Token] = lambda x: x +canonicalize_dtype_handlers[core.Token] = identity +canonicalize_dtype_handlers[core.PaddedArray] = identity +canonicalize_dtype_handlers[core.BInt] = \ + lambda x: core.BInt(_canonicalize_python_scalar_dtype(int, x.val), x.bound) def abstractify(x) -> core.AbstractValue: typ = type(x) @@ -277,6 +280,8 @@ def _make_abstract_python_scalar(typ, val): pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = {} for t in device_array.device_array_types: pytype_aval_mappings[t] = operator.attrgetter('aval') +pytype_aval_mappings[core.BInt] = lambda x: core.AbstractBInt(x.bound) +pytype_aval_mappings[core.PaddedArray] = operator.attrgetter('_aval') pytype_aval_mappings[core.Token] = lambda _: core.abstract_token pytype_aval_mappings.update((t, make_shaped_array) for t in array_types) pytype_aval_mappings.update( diff --git a/tests/api_test.py b/tests/api_test.py index 111efbcf51a7..31383c09f4fa 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -9142,6 +9142,7 @@ def f(n): self.assertAllClose(y, jnp.arange(2 * 4), check_dtypes=False) self.assertEqual(count, 1) + @jtu.skip_on_devices('iree') # TODO(mattjj): update getslice, no bints def test_slicing_basic(self): f = jax.jit(lambda x, n: jnp.sum(x[:n])) # TODO(mattjj): revise getslice, add typecheck rule for it, enable checks @@ -9509,6 +9510,98 @@ def loss_ref(params, batch): expected = grad(loss_ref)(params, batch1) self.assertAllClose(ans, expected) + def test_bint_basic(self): + d = lax.make_bint(3, 5) + self.assertEqual(str(d), '3{≤5}') + + @jax.jit + def f(d): + jnp.sin(3.) # don't have an empty jaxpr + return d + f(d) # doesn't crash + + def test_bint_broadcast(self): + d = lax.make_bint(3, 5) + + x = lax.broadcast_in_dim(0, (d,), ()) # doesn't crash + self.assertIsInstance(x, core.PaddedArray) + self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False) + self.assertEqual( + x._aval, core.DShapedArray((core.BInt(3, 5),), x._data.dtype, True)) + + def f(n): + return jnp.zeros(n) + x = jax.jit(f)(d) + self.assertIsInstance(x, core.PaddedArray) + self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False) + self.assertEqual( + x._aval, core.DShapedArray((core.BInt(3, 5),), x._data.dtype, False)) + + jaxpr = jax.make_jaxpr(f)(d).jaxpr + # { lambda ; a:bint{≤5}[]. let + # b:f32[a] = broadcast_in_dim[...] 0.0 a + # in (b,) } + self.assertLen(jaxpr.invars, 1) + a, = jaxpr.invars + self.assertEqual(a.aval, core.AbstractBInt(5)) + self.assertLen(jaxpr.eqns, 1) + eqn, = jaxpr.eqns + self.assertLen(eqn.outvars, 1) + b, = eqn.outvars + self.assertEqual(b.aval.shape, (a,)) + + def test_bint_iota(self): + def f(d): + return jnp.arange(d, dtype='int32') + + y = f(lax.make_bint(3, 5)) + self.assertIsInstance(y, core.PaddedArray) + self.assertAllClose(y._data, np.arange(5), check_dtypes=False) + + d = lax.make_bint(3, 5) + y = jax.jit(f)(d) + self.assertIsInstance(y, core.PaddedArray) + self.assertAllClose(y._data, np.arange(5), check_dtypes=False) + + def test_bint_compilation_cache(self): + count = 0 + + @jax.jit + def f(n): + nonlocal count + count += 1 + return jnp.zeros(n) + f(lax.make_bint(3, 5)) + f(lax.make_bint(4, 5)) + self.assertEqual(count, 1) + + def test_bint_compilation_cache2(self): + count = 0 + + @partial(jax.jit, abstracted_axes=('n',)) + def f(x): + nonlocal count + count += 1 + return x.sum() + + d = lax.make_bint(3, 5) + x = jnp.arange(d) + y = f(x) + self.assertEqual(y, 3) + self.assertEqual(count, 1) + + d = lax.make_bint(4, 5) + x = jnp.arange(d) + y = f(x) + self.assertEqual(y, 6) + self.assertEqual(count, 1) + + d = lax.make_bint(4, 6) + x = jnp.arange(d) + y = f(x) + self.assertEqual(y, 6) + self.assertEqual(count, 2) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())