diff --git a/jax/_src/core.py b/jax/_src/core.py index 918bf04661da..868806e8dbf8 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2383,8 +2383,6 @@ def get_bind_params(self, params): closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call') closed_call_p.def_impl(call_impl) -closed_call_p.def_effectful_abstract_eval( - lambda *_, call_jaxpr: (call_jaxpr.out_avals, call_jaxpr.effects)) outfeed_primitives: set[Primitive] = set() @@ -2828,7 +2826,7 @@ class JaxprTypeError(TypeError): pass def _check_closed_call(_, *in_atoms, call_jaxpr): in_avals = [x.aval for x in in_atoms] - if not all(map(typecompat, call_jaxpr.in_avals, in_avals)): + if list(in_avals) != list(call_jaxpr.in_avals): raise JaxprTypeError("Closed call in_avals mismatch") return call_jaxpr.out_avals, call_jaxpr.effects custom_typechecks[closed_call_p] = _check_closed_call diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index b0d6929ba704..cf16383bead0 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1426,10 +1426,7 @@ def aval_to_types(aval): args.append([hlo.create_token()]) else: args.append(arg) - if name is not None: - callee_name_stack = name_stack.extend(util.wrap_name(name, api_name)) - else: - callee_name_stack = name_stack + callee_name_stack = name_stack.extend(util.wrap_name(name, api_name)) consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts] out_vals, tokens_out = jaxpr_subcomp( ctx, jaxpr.jaxpr, callee_name_stack, tokens_in, @@ -1886,7 +1883,7 @@ def core_call_lowering(ctx: LoweringRuleContext, register_lowering(core.call_p, partial(core_call_lowering, name="core_call")) register_lowering(core.closed_call_p, - partial(core_call_lowering, name=None)) + partial(core_call_lowering, name="core_closed_call")) def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *, broadcast_dimensions) -> ir.Value: diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index b228ce9ec450..14a8a1c03606 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -227,7 +227,7 @@ def scan(f, init, xs, length=None): ys = [] maybe_reversed = reversed if reverse else lambda x: x for i in maybe_reversed(range(length)): - xs_slice = [slicing.index_in_dim(x, i, keepdims=False) for x in xs_flat] + xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat] carry, y = f(carry, tree_unflatten(xs_tree, xs_slice)) ys.append(y) stack = lambda *ys: jax.numpy.stack(ys) @@ -361,68 +361,163 @@ def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str: 'the shapes do not match' * shape_mismatch) return '' -# TODO(mattjj): re-land #19819 version? simpler, but caused ~1 perf regression. + +def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear, + f_impl, x_avals, y_avals): + consts, init, xs = split_list(args, [num_consts, num_carry]) + + carry = init + ys = [] + + for i in range(length): + i_ = length - i - 1 if reverse else i + x = _map(partial(_index_array, i_), x_avals, xs) + out = f_impl(*consts, *carry, *x) + carry, y = split_list(out, [num_carry]) + ys.append(y) + + ys = list(reversed(ys)) if reverse else ys + ys = list(zip(*ys)) + ys = _map(_stack, y_avals, ys) + return (*carry, *ys) + +def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear, + f_impl, x_avals, y_avals): + consts, init, xs = split_list(args, [num_consts, num_carry]) + + def cond_fun(vals): + i, *_ = vals + return i < length + + def body_fun(vals): + [i], carry, ys = split_list(vals, [1, num_carry]) + i_ = length - i - 1 if reverse else i + # TODO(jakevdp)[key-reuse]: this key reuse logic is not quite right, + # because the scan body may consume any keys within it. + xs_unconsumed = _map(jax.random.clone, xs) + x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed) + out_flat = f_impl(*consts, *carry, *x) + carry_out, y_updates = split_list(out_flat, [num_carry]) + ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates) + return [i + 1] + carry_out + ys_out + + # TODO(jakevdp)[key-reuse]: mark xs consumed here if f_impl consumes them. + + ys_init = _map(partial(_empty_array, length), y_avals) + if length == 0: + return init + ys_init + else: + init_val = [lax._const(length, 0)] + init + ys_init + _, *outs = while_loop(cond_fun, body_fun, init_val) + return outs + +def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry, + linear, block_length, f_impl, x_avals, y_avals): + consts, init, xs = split_list(args, [num_consts, num_carry]) + + num_blocks, rem = divmod(length, block_length) + assert rem == 0 + + partition = partial(_partition_leading, num_blocks, block_length) + xs_block = _map(partition, x_avals, xs) + + prepend_aval = partial(_prepend_dim_to_aval, block_length) + x_block_avals = _map(prepend_aval, x_avals) + y_block_avals = _map(prepend_aval, y_avals) + + f_impl_block = partial( + _scan_impl_unrolled, reverse=reverse, length=block_length, + num_consts=num_consts, num_carry=num_carry, linear=linear, + f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) + + outs = _scan_impl_loop( + *consts, *init, *xs_block, reverse=reverse, length=num_blocks, + num_consts=num_consts, num_carry=num_carry, linear=linear, + f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals) + + carry, ys_blocks = split_list(outs, [num_carry]) + combine = partial(_combine_leading, num_blocks, block_length) + ys = _map(combine, y_avals, ys_blocks) + return (*carry, *ys) + def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): - consts, carry, xs_ = split_list(args, [num_consts, num_carry]) + _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) _, y_avals = split_list(jaxpr.out_avals, [num_carry]) - num_trips, remainder = divmod(length, unroll) - if remainder: - if not reverse: - xs_, xs_rem = unzip2(_map(partial(_split_leading, num_trips*unroll), xs_)) + f_impl = core.jaxpr_as_fun(jaxpr) + + if unroll == 1: + return _scan_impl_loop( + *args, reverse=reverse, length=length, num_consts=num_consts, + num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals, + y_avals=y_avals) + + consts, init, xs = split_list(args, [num_consts, num_carry]) + num_blocks, rem = divmod(length, unroll) + length_div = num_blocks * unroll + + if rem > 0: + if reverse: + split = partial(_split_leading_dim, rem) + xs_rem, xs = unzip2(_map(split, x_avals, xs)) else: - xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_)) - xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_] - yss = _map(partial(_empty_array, (num_trips, unroll)), y_avals) - - def cond_fun(while_carry): - i, _, _ = while_carry - return i < num_trips - def body_fun(while_carry): - i_, carry, yss = while_carry - i = num_trips - i_ - 1 if reverse else i_ - xs = [slicing.dynamic_index_in_dim(xs, i, keepdims=False) for xs in xss] - carry, ys = inner(unroll, carry, xs) - yss = [slicing.dynamic_update_index_in_dim(ys, upd, i, 0) - for ys, upd in zip(yss, ys)] - return i_ + 1, carry, yss - def inner(n, carry, xs): - ys = [] - for i_ in range(n): - i = n - i_ - 1 if reverse else i_ - x = [slicing.index_in_dim(x, i, keepdims=False) for x in xs] - carry_y = eval_jaxpr_p.bind(*consts, *carry, *x, jaxpr=jaxpr) - carry, y = split_list(carry_y, [num_carry]) - ys.append(y) - ys = list(reversed(ys)) if reverse else ys - return carry, _map(jax.numpy.stack, zip(*ys)) - - if num_trips: - i = lax._const(num_trips, 0) - _, carry, yss = jax.lax.while_loop(cond_fun, body_fun, (i, carry, yss)) - ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss] - if remainder: - carry, ys_rem = inner(remainder, carry, xs_rem) - ys = _map(_concat, ys, ys_rem) if not reverse else _map(_concat, ys_rem, ys) - return [*carry, *ys] - -def _split_leading(sz, x): - return (slicing.slice_in_dim(x, 0, sz), - slicing.slice_in_dim(x, sz, x.shape[0])) - -def _concat(a, b): return lax.concatenate([a, b], 0) - -def _empty_array(prefix, aval): - return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape)) - -eval_jaxpr_p = core.Primitive('eval_jaxpr') -eval_jaxpr_p.multiple_results = True -def _stage_jaxpr(trace, *tracers, jaxpr): - params = dict(call_jaxpr=jaxpr) - return trace.default_process_primitive(core.closed_call_p, tracers, params) -pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr -@eval_jaxpr_p.def_effectful_abstract_eval # abstract eval only used for jax2tf -def _stage_jaxpr_abstract_eval(*_, jaxpr): return jaxpr.out_avals, jaxpr.effects + split = partial(_split_leading_dim, length_div) + xs, xs_rem = unzip2(_map(split, x_avals, xs)) + + outs = _scan_impl_block_unrolled( + *consts, *init, *xs, reverse=reverse, length=length_div, + num_consts=num_consts, num_carry=num_carry, linear=linear, + block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) + + carry, ys = split_list(outs, [num_carry]) + + if rem > 0: + outs = _scan_impl_unrolled( + *consts, *carry, *xs_rem, reverse=reverse, length=rem, + num_consts=num_consts, num_carry=num_carry, linear=linear, + f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) + carry, ys_rem = split_list(outs, [num_carry]) + if reverse: + ys = _map(_concatenate, y_avals, ys_rem, ys) + else: + ys = _map(_concatenate, y_avals, ys, ys_rem) + + return (*carry, *ys) + +def _stack(aval, vals): + vals = [lax.expand_dims(x, (0,)) for x in vals] + return lax.concatenate(vals, 0) + +def _concatenate(aval, x1, x2): + return lax.concatenate([x1, x2], 0) + +def _split_leading_dim(i, aval, x): + assert x.ndim >= 1 + return (slicing.slice_in_dim(x, 0, i), + slicing.slice_in_dim(x, i, x.shape[0])) + +def _dynamic_index_array(i, aval, x): + return slicing.dynamic_index_in_dim(x, i, keepdims=False) + +def _index_array(i, aval, x): + return slicing.index_in_dim(x, i, keepdims=False) + +def _empty_array(sz, aval): + return lax.broadcast(lax.empty(aval.dtype), (sz, *aval.shape)) + +def _update_array(i, aval, xs, x): + return slicing.dynamic_update_index_in_dim(xs, x, i, 0) + +def _partition_leading(sz0, sz1, aval, x): + assert x.ndim >= 1 + assert x.shape[0] == sz0 * sz1 + return lax.reshape(x, (sz0, sz1, *x.shape[1:])) + +def _combine_leading(sz0, sz1, aval, x): + assert x.ndim >= 2 + assert x.shape[0] == sz0 + assert x.shape[1] == sz1 + return lax.collapse(x, 0, 2) def _prepend_dim_to_aval(sz, aval): return core.unmapped_aval(sz, core.no_axis_name, 0, aval) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 7f092cd2026c..84d67ae446ff 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1465,13 +1465,11 @@ def get_primitive_impl(self, p: core.Primitive) -> tuple[Callable, bool]: def _unexpected_primitive(p: core.Primitive, *args, **kwargs): assert False, f"Encountered unexpected primitive {p}" + +# Call primitives are inlined for unexpected in [core.call_p, maps.xmap_p]: tf_impl[unexpected] = partial(_unexpected_primitive, unexpected) -tf_impl[lax_control_flow.loops.eval_jaxpr_p] = \ - lambda *args, jaxpr: _interpret_jaxpr( - jaxpr, *args, fresh_constant_cache=False, extra_name_stack=None) - # Primitives that are not yet implemented must be explicitly declared here. tf_not_yet_impl = [ "clz",