Skip to content

Commit

Permalink
Merge pull request #12361 from sharadmv:for-unroll
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 474506368
  • Loading branch information
jax authors committed Sep 15, 2022
2 parents 2fb8695 + 08c5753 commit 9791199
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 28 deletions.
89 changes: 61 additions & 28 deletions jax/_src/lax/control_flow/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def val_to_ref_aval(x) -> ShapedArrayRef:

def for_loop(nsteps: Union[int, Sequence[int]],
body: Callable[[Array, Ref[S]], None], init_state: S,
*, reverse: bool = False) -> S:
*, reverse: bool = False, unroll: int = 1) -> S:
"""A for-loop combinator that allows read/write semantics in the loop body.
`for_loop` is a higher-order function that enables writing loops that can be
Expand Down Expand Up @@ -138,18 +138,24 @@ def for_loop(nsteps, body, init_state):
not return anything.
init_state: A Pytree of JAX-compatible values used to initialize the `Ref`s
that will be passed into the for loop body.
unroll: A positive int specifying, in the underlying operation of the
`for` primitive, how many iterations to unroll within a single iteration
of a loop. Higher values may speed up execution time at the cost of longer
compilation time.
Returns:
A Pytree of values representing the output of the for loop.
"""
if unroll < 1:
raise ValueError("`unroll` must be a positive integer.")
if isinstance(nsteps, int):
nsteps = [nsteps]
if len(nsteps) > 1:
outer_step, *rest_steps = nsteps
def wrapped_body(i, refs):
vals = tree_map(lambda ref: state.ref_get(ref, ()), refs)
vals = for_loop(rest_steps, partial(body, i), vals)
vals = for_loop(rest_steps, partial(body, i), vals, unroll=unroll)
tree_map(lambda ref, val: state.ref_set(ref, (), val), refs, vals)
return for_loop(outer_step, wrapped_body, init_state)
return for_loop(outer_step, wrapped_body, init_state, unroll=unroll)
nsteps, = nsteps
flat_state, state_tree = tree_flatten(init_state)
state_avals = map(val_to_ref_aval, flat_state)
Expand All @@ -162,7 +168,8 @@ def wrapped_body(i, refs):
jaxpr = _hoist_consts_to_refs(jaxpr)
which_linear = (False,) * (len(consts) + len(flat_state))
out_flat = for_p.bind(*consts, *flat_state, jaxpr=jaxpr, nsteps=int(nsteps),
reverse=reverse, which_linear=which_linear)
reverse=reverse, which_linear=which_linear,
unroll=unroll)
# Consts are `Ref`s so they are both inputs and outputs. We remove them from
# the outputs.
out_flat = out_flat[len(consts):]
Expand All @@ -178,10 +185,10 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
length: Optional[int] = None,
reverse: bool = False,
unroll: int = 1) -> Tuple[Carry, Y]:
if unroll != 1:
raise NotImplementedError("Unroll not implemented")
if not callable(f):
raise TypeError("scan: f argument should be a callable.")
if unroll < 1:
raise ValueError("`unroll` must be a positive integer.")
xs_flat, xs_tree = tree_flatten(xs)

try:
Expand Down Expand Up @@ -233,7 +240,8 @@ def for_body(i, refs):
tree_map(lambda c_ref, c: ref_set(c_ref, (), c), carry_refs, carry)
tree_map(lambda y_ref, y: ref_set(y_ref, (i,), y), ys_refs, y)
assert isinstance(length, int)
init, _, ys = for_loop(length, for_body, (init, xs, ys), reverse=reverse)
init, _, ys = for_loop(length, for_body, (init, xs, ys), reverse=reverse,
unroll=unroll)
return init, ys

