Skip to content

Commit

Permalink
[dynamic-shapes] revive basic bounded int machinery, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Jul 7, 2022
1 parent 89a6766 commit 98e71fe
Show file tree
Hide file tree
Showing 10 changed files with 343 additions and 118 deletions.
28 changes: 19 additions & 9 deletions jax/_src/dispatch.py
Expand Up @@ -299,6 +299,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
if config.jax_dynamic_shapes:
keep_unused = True
has_outfeed = False
donated_invars = [False] * len(fun.in_type)
else:
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
jaxpr = apply_outfeed_rewriter(jaxpr)
Expand All @@ -318,8 +319,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
device = _xla_callable_device(nreps, backend, device, arg_devices)
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)

if (config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr) and
not _backend_supports_unbounded_dynamic_shapes(backend)):
if config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr):
jaxpr, consts = pe.pad_jaxpr(jaxpr, consts)

map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))
Expand Down Expand Up @@ -517,6 +517,7 @@ def aval_to_num_buffers(aval: core.AbstractValue) -> int:
num_buffers_handlers[core.ShapedArray] = lambda _: 1
num_buffers_handlers[core.DShapedArray] = lambda _: 1
num_buffers_handlers[core.ConcreteArray] = lambda _: 1
num_buffers_handlers[core.AbstractBInt] = lambda _: 1


def _input_handler(backend: Backend,
Expand Down Expand Up @@ -649,26 +650,33 @@ def dynamic_array_result_handler(sticky_device: Optional[Device],
return partial(_dynamic_array_result_handler, sticky_device, aval)

def _dynamic_array_result_handler(sticky_device, aval, env, buf):
if all(type(d) is int for d in aval.shape):
del env
in_env, out_env = env or (None, None)
shape = [in_env[d.val] if type(d) is core.InDBIdx else
out_env[d.val] if type(d) is core.OutDBIdx else d
for d in aval.shape]
if all(type(d) is int for d in shape):
aval = core.ShapedArray(tuple(shape), aval.dtype)
return _maybe_create_array_from_da(buf, aval, sticky_device)
elif any(type(d) is core.BInt for d in shape):
padded_shape = [d.bound if type(d) is core.BInt else d for d in shape]
buf_aval = core.ShapedArray(tuple(padded_shape), aval.dtype, aval.weak_type)
data = _maybe_create_array_from_da(buf, buf_aval, sticky_device)
return core.PaddedArray(aval.update(shape=tuple(shape)), data)
else:
assert env is not None
in_env, out_env = env
shape = [in_env[d.val] if type(d) is core.InDBIdx else
out_env[d.val] if type(d) is core.OutDBIdx else d
for d in aval.shape]
aval = core.ShapedArray(tuple(shape), aval.dtype)
return _maybe_create_array_from_da(buf, aval, sticky_device)



result_handlers: Dict[
Type[core.AbstractValue],
Callable[[Optional[Device], Any], ResultHandler]] = {}
result_handlers[core.AbstractToken] = lambda _, __: lambda _, __: core.token
result_handlers[core.ShapedArray] = array_result_handler
result_handlers[core.DShapedArray] = dynamic_array_result_handler
result_handlers[core.ConcreteArray] = array_result_handler
result_handlers[core.AbstractBInt] = \
lambda _, a: lambda _, b: core.BInt(int(b), a.bound)


def needs_check_special():
Expand Down Expand Up @@ -1005,13 +1013,15 @@ def _device_put_token(_, device):
device_put_handlers.update((t, _device_put_array) for t in array_types)
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
device_put_handlers[core.Token] = _device_put_token
device_put_handlers[core.BInt] = lambda x, d: _device_put_scalar(x.val, d)


def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[Device]):
x = _copy_device_array_to_device(x, device)
return (x.device_buffer,)
for t in device_array.device_array_types:
device_put_handlers[t] = _device_put_device_array
device_put_handlers[core.PaddedArray] = lambda x, d: device_put(x._data, d)

def _copy_device_array_to_device(
x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray],
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/iree.py
Expand Up @@ -104,8 +104,8 @@ def block_until_ready(self) -> IreeBuffer:
return self # no async

# overrides repr on base class which expects _value and aval attributes
def __repr__(self):
return f'IreeBuffer({self.to_py()})'
def __repr__(self): return f'IreeBuffer({self.to_py()})'
_value = property(to_py)

class IreeExecutable:

Expand Down
73 changes: 61 additions & 12 deletions jax/_src/lax/lax.py
Expand Up @@ -1440,6 +1440,7 @@ def unop(result_dtype, accepted_dtypes, name):
weak_type_rule=weak_type_rule)
batching.defvectorized(prim)
masking.defvectorized(prim)
pe.padding_rules[prim] = lambda _, __, x, **kw: [prim.bind(x, **kw)]
return prim
standard_unop = partial(unop, _identity)
_attrgetter = lambda name: lambda x, **kwargs: getattr(x, name)
Expand Down Expand Up @@ -1515,6 +1516,7 @@ def naryop(result_dtype, accepted_dtypes, name):
weak_type_rule=weak_type_rule)
batching.defbroadcasting(prim)
masking.defnaryop(prim)
pe.padding_rules[prim] = lambda _, __, *xs, **kw: [prim.bind(*xs, **kw)]
return prim
standard_naryop = partial(naryop, _input_dtype)

