Skip to content

Commit

Permalink
Reverts 0dde8f7
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 619416732
  • Loading branch information
mattjj authored and jax authors committed Mar 27, 2024
1 parent 0dde8f7 commit fa9f02b
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 71 deletions.
4 changes: 1 addition & 3 deletions jax/_src/core.py
Expand Up @@ -2383,8 +2383,6 @@ def get_bind_params(self, params):

closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')
closed_call_p.def_impl(call_impl)
closed_call_p.def_effectful_abstract_eval(
lambda *_, call_jaxpr: (call_jaxpr.out_avals, call_jaxpr.effects))


outfeed_primitives: set[Primitive] = set()
Expand Down Expand Up @@ -2828,7 +2826,7 @@ class JaxprTypeError(TypeError): pass

def _check_closed_call(_, *in_atoms, call_jaxpr):
in_avals = [x.aval for x in in_atoms]
if not all(map(typecompat, call_jaxpr.in_avals, in_avals)):
if list(in_avals) != list(call_jaxpr.in_avals):
raise JaxprTypeError("Closed call in_avals mismatch")
return call_jaxpr.out_avals, call_jaxpr.effects
custom_typechecks[closed_call_p] = _check_closed_call
Expand Down
7 changes: 2 additions & 5 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -1426,10 +1426,7 @@ def aval_to_types(aval):
args.append([hlo.create_token()])
else:
args.append(arg)
if name is not None:
callee_name_stack = name_stack.extend(util.wrap_name(name, api_name))
else:
callee_name_stack = name_stack
callee_name_stack = name_stack.extend(util.wrap_name(name, api_name))
consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
out_vals, tokens_out = jaxpr_subcomp(
ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
Expand Down Expand Up @@ -1886,7 +1883,7 @@ def core_call_lowering(ctx: LoweringRuleContext,

register_lowering(core.call_p, partial(core_call_lowering, name="core_call"))
register_lowering(core.closed_call_p,
partial(core_call_lowering, name=None))
partial(core_call_lowering, name="core_closed_call"))

def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *,
broadcast_dimensions) -> ir.Value:
Expand Down
213 changes: 154 additions & 59 deletions jax/_src/lax/control_flow/loops.py
Expand Up @@ -227,7 +227,7 @@ def scan(f, init, xs, length=None):
ys = []
maybe_reversed = reversed if reverse else lambda x: x
for i in maybe_reversed(range(length)):
xs_slice = [slicing.index_in_dim(x, i, keepdims=False) for x in xs_flat]
xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
ys.append(y)
stack = lambda *ys: jax.numpy.stack(ys)
Expand Down Expand Up @@ -361,68 +361,163 @@ def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str:
'the shapes do not match' * shape_mismatch)
return ''

# TODO(mattjj): re-land #19819 version? simpler, but caused ~1 perf regression.

def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])

carry = init
ys = []

for i in range(length):
i_ = length - i - 1 if reverse else i
x = _map(partial(_index_array, i_), x_avals, xs)
out = f_impl(*consts, *carry, *x)
carry, y = split_list(out, [num_carry])
ys.append(y)

ys = list(reversed(ys)) if reverse else ys
ys = list(zip(*ys))
ys = _map(_stack, y_avals, ys)
return (*carry, *ys)

def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])

def cond_fun(vals):
i, *_ = vals
return i < length

def body_fun(vals):
[i], carry, ys = split_list(vals, [1, num_carry])
i_ = length - i - 1 if reverse else i
# TODO(jakevdp)[key-reuse]: this key reuse logic is not quite right,
# because the scan body may consume any keys within it.
xs_unconsumed = _map(jax.random.clone, xs)
x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed)
out_flat = f_impl(*consts, *carry, *x)
carry_out, y_updates = split_list(out_flat, [num_carry])
ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
return [i + 1] + carry_out + ys_out

# TODO(jakevdp)[key-reuse]: mark xs consumed here if f_impl consumes them.

ys_init = _map(partial(_empty_array, length), y_avals)
if length == 0:
return init + ys_init
else:
init_val = [lax._const(length, 0)] + init + ys_init
_, *outs = while_loop(cond_fun, body_fun, init_val)
return outs

def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry,
linear, block_length, f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])

num_blocks, rem = divmod(length, block_length)
assert rem == 0

partition = partial(_partition_leading, num_blocks, block_length)
xs_block = _map(partition, x_avals, xs)

prepend_aval = partial(_prepend_dim_to_aval, block_length)
x_block_avals = _map(prepend_aval, x_avals)
y_block_avals = _map(prepend_aval, y_avals)

f_impl_block = partial(
_scan_impl_unrolled, reverse=reverse, length=block_length,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)

outs = _scan_impl_loop(
*consts, *init, *xs_block, reverse=reverse, length=num_blocks,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals)

carry, ys_blocks = split_list(outs, [num_carry])
combine = partial(_combine_leading, num_blocks, block_length)
ys = _map(combine, y_avals, ys_blocks)
return (*carry, *ys)

def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
unroll):
consts, carry, xs_ = split_list(args, [num_consts, num_carry])
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
num_trips, remainder = divmod(length, unroll)
if remainder:
if not reverse:
xs_, xs_rem = unzip2(_map(partial(_split_leading, num_trips*unroll), xs_))
f_impl = core.jaxpr_as_fun(jaxpr)

if unroll == 1:
return _scan_impl_loop(
*args, reverse=reverse, length=length, num_consts=num_consts,
num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals,
y_avals=y_avals)

