Skip to content

Commit

Permalink
Add sound loop invariance detection
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Sep 8, 2022
1 parent 2ccd724 commit 6967c7e
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 48 deletions.
88 changes: 41 additions & 47 deletions jax/_src/lax/control_flow/for_loop.py
Expand Up @@ -15,7 +15,7 @@
from functools import partial
import operator

from typing import Any, Callable, Generic, List, Optional, Sequence, Tuple, TypeVar
from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple, TypeVar

from jax import core
from jax import lax
Expand Down Expand Up @@ -54,6 +54,7 @@ class Ref(Generic[T]): pass
ReadEffect = state.ReadEffect
WriteEffect = state.WriteEffect
AccumEffect = state.AccumEffect
StateEffect = state.StateEffect
ShapedArrayRef = state.ShapedArrayRef
ref_set = state.ref_set
ref_get = state.ref_get
Expand Down Expand Up @@ -222,6 +223,11 @@ def for_body(i, refs):
init, _, ys = for_loop(length, for_body, (init, xs, ys), reverse=reverse)
return init, ys

def _get_ref_state_effects(jaxpr: core.Jaxpr) -> List[Set[StateEffect]]:
all_effects = jaxpr.effects
return [{eff for eff in all_effects
if isinstance(eff, (ReadEffect, WriteEffect, AccumEffect))
and eff.ref_aval is v.aval} for v in jaxpr.invars]

@for_p.def_abstract_eval
def _for_abstract_eval(*avals, jaxpr, **__):
Expand Down Expand Up @@ -317,6 +323,38 @@ def _partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy):

_save_everything = lambda *_, **__: True

def _is_read_only(ref_effects: Set[StateEffect]) -> bool:
assert len(ref_effects) > 0
if len(ref_effects) > 1:
# Means we must have a write or accum effect so not read-only
return False
eff, = ref_effects
return isinstance(eff, ReadEffect)

def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> List[bool]:
# Get effects for each of the jaxpr inputs and remove the loop index.
ref_effects = _get_ref_state_effects(jaxpr)[1:]
# We first assume that *read-only `Ref`s* are loop-invariant. We can safely do
# this because the only way something can be loop-varying is if we write to it
# at some point. It's *possible* that read-write `Ref`s are loop-invariant but
# we conservatively assume they aren't.
loop_invar_refs = [_is_read_only(effs) if effs else True
for effs in ref_effects]
loop_var_refs = map(operator.not_, loop_invar_refs)

# We'd like to detect if the outputs of the jaxpr are loop-invariant. An
# output is loop-invariant if it is downstream of only loop-invariant values
# (seeded by the read-only `Ref`s). If at any point, a loop-varying value
# interacts with a loop-invariant value, we produce a loop-varying value. We
# can use `partial_eval` to perform this analysis by treating loop-varying
# values as "unknown" and loop-invariant values as "known", since when a known
# and unknown value interact, they produce an unknown value.
loop_var_inputs = [True, *loop_var_refs]
_, _, loop_var_outputs, _, _, = _partial_eval_jaxpr_custom(
jaxpr, loop_var_inputs, _save_everything)
return map(operator.not_, loop_var_outputs)


def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
jaxpr: core.Jaxpr, nsteps: int, reverse: bool,
which_linear: Tuple[bool, ...]) -> List[pe.JaxprTracer]:
Expand Down Expand Up @@ -371,29 +409,7 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
# dependent on the loop index. If a residual is not dependent on the loop
# index, we don't need add an extra loop dimension we're reading from when we
# convert it from an output into a write.

# In order to detect which residuals are loop-invariant, we need to run a
# fixpoint. This is because the residual could be dependent on a `Ref` that
# changes each iteration of the loop so we need to first detect which `Ref`s
# are loop-varying. We can do this by discharging the state from the jaxpr and
# running partial_eval with initially only the loop-index being loop-varying.
# The fixpoint will eventually propagate the loop-varying-ness over the
# inputs/outputs and we will converge.
loop_var_res = [False] * len(jaxpr_known_resout.outvars)
loop_var_refs = [False] * (len(jaxpr_known_resout.invars) - 1)
discharged_jaxpr_known_resout = core.ClosedJaxpr(
*discharge_state(jaxpr_known_resout, ()))
for _ in range(len(discharged_jaxpr_known_resout.jaxpr.invars)):
(_, _, loop_var_outputs, _) = pe.partial_eval_jaxpr_nounits(
discharged_jaxpr_known_resout, [True] + loop_var_refs, False)
loop_var_res, loop_var_refs_ = split_list(
loop_var_outputs, [len(loop_var_res)])
if loop_var_refs == loop_var_refs_:
break
loop_var_refs = map(operator.or_, loop_var_refs, loop_var_refs_)
# Now that the fixpoint is complete, we know which residuals are
# loop-invariant.
loop_invar_res = map(operator.not_, loop_var_res)
loop_invar_res = _loop_invariant_outputs(jaxpr_known_resout)

