Skip to content

Commit

Permalink
[shape_poly] DynamicReduceWindow: add shape poly support with native …
Browse files Browse the repository at this point in the history
…serialization.

PiperOrigin-RevId: 541216925
  • Loading branch information
gnecula authored and jax authors committed Jun 17, 2023
1 parent b9c0658 commit b83e6fb
Show file tree
Hide file tree
Showing 8 changed files with 492 additions and 89 deletions.
68 changes: 67 additions & 1 deletion jax/_src/interpreters/mlir.py
Expand Up @@ -1956,6 +1956,7 @@ def custom_call(
backend_config: Optional[str] = None,
has_side_effect: bool = False,
result_shapes: Optional[Sequence[ir.Value]] = None,
called_computations: Sequence[str] = (),
api_version: int = 2,
) -> ir.Operation:
"""Wraps a hlo.CustomCall.
Expand All @@ -1964,14 +1965,16 @@ def custom_call(
result_shapes: tensors that represent the result shapes, to be used when
the results have dynamic shapes. If not-None, its length must match the
number of the results.
called_computations: the list of function names called by the custom call.
"""
attributes = dict(
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
backend_config=ir.StringAttr.get(
"" if backend_config is None else backend_config),
api_version=i32_attr(api_version),
called_computations=ir.ArrayAttr.get([]),
called_computations=ir.ArrayAttr.get([
ir.FlatSymbolRefAttr.get(name) for name in called_computations]),
)

if result_shapes is not None:
Expand All @@ -1985,6 +1988,69 @@ def custom_call(

return hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes)

def reduce_window(
ctx: LoweringRuleContext, *,
# Base name to be used for the reducer function
reducer_name: str,
# Compute the reducer body given the reducer.
reducer_body: Callable[[ir.Block], Sequence[ir.Value]],
operands: Sequence[ir.Value],
init_values: Sequence[ir.Value],
init_values_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue],
window_dimensions, window_strides, padding, base_dilation, window_dilation):
"""Builds a ReduceWindowOp, with support for dynamic shapes."""

scalar_types = [aval_to_ir_type(aval) for aval in init_values_avals]
if any(not core.is_constant_shape(s)
for s in [window_dimensions, window_dilation, window_strides, base_dilation, *padding]):
# d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each
# spatial dimension.
int2d = aval_to_ir_type(core.ShapedArray((1, 2), np.int32))
def prep_one_pad(pad_lo_hi: Tuple[core.DimSize, core.DimSize]):
pads = shape_tensor(eval_dynamic_shape(ctx, pad_lo_hi)) # i32[2]
return hlo.ReshapeOp(int2d, pads)
d_padding = hlo.ConcatenateOp(list(map(prep_one_pad, padding)),
i64_attr(0)).result
# Build the reducer
reducer_type = ir.FunctionType.get(scalar_types + scalar_types,
scalar_types)
with ir.InsertionPoint.at_block_begin(ctx.module_context.module.body):
reducer = func_dialect.FuncOp(reducer_name, reducer_type)
ctx.module_context.symbol_table.insert(reducer)
entry_block = reducer.add_entry_block()
with ir.InsertionPoint(entry_block):
res = reducer_body(entry_block)
hlo.ReturnOp(res)