consts, init, xs = split_list(args, [num_consts, num_carry])
num_blocks, rem = divmod(length, unroll)
length_div = num_blocks * unroll

if rem > 0:
if reverse:
split = partial(_split_leading_dim, rem)
xs_rem, xs = unzip2(_map(split, x_avals, xs))
else:
xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_))
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
yss = _map(partial(_empty_array, (num_trips, unroll)), y_avals)

def cond_fun(while_carry):
i, _, _ = while_carry
return i < num_trips
def body_fun(while_carry):
i_, carry, yss = while_carry
i = num_trips - i_ - 1 if reverse else i_
xs = [slicing.dynamic_index_in_dim(xs, i, keepdims=False) for xs in xss]
carry, ys = inner(unroll, carry, xs)
yss = [slicing.dynamic_update_index_in_dim(ys, upd, i, 0)
for ys, upd in zip(yss, ys)]
return i_ + 1, carry, yss
def inner(n, carry, xs):
ys = []
for i_ in range(n):
i = n - i_ - 1 if reverse else i_
x = [slicing.index_in_dim(x, i, keepdims=False) for x in xs]
carry_y = eval_jaxpr_p.bind(*consts, *carry, *x, jaxpr=jaxpr)
carry, y = split_list(carry_y, [num_carry])
ys.append(y)
ys = list(reversed(ys)) if reverse else ys
return carry, _map(jax.numpy.stack, zip(*ys))

if num_trips:
i = lax._const(num_trips, 0)
_, carry, yss = jax.lax.while_loop(cond_fun, body_fun, (i, carry, yss))
ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss]
if remainder:
carry, ys_rem = inner(remainder, carry, xs_rem)
ys = _map(_concat, ys, ys_rem) if not reverse else _map(_concat, ys_rem, ys)
return [*carry, *ys]

def _split_leading(sz, x):
return (slicing.slice_in_dim(x, 0, sz),
slicing.slice_in_dim(x, sz, x.shape[0]))

def _concat(a, b): return lax.concatenate([a, b], 0)

def _empty_array(prefix, aval):
return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape))

eval_jaxpr_p = core.Primitive('eval_jaxpr')
eval_jaxpr_p.multiple_results = True
def _stage_jaxpr(trace, *tracers, jaxpr):
params = dict(call_jaxpr=jaxpr)
return trace.default_process_primitive(core.closed_call_p, tracers, params)
pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr
@eval_jaxpr_p.def_effectful_abstract_eval # abstract eval only used for jax2tf
def _stage_jaxpr_abstract_eval(*_, jaxpr): return jaxpr.out_avals, jaxpr.effects
split = partial(_split_leading_dim, length_div)
xs, xs_rem = unzip2(_map(split, x_avals, xs))

outs = _scan_impl_block_unrolled(
*consts, *init, *xs, reverse=reverse, length=length_div,
num_consts=num_consts, num_carry=num_carry, linear=linear,
block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)

carry, ys = split_list(outs, [num_carry])

if rem > 0:
outs = _scan_impl_unrolled(
*consts, *carry, *xs_rem, reverse=reverse, length=rem,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
carry, ys_rem = split_list(outs, [num_carry])
if reverse:
ys = _map(_concatenate, y_avals, ys_rem, ys)
else:
ys = _map(_concatenate, y_avals, ys, ys_rem)

return (*carry, *ys)

def _stack(aval, vals):
vals = [lax.expand_dims(x, (0,)) for x in vals]
return lax.concatenate(vals, 0)

def _concatenate(aval, x1, x2):
return lax.concatenate([x1, x2], 0)

def _split_leading_dim(i, aval, x):
assert x.ndim >= 1
return (slicing.slice_in_dim(x, 0, i),
slicing.slice_in_dim(x, i, x.shape[0]))

def _dynamic_index_array(i, aval, x):
return slicing.dynamic_index_in_dim(x, i, keepdims=False)

def _index_array(i, aval, x):
return slicing.index_in_dim(x, i, keepdims=False)

def _empty_array(sz, aval):
return lax.broadcast(lax.empty(aval.dtype), (sz, *aval.shape))

def _update_array(i, aval, xs, x):
return slicing.dynamic_update_index_in_dim(xs, x, i, 0)

def _partition_leading(sz0, sz1, aval, x):
assert x.ndim >= 1
assert x.shape[0] == sz0 * sz1
return lax.reshape(x, (sz0, sz1, *x.shape[1:]))

def _combine_leading(sz0, sz1, aval, x):
assert x.ndim >= 2
assert x.shape[0] == sz0
assert x.shape[1] == sz1
return lax.collapse(x, 0, 2)

def _prepend_dim_to_aval(sz, aval):
return core.unmapped_aval(sz, core.no_axis_name, 0, aval)
Expand Down
6 changes: 2 additions & 4 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -1465,13 +1465,11 @@ def get_primitive_impl(self, p: core.Primitive) -> tuple[Callable, bool]:
def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
assert False, f"Encountered unexpected primitive {p}"


# Call primitives are inlined
for unexpected in [core.call_p, maps.xmap_p]:
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)

tf_impl[lax_control_flow.loops.eval_jaxpr_p] = \
lambda *args, jaxpr: _interpret_jaxpr(
jaxpr, *args, fresh_constant_cache=False, extra_name_stack=None)

# Primitives that are not yet implemented must be explicitly declared here.
tf_not_yet_impl = [
"clz",
Expand Down

0 comments on commit fa9f02b

Please sign in to comment.