Skip to content

Commit

Permalink
Add discharge rules for scan with mutable arrays. Move mutable array …
Browse files Browse the repository at this point in the history
…tests to separate file.

Co-authored-by: Matt Johnson <mattjj@google.com>
  • Loading branch information
2 people authored and selamw1 committed May 2, 2024
1 parent 8168eeb commit 9e95e10
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 202 deletions.
6 changes: 4 additions & 2 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,13 +2014,15 @@ def mutable_array(init_val):
return mutable_array_p.bind(init_val)
mutable_array_p = Primitive('mutable_array')

class InternalMutableArray(effects.Effect):
class InternalMutableArrayEffect(effects.Effect):
pass
internal_mutable_array_effect = InternalMutableArrayEffect()
effects.control_flow_allowed_effects.add_type(InternalMutableArrayEffect)

@mutable_array_p.def_effectful_abstract_eval
def mutable_array_abstract_eval(init_aval):
from jax._src.state.types import AbstractRef # type: ignore[import]
return AbstractRef(init_aval), {InternalMutableArray}
return AbstractRef(init_aval), {internal_mutable_array_effect}

@mutable_array_p.def_impl
def _mutable_array_impl(init_val):
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,7 +2046,7 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
assert next(out_layouts_, None) is None
else:
inout_aliases = mut = None
if any(isinstance(e, core.InternalMutableArray) for e in closed_jaxpr.effects):
if any(isinstance(e, core.InternalMutableArrayEffect) for e in closed_jaxpr.effects):
closed_jaxpr = _discharge_internal_refs(closed_jaxpr)

return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
Expand Down
102 changes: 60 additions & 42 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@
from jax._src.traceback_util import api_boundary
from jax._src.typing import Array
from jax._src.util import (partition_list, safe_map, safe_zip, split_list,
unzip2, weakref_lru_cache, merge_lists)
split_list_checked, unzip2, weakref_lru_cache,
merge_lists)
import numpy as np

from jax._src.lax.control_flow.common import (
Expand Down Expand Up @@ -1201,57 +1202,74 @@ def _scan_pp_rule(eqn, context, settings):
def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
num_carry, linear, unroll, reverse, length,
_split_transpose):
jaxpr, consts = jaxpr.jaxpr, jaxpr.consts
# We're shuffling parameters between three signatures for the scan body:
# jaxpr : (n_consts, n_carry, n_xs) -> (n_carry, n_ys)
# discharged : (n_consts, n_carry, n_xs) -> (n_carry, n_ys, n_ref_consts, n_ref_xs)
# wrapped : (n_val_consts, (n_ref_consts, n_carry), (n_val_xs, n_ref_xs))
# -> ((n_ref_consts, n_carry), (n_ys, n_ref_xs))
# where we partition consts and xs between ref and non-ref versions:
# n_carry = (n_val_consts, n_ref_consts)
# n_xs = (n_val_xs, n_ref_xs)