Expand Down Expand Up @@ -2080,7 +2082,6 @@ def _add_inverse(r, x, y):
ad.primitive_jvps[add_p] = _add_jvp
ad.primitive_transposes[add_p] = _add_transpose
mlir.register_lowering(add_p, partial(_nary_lower_mhlo, mhlo.AddOp))
pe.padding_rules[add_p] = lambda _, __, x, y: [add(x, y)]

def _sub_jvp(primals, tangents):
x, y = primals
Expand Down Expand Up @@ -2110,7 +2111,6 @@ def _sub_transpose(t, x, y):
ad.primitive_jvps[sub_p] = _sub_jvp
ad.primitive_transposes[sub_p] = _sub_transpose
mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubOp))
pe.padding_rules[sub_p] = lambda _, __, x, y: [sub(x, y)]


def _mul_transpose(ct, x, y):
Expand All @@ -2137,7 +2137,6 @@ def _mul_inverse(r, x, y):
lambda ydot, x, y: mul(x, ydot))
ad.primitive_transposes[mul_p] = _mul_transpose
mlir.register_lowering(mul_p, partial(_nary_lower_mhlo, mhlo.MulOp))
pe.padding_rules[mul_p] = lambda _, __, x, y: [mul(x, y)]

def _div_transpose_rule(cotangent, x, y):
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
Expand Down Expand Up @@ -2174,7 +2173,6 @@ def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x):
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
mlir.register_lowering(max_p, partial(_nary_lower_mhlo, mlir.max_mhlo))
pe.padding_rules[max_p] = lambda _, __, x, y: [max(x, y)]

min_p: core.Primitive = standard_naryop([_any, _any], 'min')
ad.defjvp2(min_p,
Expand Down Expand Up @@ -2297,9 +2295,13 @@ def _convert_elt_type_pp_rule(eqn, context, settings):
printed_params = {}
if eqn.params['weak_type']:
printed_params['weak_type'] = True
return [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
rhs = [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
return [lhs, pp.text(" = ", annotation=annotation), *rhs]


convert_element_type_p = Primitive('convert_element_type')
Expand Down Expand Up @@ -2756,7 +2758,7 @@ def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape,
assert isinstance(d, core.Tracer)
new_shape.append(None)
new_dyn_shape.append(d)
return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=new_shape,
return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=tuple(new_shape),
broadcast_dimensions=broadcast_dimensions)]

def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions):
Expand Down Expand Up @@ -2820,8 +2822,18 @@ def _broadcast_in_dim_pp_rule(eqn, context, settings):
if settings.source_info else None)
return [lhs, pp.text(" = ", annotation=annotation), *rhs]

def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
if not any(isinstance(d, core.BInt) for d in shape):
shape = _broadcast_in_dim_shape_rule( # error checking
x, shape=shape, broadcast_dimensions=broadcast_dimensions)
return core.ShapedArray(shape, x.dtype, x.weak_type, x.named_shape)
# If any BInts in shape, produce a DShapedArray (even if x is a ShapedArray)
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
return core.DShapedArray(shape, x.dtype, x.weak_type)