jaxpr_known, res_avals = _convert_outputs_to_writes(nsteps,
jaxpr_known_resout,
Expand Down Expand Up @@ -504,29 +520,7 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
# dependent on the loop index. If a residual is not dependent on the loop
# index, we don't need add an extra loop dimension we're reading from when we
# convert it from an output into a write.

# In order to detect which residuals are loop-invariant, we need to run a
# fixpoint. This is because the residual could be dependent on a `Ref` that
# changes each iteration of the loop so we need to first detect which `Ref`s
# are loop-varying. We can do this by discharging the state from the jaxpr and
# running partial_eval with initially only the loop-index being loop-varying.
# The fixpoint will eventually propagate the loop-varying-ness over the
# inputs/outputs and we will converge.
loop_var_res = [False] * len(jaxpr_known_resout.outvars)
loop_var_refs = [False] * (len(jaxpr_known_resout.invars) - 1)
discharged_jaxpr_known_resout = core.ClosedJaxpr(
*discharge_state(jaxpr_known_resout, ()))
for _ in range(len(discharged_jaxpr_known_resout.jaxpr.invars)):
(_, _, loop_var_outputs, _) = pe.partial_eval_jaxpr_nounits(
discharged_jaxpr_known_resout, [True] + loop_var_refs, False)
loop_var_res, loop_var_refs_ = split_list(
loop_var_outputs, [len(loop_var_res)])
if loop_var_refs == loop_var_refs_:
break
loop_var_refs = map(operator.or_, loop_var_refs, loop_var_refs_)
# Now that the fixpoint is complete, we know which residuals are
# loop-invariant.
loop_invar_res = map(operator.not_, loop_var_res)
loop_invar_res = _loop_invariant_outputs(jaxpr_known_resout)

jaxpr_known, res_avals = _convert_outputs_to_writes(nsteps,
jaxpr_known_resout,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/state/__init__.py
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Module for state."""
from jax._src.state.types import (ShapedArrayRef, ReadEffect, WriteEffect,
AccumEffect)
AccumEffect, StateEffect)
from jax._src.state.primitives import (ref_get, ref_set, ref_swap,
ref_addupdate, get_p, swap_p,
addupdate_p)
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/state/types.py
Expand Up @@ -69,6 +69,8 @@ class AccumEffect(RefEffect):
def __str__(self):
return f"Accum<{self.ref_aval}>"

StateEffect = Union[ReadEffect, WriteEffect, AccumEffect]

# ## `Ref`s

# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
Expand Down
23 changes: 23 additions & 0 deletions tests/lax_control_flow_test.py
Expand Up @@ -2853,6 +2853,29 @@ def f(a, b):
np.testing.assert_allclose(actual_tangents[0], expected_tangents[0])
np.testing.assert_allclose(actual_tangents[1], expected_tangents[1])

def body2(_, refs):
# Here we use `i_ref` as a loop counter
a_ref, b_ref, c_ref, i_ref = refs
i = i_ref[()]
a = a_ref[i]
b = b_ref[()]
x = jnp.sin(a)
b_ref[()] = jnp.sin(b * x)
c_ref[i] = x * b
i_ref[()] = i + 1

def g(a, b):
c = jnp.zeros_like(a)
_, b, c, _ = for_impl(5, body2, (a, b, c, 0))
return b, c
a = jnp.arange(5.) + 1.
b = 1.
_, g_lin = jax.linearize(f, a, b)
expected_tangents = g_lin(a, b)
_, actual_tangents = jax.jvp(g, (a, b), (a, b))
np.testing.assert_allclose(actual_tangents[0], expected_tangents[0])
np.testing.assert_allclose(actual_tangents[1], expected_tangents[1])

@parameterized.named_parameters(
{"testcase_name": "_f={}_nsteps={}_impl={}".format(
for_body_name, nsteps, impl_name),
Expand Down

0 comments on commit 6967c7e

Please sign in to comment.