Skip to content

Commit

Permalink
Add discharge rules for scan/while
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Jul 6, 2023
1 parent f08e52f commit c446b42
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 3 deletions.
137 changes: 136 additions & 1 deletion jax/_src/lax/control_flow/loops.py
Expand Up @@ -46,10 +46,12 @@
from jax._src.lax import windowed_reductions
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src import state
from jax._src.state import discharge as state_discharge
from jax._src.numpy.ufuncs import logaddexp
from jax._src.traceback_util import api_boundary
from jax._src.util import (partition_list, safe_map, safe_zip, split_list,
unzip2, weakref_lru_cache)
unzip2, weakref_lru_cache, merge_lists)
import numpy as np

from jax._src.lax.control_flow.common import (
Expand Down Expand Up @@ -1029,6 +1031,61 @@ def _scan_pp_rule(eqn, context, settings):
del printed_params['reverse']
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)

def _scan_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
num_carry, linear, unroll, reverse, length):
jaxpr, consts = jaxpr.jaxpr, jaxpr.consts
if consts: raise NotImplementedError
consts, carry, xs = split_list(args, [num_consts, num_carry])
consts_linear, carry_linear, xs_linear = split_list(
linear, [num_consts, num_carry])
consts_avals, carry_avals, xs_avals = split_list(in_avals,
[num_consts, num_carry])
is_ref = [isinstance(a, state.AbstractRef) for a in consts_avals]
remaining_const_avals, in_ref_avals = partition_list(is_ref, consts_avals)
remaining_consts, in_refs = partition_list(is_ref, consts)
remaining_consts_linear, in_refs_linear = partition_list(is_ref, consts_linear)
num_refs = sum(is_ref)
num_extensive_in = len(in_avals) - num_carry - num_consts
num_extensive_out = len(out_avals) - num_carry
num_remaining_consts = num_consts - num_refs
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ())
if discharged_consts:
raise NotImplementedError("Discharged jaxpr has consts. If you see this, "
"please open an issue at "
"https://github.com/google/jax/issues")
# The discharged jaxpr will have output refs stashed at the end
def wrapped(*refs_and_args):
consts, refs, carry, xs = split_list(refs_and_args, [num_remaining_consts,
num_refs,
num_carry])
consts_with_refs = merge_lists(is_ref, consts, refs)
outs_and_refs = core.eval_jaxpr(discharged_jaxpr, (), *consts_with_refs,
*carry, *xs)
carry, ys, out_refs = split_list(outs_and_refs, [num_carry,
num_extensive_out])
assert len(out_refs) == num_refs
return [*out_refs, *carry, *ys]
new_in_avals = [*remaining_const_avals, *[a.inner_aval for a in in_ref_avals],
*carry_avals,
*[core.mapped_aval(length, 0, a) for a in xs_avals]]
new_jaxpr, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped),
new_in_avals)
new_linear = (*remaining_consts_linear, *in_refs_linear,
*carry_linear, *xs_linear)
all_out = scan_p.bind(*remaining_consts, *in_refs, *carry, *xs,
jaxpr=core.ClosedJaxpr(new_jaxpr, ()),
length=length,
num_consts=num_remaining_consts,
num_carry=num_refs + num_carry,
unroll=unroll,
reverse=reverse,
linear=new_linear)
refs_out, carry_out, ys_out = split_list(all_out, [num_refs, num_carry])
new_invals = [*merge_lists(is_ref, [None] * num_remaining_consts, refs_out),
*[None] * num_carry, *[None] * num_extensive_in]
assert len(new_invals) == len(in_avals)
return new_invals, [*carry_out, *ys_out]

def scan_bind(*args, **params):
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
Expand All @@ -1054,6 +1111,7 @@ def scan_bind(*args, **params):
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
pe.padding_rules[scan_p] = _scan_padding_rule
pe.dce_rules[scan_p] = _scan_dce_rule
state_discharge.register_discharge_rule(scan_p)(_scan_discharge_rule)
# TODO(mattjj,frostig): un-comment this pp rule
# core.pp_eqn_rules[scan_p] = _scan_pp_rule

Expand Down Expand Up @@ -1618,6 +1676,82 @@ def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
f'Effects not supported in `while`: {disallowed_effects}')
return body_jaxpr.out_avals, joined_effects

