Skip to content

Commit

Permalink
Add a simple form of partial evaluation for while_loop. (google#2497)
Browse files Browse the repository at this point in the history
The issue that I wanted to fix was that when running grad(while_loop),
the error was a cryptic assertion failure (that all primals are known
after linearization, in ad.py:linearize). I could not figure out
how to detect before that assertion that we are doing a reverse AD
for while_loop. So, I implemented a simple form of partial evaluation,
to allow the primals after linearization to be known, so that the
code proceeds and can then fail gracefully when trying to transpose the
while.

This is not a proper implementation of partial evaluation. The known
outputs are computed early, properly. But the unknown outputs
are computed by a *whole* computation of, including the known
parts.

Fixes issue: google#2129
  • Loading branch information
gnecula committed Apr 17, 2020
1 parent aeb0d03 commit 7d716b8
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 24 deletions.
51 changes: 32 additions & 19 deletions jax/interpreters/partial_eval.py
Expand Up @@ -17,7 +17,7 @@
from collections import namedtuple
import contextlib
import threading
from typing import Callable, Dict, Optional, Sequence, Set, Tuple, Union
from typing import Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union
from weakref import ref

import numpy as onp
Expand Down Expand Up @@ -84,6 +84,15 @@ def merge_with_known(self, val: core.Value) -> core.Value:
return known if known is not None else val


# We form Jaxprs using `JaxprTrace` for three distinct purposes:
# (1) to stage program representations completely out of the JAX system
# (e.g. for XLA using jit or pmap). In this case we are using the
# `StagingJaxprTrace` subclass.
# (3) to linearize a function for reverse-mode AD. In this case we are
# using the `JaxprTrace` subclass.
# (2) to build a representation of a function that may require further JAX
# transformations (e.g. in "initial-style" higher-order primitives, like
# for control flow). In this case we use the `JaxprTrace` class.
class JaxprTrace(Trace):
def pure(self, val):
return self.new_const(val)
Expand Down Expand Up @@ -133,6 +142,8 @@ def process_primitive(self, primitive, tracers, params):
return self.default_process_primitive(primitive, tracers, params)

def default_process_primitive(self, primitive, tracers, params):
"""By default, if all the input tracers are known, then execute the primitive
and all the ouputs are known. Otherwise, all the outputs are unknown."""
consts = tuple(t.pval.get_known() for t in tracers)
if all(c is not None for c in consts):
return primitive.bind(*consts, **params)
Expand Down Expand Up @@ -261,15 +272,9 @@ def todo(x):
return out, todo

def process_custom_jvp_call(self, prim, fun, jvp, tracers):
# We form jaxprs using JaxprTraces for two distinct purposes: to stage
# program representations completely out of the JAX system (e.g. for XLA
# using jit or pmap), and to build a representation of a function that may
# require further JAX transformations (e.g. in "initial-style" higher-order
# primitives, like for control flow). In particular, in the latter case we
# need custom differentiation rules to stick around, but in the former we do
# not. This method call should only be reachable in the former case, and so
# we check that the former case is indicated (with a StagingJaxprTrace) and
# then drop the differentiation rules.
# See comment at top of `JaxprTrace`. This method should be reachable
# only when we stage out, and in that case we drop the custom differentiation
# rules, because we do not need them.
assert self.master.trace_type is StagingJaxprTrace
return fun.call_wrapped(*tracers)

Expand All @@ -278,9 +283,9 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
assert self.master.trace_type is StagingJaxprTrace
return fun.call_wrapped(*tracers)

# This subclass is used just for its type tag, which switches the behavior of
# process_call to stage out into the jaxpr any call primitives encountered
# (rather than doing partial evaluation into the call).
# This subclass is used just for its type tag (see comment for `JaxprTrace`)
# This switches the behavior of process_call to stage out into the jaxpr any
# call primitives encountered (rather than doing partial evaluation into the call).
class StagingJaxprTrace(JaxprTrace):
pass

Expand Down Expand Up @@ -367,14 +372,18 @@ def full_lower(self):
else:
return self


# TODO(necula): this should return a TypedJaxpr
# TODO(necula): remove stage_out, replace trace_type=pe.StagingJaxprTrace
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
instantiate: Union[bool, Sequence[bool]] = False,
stage_out=False, bottom=False) \
stage_out=False, bottom=False,
trace_type: Optional[Type[Trace]] = None) \
-> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
"""Traces a function into a Jaxpr, given PartialVals for inputs.
`trace_type` can be one of `StagingJaxprTrace` or `JaxprTrace` (see
comments for that class).
Returns (`jaxpr`, `out_pvals`, `consts`).
The `jaxpr` contains only the computation that depends on unknown inputs.
The `out_pvals` are the PartialVal for the outputs. The intermediate
Expand Down Expand Up @@ -415,7 +424,7 @@ def fun(ki, ui): # ki will be a known input in this example
out_pvals = [abstract(ConcreteArray(6)), abstract(ShapedArray)] # all are unknown PartialVal
consts = [3, 6] # values for `ka` and `kb` constvars
"""
trace_type = StagingJaxprTrace if stage_out else JaxprTrace
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
with new_master(trace_type, bottom=bottom) as master:
fun = trace_to_subjaxpr(fun, master, instantiate)
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
Expand Down Expand Up @@ -549,7 +558,8 @@ def convert_constvars_jaxpr(jaxpr):
return lifted_jaxpr