rw = custom_call(
"stablehlo.dynamic_reduce_window",
list(map(aval_to_ir_type, out_avals)),
[
*operands, *init_values,
shape_tensor(eval_dynamic_shape(ctx, window_dimensions)),
shape_tensor(eval_dynamic_shape(ctx, window_strides)),
shape_tensor(eval_dynamic_shape(ctx, base_dilation)),
shape_tensor(eval_dynamic_shape(ctx, window_dilation)),
d_padding],
called_computations=[reducer.name.value],
)
else: # Static shapes
rw = hlo.ReduceWindowOp(
list(map(aval_to_ir_type, out_avals)),
operands, init_values,
dense_int_elements(window_dimensions),
window_strides=dense_int_elements(window_strides),
base_dilations=dense_int_elements(base_dilation),
window_dilations=dense_int_elements(window_dilation),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
with ir.InsertionPoint(reducer):
res = reducer_body(reducer)
hlo.ReturnOp(res)
return rw.results


def refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
"""Refine the polymorphic shapes inside a module.
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/lax/control_flow/loops.py
Expand Up @@ -2037,6 +2037,10 @@ def cumred_gpu_impl(window_reduce: Callable, reduce_fn: Callable, x, *,
# On small inputs reduce_window is faster being a single fusion,
# but on larger ones is slower because of O(n^2) complexity.
# This conservative value of the threshold was obtained via benchmarking.
if not core.is_constant_dim(x.shape[axis]):
raise NotImplementedError(
"associative scan reductions not implemented with shape polymorphism "
"and native serialization on GPU")
if x.shape[axis] > 32:
return associative_scan(reduce_fn, x, reverse=reverse, axis=axis)
return cumred_reduce_window_impl(window_reduce, x, axis=axis, reverse=reverse)
Expand Down
102 changes: 50 additions & 52 deletions jax/_src/lax/windowed_reductions.py
Expand Up @@ -314,26 +314,26 @@ def _generic_reduce_window_lower(ctx, *args, jaxpr, consts,
base_dilation, window_dilation):
operands, init_values = util.split_list(args, [len(args) // 2])
_, init_value_avals = util.split_list(ctx.avals_in, [len(operands)])
scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
rw = hlo.ReduceWindowOp(
map(mlir.aval_to_ir_type, ctx.avals_out),
operands,
init_values,
mlir.dense_int_elements(window_dimensions),
window_strides=mlir.dense_int_elements(window_strides),
base_dilations=mlir.dense_int_elements(base_dilation),
window_dilations=mlir.dense_int_elements(window_dilation),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
with ir.InsertionPoint(reducer):

def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]:
if jaxpr.effects:
raise NotImplementedError('Cannot lower effectful `reduce_window`.')
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr,
mlir.TokenSet(), consts, *([a] for a in reducer.arguments),
dim_var_values=ctx.dim_var_values)
hlo.ReturnOp(util.flatten(out_nodes))
return rw.results
return util.flatten(out_nodes)

return mlir.reduce_window(
ctx,
reducer_name="generic_reduce_window_reducer",
reducer_body=reducer_body,
operands=operands,
init_values=init_values, init_values_avals=init_value_avals,
out_avals=ctx.avals_out,
window_dimensions=window_dimensions, window_strides=window_strides,
base_dilation=base_dilation, window_dilation=window_dilation,
padding=padding)


mlir.register_lowering(reduce_window_p, _generic_reduce_window_lower)

Expand Down Expand Up @@ -461,28 +461,26 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,


def _reduce_window_lower(
reduce_op, init_value, ctx, operand, *,
window_dimensions, window_strides, padding, base_dilation, window_dilation):
aval_out, = ctx.avals_out
reduce_op,
init_value, ctx, operand, *,
window_dimensions, window_strides, padding, base_dilation,
window_dilation):

operand_aval, = ctx.avals_in
scalar_aval = operand_aval.update(shape=())
scalar_type = mlir.aval_to_ir_type(scalar_aval)
if any(not core.is_constant_shape(s)
for s in [window_dimensions, window_dilation, window_strides, base_dilation, *padding]):
raise NotImplementedError("ReduceWindowOp for dynamic shapes")
rw = hlo.ReduceWindowOp(
mlir.aval_to_ir_types(aval_out), [operand],
[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype), scalar_aval)],
mlir.dense_int_elements(window_dimensions),
window_strides=mlir.dense_int_elements(window_strides),
base_dilations=mlir.dense_int_elements(base_dilation),
window_dilations=mlir.dense_int_elements(window_dilation),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer):
hlo.ReturnOp(reduce_op(*reducer.arguments))
return rw.results

return mlir.reduce_window(ctx,
reducer_name=f"reduce_window_{scalar_aval.dtype}_reducer",
reducer_body=lambda reducer: reduce_op(*reducer.arguments),
operands=[operand],
init_values=[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype),
scalar_aval)],
init_values_avals=[scalar_aval],
out_avals=ctx.avals_out,
window_dimensions=window_dimensions,
window_strides=window_strides, base_dilation=base_dilation,
window_dilation=window_dilation, padding=padding)


mlir.register_lowering(reduce_window_sum_p, partial(
_reduce_window_lower, hlo.AddOp, lambda _: 0))
Expand Down Expand Up @@ -685,7 +683,6 @@ def _broadcast_scalar_const(x, aval_out):
word_dtype = lax._UINT_DTYPES[nbits]
double_word_dtype = lax._UINT_DTYPES[nbits * 2]
word_type = mlir.dtype_to_ir_type(word_dtype) # type: ignore
double_word_type = mlir.dtype_to_ir_type(double_word_dtype) # type: ignore
# Packs two values into a double_word_type.
def pack(a, b, ab_aval):
word_type_ab_aval = ab_aval.update(dtype=word_dtype)
Expand Down Expand Up @@ -727,7 +724,6 @@ def snd(t, t_aval):
nmant = r_nbits - nexp - 1

double_word_dtype = word_dtype = lax._UINT_DTYPES[nbits]
double_word_type = word_type = mlir.dtype_to_ir_type(word_dtype) # type: ignore

# Packs two values into a double_word_type.
def pack(a, b, ab_aval):
Expand Down Expand Up @@ -759,25 +755,27 @@ def snd(t, t_aval):
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
init = -np.inf if select_prim is lax.ge_p else np.inf
double_word_out_aval = out_aval.update(dtype=double_word_dtype)
rw = hlo.ReduceWindowOp(
[mlir.aval_to_ir_type(double_word_out_aval)],
pack(operand, tangents, operand_aval),
pack(const(dtype, init), const(dtype, 0), core.ShapedArray((), dtype)),
mlir.dense_int_elements(window_dimensions),
window_strides=mlir.dense_int_elements(window_strides),
base_dilations=mlir.dense_int_elements(base_dilation),
window_dilations=mlir.dense_int_elements(window_dilation),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
scalar_type = ir.RankedTensorType.get([], double_word_type)
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer):

def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]:
x, y = reducer.arguments
assert select_prim is lax.ge_p or select_prim is lax.le_p
cmp_op = "GE" if select_prim is lax.ge_p else "LE"
out = hlo.SelectOp(mlir.compare_hlo(fst(x), fst(y), cmp_op), x, y)
hlo.ReturnOp(out)
return [snd(rw.result, double_word_out_aval)]
return out

res, = mlir.reduce_window(ctx,
reducer_name="reduce_window_select_and_gather_add",
reducer_body=reducer_body,
operands=[pack(operand, tangents, operand_aval)],
init_values=[pack(const(dtype, init), const(dtype, 0), core.ShapedArray((), dtype))],
init_values_avals=[core.ShapedArray((), double_word_dtype)],
out_avals=[double_word_out_aval],
window_dimensions=window_dimensions,
window_strides=window_strides,
base_dilation=base_dilation,
window_dilation=window_dilation,
padding=padding)
return [snd(res, double_word_out_aval)]

# TODO(phawkins): use this translation rule on all platforms.
def _select_and_gather_add_using_variadic_reducewindow(
Expand Down
5 changes: 5 additions & 0 deletions jax/experimental/jax2tf/impl_no_xla.py
Expand Up @@ -641,6 +641,11 @@ def _reduce_monoid(operand, window_dimensions, window_strides, padding,
has_only_spatial_dims=has_only_spatial_dims)

def tf_pool(inputs, pooling_type):
if any(not core.is_constant_shape(s) for s in
(window_dimensions, window_strides, dilations)):
raise NotImplementedError(
f"TODO: use tf.nn.pool with dynamic shapes¨{window_dimensions=} "
f" {window_strides=} {dilations=}")
result = tf.nn.pool(
inputs,
window_shape=window_dimensions,
Expand Down
4 changes: 4 additions & 0 deletions jax/experimental/jax2tf/jax_export.py
Expand Up @@ -715,6 +715,10 @@ def _check_lowering(lowering) -> None:
# ApproxTopK on TPU
"ApproxTopK",
"tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True)
# TODO(burmako): maintain backwards compatibility for this, until it
# is upstreamed to StableHLO.
# See https://github.com/openxla/stablehlo/issues/8.
"stablehlo.dynamic_reduce_window",
}

def _check_module(mod: ir.Module, *,
Expand Down

0 comments on commit b83e6fb

Please sign in to comment.