def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
cond_nconsts, body_nconsts):
# TODO(sharadmv): enable supporting state effects in the cond
if any(isinstance(eff, state.RefEffect) for eff in cond_jaxpr.effects):
raise NotImplementedError
cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
cond_consts_avals, body_consts_avals, carry_avals = split_list(in_avals,
[cond_nconsts,
body_nconsts])
# There shouldn't be any `Ref`s in the `cond` (because of our check above).
assert not any(isinstance(aval, state.AbstractRef) for aval in cond_consts_avals)
is_ref = [isinstance(aval, state.AbstractRef) for aval in body_consts_avals]
remaining_body_consts, refs = partition_list(is_ref, body_consts)
remaining_body_const_avals, ref_avals = partition_list(is_ref,
body_consts_avals)
num_refs = sum(is_ref)
num_remaining_consts = body_nconsts - num_refs
num_carry = len(in_avals) - body_nconsts - cond_nconsts
body_jaxpr, body_jaxpr_consts = body_jaxpr.jaxpr, body_jaxpr.consts
cond_jaxpr, cond_jaxpr_consts = cond_jaxpr.jaxpr, cond_jaxpr.consts
if body_jaxpr_consts:
raise NotImplementedError("Body jaxpr has consts. If you see this error, "
"please open an issue at "
"https://github.com/google/jax/issues")
# body_jaxpr has the signature (*body_consts, *carry) -> carry.
# Some of these body_consts are actually `Ref`s so when we discharge
# them, they also turn into outputs, effectively turning those consts into
# carries. However this doesn't fit the expected signature for the body_jaxpr.
# Therefore we need to rewrite the jaxpr to shuffle around the `Ref`s so that
# they are part of the carry.
discharged_body_jaxpr, discharged_consts = state_discharge.discharge_state(
body_jaxpr, ())
if discharged_consts: raise NotImplementedError

def new_body(*consts_refs_carry):
consts, refs, carry = split_list(
consts_refs_carry, [num_remaining_consts, num_refs])
consts_and_refs = merge_lists(is_ref, consts, refs)
carry_refs = core.eval_jaxpr(discharged_body_jaxpr, (), *consts_and_refs,
*carry)
carry, refs_out = split_list(carry_refs, [num_carry])
return [*refs_out, *carry]
new_body_jaxpr, _, new_body_consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(new_body), [*remaining_body_const_avals, *[a.inner_aval for a
in ref_avals],
*carry_avals])
if new_body_consts: raise NotImplementedError

# Since some `Ref`s that were previously consts are now carries, we need to
# deal with them (i.e. ignore them) in the `cond`, so we need to rewrite the
# cond_jaxpr as well.
def new_cond(*consts_refs_carry):
consts, refs, carry = split_list(
consts_refs_carry, [cond_nconsts, num_refs])
del refs # We don't use them here!
return core.eval_jaxpr(cond_jaxpr, cond_jaxpr_consts, *consts, *carry)
new_cond_jaxpr, _, new_cond_consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(new_cond), [*cond_consts_avals,
*[a.inner_aval for a in ref_avals],
*carry_avals])
if new_cond_consts: raise NotImplementedError

out = while_p.bind(*cond_consts, *remaining_body_consts, *refs, *carry,
body_jaxpr=core.ClosedJaxpr(new_body_jaxpr, ()),
cond_jaxpr=core.ClosedJaxpr(new_cond_jaxpr, ()),
body_nconsts=num_remaining_consts,
cond_nconsts=cond_nconsts)
refs_out, carry_out = split_list(out, [num_refs])
updated_body_consts = merge_lists(is_ref, [None] * num_remaining_consts,
refs_out)
invals_out = [
*[None] * cond_nconsts,
*updated_body_consts,
*[None] * num_carry]
return invals_out, carry_out

while_p = core.AxisPrimitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(dispatch.apply_primitive, while_p))
Expand All @@ -1631,6 +1765,7 @@ def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
mlir.register_lowering(while_p, _while_lowering)
core.custom_typechecks[while_p] = _while_typecheck
state_discharge.register_discharge_rule(while_p)(_while_discharge_rule)


def _pred_bcast_select_hlo(ctx,
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/state/discharge.py
Expand Up @@ -26,12 +26,12 @@
from jax._src import ad_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src import source_info_util
from jax._src import tree_util
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax
from jax._src.lax import slicing as lax_slicing
from jax._src.state.types import AbstractRef, RefEffect
Expand Down
33 changes: 33 additions & 0 deletions tests/state_test.py
Expand Up @@ -1170,6 +1170,39 @@ def false_fun():
with self.assertRaises(NotImplementedError):
jax.grad(f)(3.)

def test_while_with_state_in_body(self):
def f(x, y, z):
@run_state
def body(x_ref):
def cond(i):
return i < y
def body(i):
x_ref[...] += z
return i + 1
lax.while_loop(cond, body, 0)
return body(x)
jaxpr = jax.make_jaxpr(f)(0, 5, 2).jaxpr
self.assertEmpty(jaxpr.effects)
self.assertAllClose(jax.jit(f)(0, 5, 2), 10)
self.assertAllClose(jax.jit(f)(1, 2, 3), 7)

def test_scan_with_state_in_body(self):
def f(x, w, y, zs):
@run_state
def body(refs):
x_ref, w_ref = refs
def body(y, z):
x_ref[...] += y
w_ref[...] += z
return y + 1, ()
lax.scan(body, y, zs)
return body((x, w))
zs = jnp.arange(5)
jaxpr = jax.make_jaxpr(f)(0, 1, 5, zs).jaxpr
self.assertEmpty(jaxpr.effects)
self.assertAllClose(jax.jit(f)(0, 1, 5, zs), (35, 11))
self.assertAllClose(jax.jit(f)(1, 1, 2, zs), (21, 11))

class GeneralRefTest(jtu.JaxTestCase):

def test_unshaped_ref(self):
Expand Down

0 comments on commit c446b42

Please sign in to comment.