broadcast_in_dim_p = standard_primitive(
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
broadcast_in_dim_p.def_abstract_eval(_broadcast_in_dim_abstract_eval)
ad.primitive_jvps[broadcast_in_dim_p] = _broadcast_in_dim_jvp_rule
ad.primitive_transposes[broadcast_in_dim_p] = _broadcast_in_dim_transpose_rule
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
Expand Down Expand Up @@ -3605,8 +3617,8 @@ def _reduce_sum_padding_rule(in_avals, out_avals, operand, *, axes):

def _replace_masked_values(x, val, padded_axes):
if not padded_axes: return x
masks = [broadcasted_iota(np.dtype('int32'), x.shape, i) < d
for i, d in padded_axes]
dtype = dtypes._scalar_type_to_dtype(int)
masks = [broadcasted_iota(dtype, x.shape, i) < d for i, d in padded_axes]
return select(_reduce(operator.and_, masks), x, full_like(x, val))


Expand Down Expand Up @@ -4384,7 +4396,10 @@ def _iota_abstract_eval(*, dtype, shape, dimension):
if not 0 <= dimension < len(shape):
raise ValueError("iota dimension must be between 0 and len(shape), got "
f"dimension={dimension} for shape {shape}")
return ShapedArray(shape, dtype)
if not any(isinstance(d, core.BInt) for d in shape):
return ShapedArray(shape, dtype)
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
return core.DShapedArray(shape, dtype, False)

iota_p = Primitive('iota')
iota_p.def_impl(partial(xla.apply_primitive, iota_p))
Expand Down Expand Up @@ -4425,12 +4440,46 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension):
mlir.i64_attr(dimension)).results
mlir.register_lowering(iota_p, _iota_lower)

def _iota_pp_rule(eqn, context, settings):
printed_params = {}
if len(eqn.params['shape']) > 1:
printed_params['dimension'] = eqn.params['dimension']
lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
rhs = [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
# core.pp_eqn_rules[iota_p] = _iota_pp_rule

def _iota_padding_rule(in_avals, out_avals, *dyn_shape, dtype, shape, dimension):
out_aval, = out_avals
new_shape = []
new_dyn_shape = []
for d in out_aval.shape:
if type(d) is pe.BoundedAxisSize:
new_shape.append(d.bound)
elif type(d) is int:
new_shape.append(d)
else:
assert isinstance(d, core.Tracer)
new_shape.append(None)
new_dyn_shape.append(d)
return [iota_p.bind(*new_dyn_shape, shape=tuple(new_shape),
dtype=dtype, dimension=dimension)]
pe.padding_rules[iota_p] = _iota_padding_rule


def make_bint(i, bd: int):
return bint_p.bind(i, bd=bd)

bint_p = core.Primitive('bint')

@bint_p.def_impl
def _bint_impl(i, *, bd):
return core.BInt(i, bd)

@bint_p.def_abstract_eval
def bint_abstract_eval(_, *, bd: int):
return core.AbstractBInt(bound=bd)
Expand Down Expand Up @@ -4566,7 +4615,7 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
if not len(obj): # pylint: disable=g-explicit-length-test
return
if (config.jax_dynamic_shapes and isinstance(obj, (tuple, list)) and
any(isinstance(d, core.Tracer) for d in obj)):
any(isinstance(d, (core.Tracer, core.BInt)) for d in obj)):
return # TODO(mattjj): handle more checks in the dynamic shape case
obj_arr = np.array(obj)
if obj_arr.ndim != 1:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/slicing.py
Expand Up @@ -903,7 +903,7 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes):
dynamic_slice_p = standard_primitive(
_dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice',
weak_type_rule=_argnum_weak_type(0))
ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp # TODO
ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp
ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule

Expand Down
4 changes: 4 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -2101,8 +2101,12 @@ def arange(start: core.DimSize, stop: Optional[core.DimSize]=None,
dtype = _jnp_dtype(dtype)
if stop is None and step is None:
if (jax.config.jax_dynamic_shapes and
not isinstance(core.get_aval(start), core.AbstractBInt) and
not isinstance(core.get_aval(start), core.ConcreteArray)):
start = ceil(start).astype(int) # note using jnp here
elif (isinstance(start, core.BInt) or isinstance(start, core.Tracer) and
isinstance(core.get_aval(start), core.AbstractBInt)):
pass
else:
start = require(start, msg("stop"))
start = np.ceil(start).astype(int)
Expand Down

0 comments on commit 98e71fe

Please sign in to comment.