def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool],
instantiate: Union[bool, Sequence[bool]]
instantiate: Union[bool, Sequence[bool]],
trace_type: Optional[Type[core.Trace]]
) -> Tuple[TypedJaxpr, TypedJaxpr, Sequence[bool]]:
"""Specializes a Jaxpr given an indication of which inputs are known.
Expand Down Expand Up @@ -586,7 +596,8 @@ def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool],
def fun(*vals):
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate,
trace_type=trace_type)
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
return out_consts_2 + consts_2
Expand Down Expand Up @@ -674,7 +685,9 @@ def _remat_partial_eval(trace, _, f, tracers, params):
for var, pv, const in zip(jaxpr.outvars, out_pvs, out_pval_consts1)]
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
in_unknowns = [t.pval[0] is not None for t in it.chain(env, tracers)]
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, False)
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns,
instantiate=False,
trace_type=trace.master.trace_type)
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)

# First, we prune the jaxpr to be staged out not to have too many outputs.
Expand Down
95 changes: 90 additions & 5 deletions jax/lax/lax_control_flow.py
Expand Up @@ -401,12 +401,92 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
for nz in nonzeros_out]
return out_carry, out_tangents

def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: int,
cond_jaxpr: pe.TypedJaxpr, body_nconsts: int,
body_jaxpr: pe.TypedJaxpr) -> Sequence[pe.Tracer]:
"""An implementation of partial evaluation for while.
As long as some carry (and hence output) are known and the output
of `cond_jaxpr` is known, we use a portion of the loop body to compute the known
outputs of the `while_loop`. For the unknown outputs we generate Jaxpr to run
the whole while, including recomputing the known parts.
This means that we don't actually save any computation by partial
evaluation if there are unknown outputs.
What this achieves is that we can give a proper error for reverse
differentiation of `while`, because in that use of partial evaluation the
primal inputs are considered "known", and only the tangent computation is
unknown (see issue #2129).
"""
unknowns = [not t.pval.is_known() for t in tracers]
params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr,
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr)

cond_consts_uk, body_consts_uk, carry_init_uk = split_list(unknowns, [cond_nconsts, body_nconsts])
# Fixpoint computation of unknown carry. Each iteration promotes
# at least one carry to unknown. We need one last iteration to prepare the jaxpr.
carry_uk = carry_init_uk
for _ in range(1 + len(carry_uk)):
body_jaxpr_known, _, carry_out_uk = pe.partial_eval_jaxpr(
body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk,
trace_type=trace.master.trace_type)
if carry_out_uk == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_out_uk)
else:
assert False, "Fixpoint not reached"

cond_jaxpr_known, _, cond_uk = pe.partial_eval_jaxpr(
cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False,
trace_type=trace.master.trace_type)

if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns):
# If conditional is unknown, or all inputs are known, or all are unknown,
# just do the default processing.
return trace.default_process_primitive(while_p, tracers, params)

# Run the known part of the while. Prepare the inputs, as constants (if known), or
# as core.unit.
in_consts = [ core.unit if uk else t.pval.get_known()
for uk, t in zip(cond_consts_uk + body_consts_uk + carry_uk,
tracers)]
# There should be no residuals for the cond_jaxpr_known
assert 1 == len(cond_jaxpr_known.out_avals)
# We ignore the residuals from the body_jaxpr_known, so the type of inputs matches
# the type of outputs; residuals are at the end
if len(body_jaxpr_known.out_avals) > len(body_jaxpr.out_avals):
# TODO(necula): this is not quite enough; we should drop the residual computations also
body_jaxpr_known.out_avals = body_jaxpr_known.out_avals[:len(body_jaxpr.out_avals)]
body_jaxpr_known.jaxpr.outvars = body_jaxpr_known.jaxpr.outvars[:len(body_jaxpr.out_avals)]
out_known = while_p.bind(
*in_consts,
cond_nconsts=cond_nconsts,
cond_jaxpr=cond_jaxpr_known,
body_nconsts=body_nconsts,
body_jaxpr=body_jaxpr_known)

# Run the whole while_loop to get all the outputs, then merge with known ones
out_all: Sequence[pe.Tracer] = trace.default_process_primitive(while_p, tracers, params)
out_tracers: Sequence[pe.Tracer] = [
out_unknown if uk
else pe.JaxprTracer(trace, pe.PartialVal.known(known), out_unknown.recipe)
for uk, out_unknown, known in zip(carry_uk, out_all, out_known)]

return out_tracers

def _while_transpose_error(*_, **kwargs):
raise ValueError("Reverse-mode differentiation does not work for lax.while_loop. "
"Try using lax.scan, or lax.fori_loop with constant bounds.")

while_p = lax.Primitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
xla.initial_style_translations[while_p] = _while_loop_translation_rule
ad.primitive_transposes[while_p] = _while_transpose_error
batching.primitive_batchers[while_p] = _while_loop_batching_rule


Expand Down Expand Up @@ -571,14 +651,18 @@ def _cond_partial_eval(trace, *tracers, true_jaxpr, false_jaxpr, linear):
params = dict(true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr, linear=linear)
return trace.default_process_primitive(cond_p, tracers, params)

_, _, t_out_uks = pe.partial_eval_jaxpr(true_jaxpr, t_uk, instantiate=False)
_, _, f_out_uks = pe.partial_eval_jaxpr(false_jaxpr, f_uk, instantiate=False)
_, _, t_out_uks = pe.partial_eval_jaxpr(true_jaxpr, t_uk, instantiate=False,
trace_type=trace.master.trace_type)
_, _, f_out_uks = pe.partial_eval_jaxpr(false_jaxpr, f_uk, instantiate=False,
trace_type=trace.master.trace_type)
out_uks = [a or b for a, b in zip(t_out_uks, f_out_uks)]

true_jaxpr_1, true_jaxpr_2, _ = pe.partial_eval_jaxpr(true_jaxpr, t_uk,
instantiate=out_uks)
instantiate=out_uks,
trace_type=trace.master.trace_type)
false_jaxpr_1, false_jaxpr_2, _ = pe.partial_eval_jaxpr(false_jaxpr, f_uk,
instantiate=out_uks)
instantiate=out_uks,
trace_type=trace.master.trace_type)

num_t_res = len(true_jaxpr_1.out_avals) - len(out_uks)
num_f_res = len(false_jaxpr_1.out_avals) - len(out_uks)
Expand Down Expand Up @@ -992,7 +1076,8 @@ def _scan_partial_eval(trace, *tracers, forward, length, num_consts, num_carry,
for _ in range(1 + len(carry_uk)):
unknowns = const_uk + carry_uk + xs_uk
jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr(
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys,
trace_type=trace.master.trace_type)
carry_uk_out, ys_uk = out_uk[:num_carry], out_uk[num_carry:]
if carry_uk_out == carry_uk:
break
Expand Down
27 changes: 27 additions & 0 deletions tests/lax_control_flow_test.py
Expand Up @@ -1405,6 +1405,33 @@ def loop(loop_impl, x):

jtu.check_grads(loop_lax, (x,), order=2, modes=["fwd"])

@parameterized.named_parameters(
dict(testcase_name="_loop={}".format(loop), loop=loop)
for loop in ["while", "fori", "fori_inside_cond", "fori_inside_scan"])
def testWhileGradError(self, loop: str = "fori_inside_scan"):
# Raise error for vjp for loops
if loop == "while":
func = lambda x: lax.while_loop(lambda i: i < 5., lambda i: i + 1., x)
elif loop == "fori":
func = lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x)
elif loop == "fori_inside_jit":
func = api.jit(lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x))
elif loop == "fori_inside_cond":
func = lambda x: lax.cond(True, x,
lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x),
1., lambda x: x)
elif loop == "fori_inside_scan":
func = lambda x: lax.scan(lambda c, x: (lax.fori_loop(x, x + 2., lambda i, c1: c1 * c, x),
None),
x, onp.ones(2))[0]
else:
assert False

with self.assertRaisesRegex(ValueError, "Reverse-mode differentiation does not work for lax.while_loop"):
api.grad(func)(1.)

api.linearize(func, 1.) # Linearization works

def testIssue1316(self):
def f(carry, _):
c, key = carry
Expand Down

0 comments on commit 7d716b8

Please sign in to comment.