# avals from jaxpr (i.e. rank-reduced) rather than from caller
jaxpr, in_avals, out_avals, consts = jaxpr.jaxpr, jaxpr.in_avals, jaxpr.out_avals, jaxpr.consts
if consts: raise NotImplementedError
consts, carry, xs = split_list(args, [num_consts, num_carry])
consts_linear, carry_linear, xs_linear = split_list(
linear, [num_consts, num_carry])
consts_avals, carry_avals, xs_avals = split_list(in_avals,
[num_consts, num_carry])
is_ref = [isinstance(a, state.AbstractRef) for a in consts_avals]
remaining_const_avals, in_ref_avals = partition_list(is_ref, consts_avals)
remaining_consts, in_refs = partition_list(is_ref, consts)
remaining_consts_linear, in_refs_linear = partition_list(is_ref, consts_linear)
num_refs = sum(is_ref)
num_extensive_in = len(in_avals) - num_carry - num_consts
num_extensive_out = len(out_avals) - num_carry
num_remaining_consts = num_consts - num_refs
n_consts = num_consts
n_carry = num_carry
n_xs = len(in_avals) - n_consts - n_carry
n_ys = len(out_avals) - n_carry
consts_avals, carry_avals, xs_avals = split_list_checked(in_avals,
[n_consts, n_carry, n_xs])
is_ref_const = [isinstance(a, state.AbstractRef) for a in consts_avals]
assert not any(isinstance(a, state.AbstractRef) for a in carry_avals)
is_ref_xs = [isinstance(a, state.AbstractRef) for a in xs_avals]
n_ref_consts = sum(is_ref_const)
n_val_consts = n_consts - n_ref_consts
n_ref_xs = sum(is_ref_xs)
n_val_xs = n_xs - n_ref_xs
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ())
if discharged_consts:
raise NotImplementedError("Discharged jaxpr has consts. If you see this, "
"please open an issue at "
"https://github.com/google/jax/issues")
# The discharged jaxpr will have output refs stashed at the end
def wrapped(*refs_and_args):
consts, refs, carry, xs = split_list(refs_and_args, [num_remaining_consts,
num_refs,
num_carry])
consts_with_refs = merge_lists(is_ref, consts, refs)
outs_and_refs = core.eval_jaxpr(discharged_jaxpr, (), *consts_with_refs,
*carry, *xs)
carry, ys, out_refs = split_list(outs_and_refs, [num_carry,
num_extensive_out])
assert len(out_refs) == num_refs
return [*out_refs, *carry, *ys]
new_in_avals = [*remaining_const_avals, *[a.inner_aval for a in in_ref_avals],
*carry_avals,
*[core.mapped_aval(length, 0, a) for a in xs_avals]]
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), new_in_avals)
new_linear = (*remaining_consts_linear, *in_refs_linear,
*carry_linear, *xs_linear)
all_out = scan_p.bind(*remaining_consts, *in_refs, *carry, *xs,
def wrapped(*wrapped_args):
val_consts, ref_consts_in, carry_in, val_xs, ref_xs_in = split_list_checked(wrapped_args,
[n_val_consts, n_ref_consts, n_carry, n_val_xs, n_ref_xs])
consts = merge_lists(is_ref_const, val_consts, ref_consts_in)
xs = merge_lists(is_ref_xs, val_xs, ref_xs_in)
outs = core.eval_jaxpr(discharged_jaxpr, (), *consts, *carry_in, *xs)
carry_out, ys, ref_consts_out, ref_xs_out = split_list_checked(outs,
[n_carry, n_ys, n_ref_consts, n_ref_xs])
return [*ref_consts_out, *carry_out, *ys, *ref_xs_out]

def arrange_jaxpr_args_for_wrapped(args):
consts, carry_in, xs = split_list_checked(args, [n_consts, n_carry, n_xs])
val_consts, ref_consts_in = partition_list(is_ref_const, consts)
val_xs, ref_xs_in = partition_list(is_ref_xs, xs)
return *val_consts, *ref_consts_in, *carry_in, *val_xs, *ref_xs_in

args_for_wrapped = arrange_jaxpr_args_for_wrapped(args)
linear_for_wrapped = arrange_jaxpr_args_for_wrapped(linear)
avals_for_wrapped = arrange_jaxpr_args_for_wrapped(in_avals)
avals_for_wrapped_no_refs = [aval.inner_aval if isinstance(aval, state.AbstractRef) else aval
for aval in avals_for_wrapped]
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), avals_for_wrapped_no_refs)
all_out = scan_p.bind(*args_for_wrapped,
jaxpr=core.ClosedJaxpr(new_jaxpr, ()),
length=length,
num_consts=num_remaining_consts,
num_carry=num_refs + num_carry,
num_consts=n_val_consts,
num_carry=n_ref_consts + n_carry,
unroll=unroll,
reverse=reverse,
linear=new_linear, _split_transpose=_split_transpose)
refs_out, carry_out, ys_out = split_list(all_out, [num_refs, num_carry])
new_invals = [*merge_lists(is_ref, [None] * num_remaining_consts, refs_out),
*[None] * num_carry, *[None] * num_extensive_in]
assert len(new_invals) == len(in_avals)
return new_invals, [*carry_out, *ys_out]
linear=linear_for_wrapped, _split_transpose=_split_transpose)
ref_consts_out, carry_out, ys, ref_xs_out = split_list_checked(all_out,
[n_ref_consts, n_carry, n_ys, n_ref_xs])
refs_out_matching_in_avals = [
*merge_lists(is_ref_const, [None] * n_val_consts, ref_consts_out),
*[None] * n_carry,
*merge_lists(is_ref_xs, [None] * n_val_xs, ref_xs_out)]
assert len(refs_out_matching_in_avals) == len(in_avals)
return refs_out_matching_in_avals, [*carry_out, *ys]

def scan_bind(*args, **params):
if config.enable_checks.value:
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/state/discharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ def _eval_jaxpr_discharge_state(
for eqn in jaxpr.eqns:
if eqn.primitive is core.mutable_array_p:
[invar], [outvar] = eqn.invars, eqn.outvars
init_val = env.read(invar)
env.write(outvar, init_val)
ans = env.read(invar)
refs_to_discharge.add(id(outvar.aval))
elif any(id(v.aval) in refs_to_discharge for v in eqn.invars):
elif (any(id(v.aval) in refs_to_discharge for v in eqn.invars)
or core.internal_mutable_array_effect in eqn.effects ):
if eqn.primitive not in _discharge_rules:
raise NotImplementedError("No state discharge rule implemented for "
f"primitive: {eqn.primitive}")
Expand Down
9 changes: 9 additions & 0 deletions jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]:
lists.append(args)
return lists

def split_list_checked(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]:
args = list(args)
assert sum(ns) == len(args)
lists = []
for n in ns:
lists.append(args[:n])
args = args[n:]
return lists

def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T]]:
assert len(bs) == len(l)
lists = [], [] # type: ignore
Expand Down
1 change: 1 addition & 0 deletions jax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
safe_zip as safe_zip,
split_dict as split_dict,
split_list as split_list,
split_list_checked as split_list_checked,
split_merge as split_merge,
subvals as subvals,
toposort as toposort,
Expand Down
5 changes: 5 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,11 @@ jax_test(
deps = py_deps("hypothesis"),
)

jax_test(
name = "mutable_array_test",
srcs = ["mutable_array_test.py"],
)

jax_test(
name = "for_loop_test",
srcs = ["for_loop_test.py"],
Expand Down

0 comments on commit 9e95e10

Please sign in to comment.