def _get_ref_state_effects(jaxpr: core.Jaxpr) -> List[Set[StateEffect]]:
Expand All @@ -255,33 +263,50 @@ def _for_abstract_eval(*avals, jaxpr, **__):
@state.register_discharge_rule(for_p)
def _for_discharge_rule(in_avals, *args: Any, jaxpr: core.Jaxpr,
reverse: bool, which_linear: Sequence[bool],
nsteps: int
nsteps: int, unroll: int
) -> Tuple[Sequence[Optional[Any]], Sequence[Any]]:
out_vals = for_p.bind(*args, jaxpr=jaxpr, reverse=reverse,
which_linear=which_linear, nsteps=nsteps)
which_linear=which_linear, nsteps=nsteps,
unroll=unroll)
new_invals = []
for aval, out_val in zip(in_avals, out_vals):
new_invals.append(out_val if isinstance(aval, ShapedArrayRef) else None)
return new_invals, out_vals

def _for_impl(*args, jaxpr, nsteps, reverse, which_linear):
def _for_impl(*args, jaxpr, nsteps, reverse, which_linear, unroll):
del which_linear
discharged_jaxpr, consts = discharge_state(jaxpr, ())
def body(i, state):
i_ = nsteps - i - 1 if reverse else i
return core.eval_jaxpr(discharged_jaxpr, consts, i_, *state)
return _for_impl_unrolled(body, nsteps, unroll, *args)

def _for_impl_unrolled(body, nsteps, unroll, *args):
remainder = nsteps % unroll
i = jnp.int32(0)
state = list(args)

for _ in range(remainder):
state = body(i, state)
i = i + 1

def cond(carry):
i, _ = carry
return i < nsteps
def body(carry):
def while_body(carry):
i, state = carry
i_ = nsteps - i - 1 if reverse else i
next_state = core.eval_jaxpr(discharged_jaxpr, consts, i_, *state)
return i + 1, next_state
_, state = lax.while_loop(cond, body, (jnp.int32(0), list(args)))
for _ in range(unroll):
state = body(i, state)
i = i + 1
return i, state
_, state = lax.while_loop(cond, while_body, (i, state))
return state

mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True))
for_p.def_impl(partial(xla.apply_primitive, for_p))

def _for_vmap(axis_size, axis_name, main_type, args, dims, *,
jaxpr, nsteps, reverse, which_linear):
jaxpr, nsteps, reverse, which_linear, unroll):
init_batched = [d is not batching.not_mapped for d in dims]
discharged_jaxpr, body_consts = discharge_state(jaxpr, ())
batched = init_batched
Expand All @@ -303,11 +328,13 @@ def _for_vmap(axis_size, axis_name, main_type, args, dims, *,
axis_name=axis_name, main_type=main_type)
batched_jaxpr, () = batched_jaxpr_.jaxpr, batched_jaxpr_.consts # TODO consts
out_flat = for_p.bind(*args, jaxpr=batched_jaxpr, nsteps=nsteps,
reverse=reverse, which_linear=which_linear)
reverse=reverse, which_linear=which_linear,
unroll=unroll)
return out_flat, [0 if b else batching.not_mapped for b in batched]
batching.axis_primitive_batchers[for_p] = _for_vmap

def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear):
def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear,
unroll):
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
# We need to find out which `Ref`s have nonzero tangents after running the
# for loop. Ordinarily we do this with a fixed point on the body jaxpr but
Expand All @@ -334,7 +361,7 @@ def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear):
jvp_which_linear = which_linear + (True,) * len(tangents)
out_flat = for_p.bind(*primals, *tangents, jaxpr=jvp_jaxpr,
nsteps=nsteps, reverse=reverse,
which_linear=jvp_which_linear)
which_linear=jvp_which_linear, unroll=unroll)
# `out_flat` includes constant inputs into the `for_loop` which are converted
# into outputs as well. We don't care about these in AD so we throw them out.
out_primals, out_tangents = split_list(out_flat, [len(primals)])
Expand Down Expand Up @@ -388,7 +415,8 @@ def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> List[bool]:

def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
jaxpr: core.Jaxpr, nsteps: int, reverse: bool,
which_linear: Tuple[bool, ...]) -> List[pe.JaxprTracer]:
which_linear: Tuple[bool, ...],
unroll: int) -> List[pe.JaxprTracer]:
num_inputs = len(tracers)
assert num_inputs == len(jaxpr.invars) - 1
in_unknowns = [not t.pval.is_known() for t in tracers]
Expand Down Expand Up @@ -454,7 +482,8 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
# necessarily okay for general partial eval.
jaxpr_known_which_linear = (False,) * len(jaxpr_known_args)
out_flat = for_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, nsteps=nsteps,
reverse=reverse, which_linear=jaxpr_known_which_linear)
reverse=reverse, which_linear=jaxpr_known_which_linear,
unroll=unroll)
known_outputs, residuals = split_list(out_flat, [len(known_tracers)])
residuals = map(trace.new_instantiated_const, residuals)

Expand Down Expand Up @@ -495,7 +524,8 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs,
for_p, dict(jaxpr=jaxpr_unknown, nsteps=nsteps,
reverse=reverse,
which_linear=which_linear_unknown),
which_linear=which_linear_unknown,
unroll=unroll),
core.no_effects, source)
for t in res_ref_unknown_outputs: t.recipe = eqn
_, unknown_outputs = split_list(res_ref_unknown_outputs, [num_res])
Expand All @@ -504,8 +534,8 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
pe.custom_partial_eval_rules[for_p] = _for_partial_eval

def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
jaxpr, nsteps, reverse, which_linear = split_dict(
eqn.params, ["jaxpr", "nsteps", "reverse", "which_linear"])
jaxpr, nsteps, reverse, which_linear, unroll = split_dict(
eqn.params, ["jaxpr", "nsteps", "reverse", "which_linear", "unroll"])
num_inputs = len(eqn.invars)
# We first need to run a fixpoint to determine which of the `Ref`s are unknown
# after running the for loop. However, the jaxpr has no outputs. Instead, we
Expand Down Expand Up @@ -576,7 +606,8 @@ def known(*known_vals):
jaxpr_known_args = [*known_vals, *empty_res]
jaxpr_known_which_linear = (False,) * len(jaxpr_known_args)
return for_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, nsteps=nsteps,
reverse=reverse, which_linear=jaxpr_known_which_linear)
reverse=reverse, which_linear=jaxpr_known_which_linear,
unroll=unroll)
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
known, [v.aval for v in known_invars])
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
Expand All @@ -590,7 +621,8 @@ def known(*known_vals):
which_linear_unknown = (False,) * num_res + tuple(which_linear)
params_staged = dict(eqn.params, jaxpr=jaxpr_staged, reverse=reverse,
nsteps=nsteps,
which_linear=which_linear_unknown)
which_linear=which_linear_unknown,
unroll=unroll)

@lu.wrap_init
def staged(*res_and_refs):
Expand Down Expand Up @@ -689,7 +721,7 @@ def trans(i, *args):
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
return jaxpr_trans

def _for_transpose(in_cts, *args, jaxpr, nsteps, reverse, which_linear):
def _for_transpose(in_cts, *args, jaxpr, nsteps, reverse, which_linear, unroll):
# if any in_ct is nonzero, we definitely want it in args_ (and the
# corresponding x in args could be an undefined primal, but doesnt have to be)
# for non-res stuff:
Expand Down Expand Up @@ -722,7 +754,8 @@ def _for_transpose(in_cts, *args, jaxpr, nsteps, reverse, which_linear):
assert len(args_) == len(jaxpr_transpose.invars) - 1
all_outs = for_p.bind(*args_, jaxpr=jaxpr_transpose, nsteps=nsteps,
reverse=not reverse,
which_linear=tuple(which_linear_transpose))
which_linear=tuple(which_linear_transpose),
unroll=unroll)
ct_outs = [ct if ad.is_undefined_primal(x) else None
for x, ct in zip(args, all_outs)]
return ct_outs
Expand Down
1 change: 1 addition & 0 deletions tests/for_loop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def inner_body(i, _):
(jax.jit(for_loop.for_loop, static_argnums=(0, 1)), 'jit_for_loop'),
(remat_of_for_loop, 'remat_for_loop'),
(nested_for_loop, 'nested_for_loop'),
(partial(for_loop.for_loop, unroll=3), 'unrolled_for_loop'),
]


Expand Down

0 comments on commit 9791199

Please sign in to comment.