Skip to content

Commit

Permalink
Add scan and while rule for jax.experimental.callback transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Mar 10, 2021
1 parent 23b82b9 commit ddaef19
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 3 deletions.
79 changes: 76 additions & 3 deletions jax/experimental/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools as it

from typing import Any, Callable, Dict, Sequence, Union

import jax.numpy as jnp

from jax import core
from jax.core import Trace, Tracer
from jax.core import Trace, Tracer, jaxpr_as_fun
from jax import lax
from jax import linear_util as lu
from jax._src.util import partial, safe_map, wraps
from jax._src.util import partial, safe_map, wraps, split_list
from jax._src.lax import control_flow as lcf

import inspect
from jax.api_util import flatten_fun_nokwargs
from jax.tree_util import tree_flatten, tree_unflatten
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, tree_leaves, tree_map

map = safe_map

Expand Down Expand Up @@ -143,6 +147,8 @@ def sublift(self, val):
return CallbackTracer(self, val.val)

def process_primitive(self, primitive, tracers, params):
if primitive in custom_callback_rules:
return custom_callback_rules[primitive](self, *tracers, **params)
vals_in = [t.val for t in tracers]
vals_out = self.main.callback(primitive, vals_in, params) # type: ignore
if primitive.multiple_results:
Expand All @@ -169,3 +175,70 @@ def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
# TODO(sharadmv): don't drop the custom derivative rule
del primitive, fwd, bwd, out_trees # Unused.
return fun.call_wrapped(*tracers)


custom_callback_rules: Dict[Any, Any] = {}

def _scan_callback_rule(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
const_tracers, carry_tracers, xs_tracers = split_list(tracers, [num_consts, num_carry])
carry_avals, xs_avals = tree_map(lambda x: x.aval, (carry_tracers, xs_tracers))
const_vals, carry_vals, xs_vals = tree_map(lambda x: x.val, (const_tracers, carry_tracers, xs_tracers))

x_tracers = [t[0] for t in xs_tracers]
x_avals = [t.aval for t in x_tracers]

body_fun = jaxpr_as_fun(jaxpr)

def new_body(carry, x):
flat_args = tree_leaves((carry, x))
out = body_fun(*(const_vals + flat_args))
out_carry, y = split_list(out, [num_carry])
return out_carry, y
main = trace.main
new_body = callback_transform(new_body, main.callback, strip_calls=main.strip_calls) # type: ignore
in_tree = tree_structure(tuple(carry_avals + xs_avals))
new_jaxpr, new_consts, _ = lcf._initial_style_jaxpr(
new_body, in_tree, tuple(carry_avals + x_avals))
vals = tuple(it.chain(new_consts, carry_vals, xs_vals))
out_vals = lax.scan_p.bind(*vals, reverse=reverse, length=length,
num_consts=len(new_consts), num_carry=num_carry,
jaxpr=new_jaxpr, linear=linear, unroll=unroll)
return safe_map(trace.pure, out_vals)

custom_callback_rules[lax.scan_p] = _scan_callback_rule


def _while_callback_rule(trace, *tracers, cond_jaxpr, body_jaxpr,
cond_nconsts, body_nconsts):
cond_const_tracers, body_const_tracers, init_tracers = split_list(
tracers, [cond_nconsts, body_nconsts])
init_avals = safe_map(lambda x: x.aval, init_tracers)
cond_const_vals, body_const_vals, init_vals = tree_map(
lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers))

body_fun = jaxpr_as_fun(body_jaxpr)
cond_fun = jaxpr_as_fun(cond_jaxpr)

def cond(*carry):
return cond_fun(*it.chain(cond_const_vals, carry))

def body(*carry):
return body_fun(*it.chain(body_const_vals, carry))

main = trace.main
new_cond = callback_transform(cond, main.callback, strip_calls=main.strip_calls) # type: ignore
new_body = callback_transform(body, main.callback, strip_calls=main.strip_calls) # type: ignore
in_tree = tree_structure(init_avals)

new_cond_jaxpr, new_cond_consts, _ = lcf._initial_style_jaxpr(new_cond, in_tree, tuple(init_avals))
new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr(new_body, in_tree, tuple(init_avals))
out = lcf.while_p.bind(
*it.chain(new_cond_consts, new_body_consts, init_vals),
cond_nconsts=len(new_cond_consts),
body_nconsts=len(new_body_consts),
cond_jaxpr=new_cond_jaxpr,
body_jaxpr=new_body_jaxpr)
return safe_map(trace.pure, out)

custom_callback_rules[lax.while_p] = _while_callback_rule
57 changes: 57 additions & 0 deletions tests/callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,63 @@ def f(x):
rewrite(f, {})(x),
jnp.array([2.0, 4.0]))

def testRewriteThroughScan(self):
def f(xs):
def body(carry, x):
carry = carry * 2.
return carry, x - 2.
return lax.scan(body, 1., xs)

xs = jnp.arange(4.)
carry, ys = f(xs)
self.assertAllClose(carry, 16.)
self.assertAllClose(ys, jnp.arange(4.) - 2.)

rewrites = {
lax.mul_p: lambda x, y: x + y,
lax.sub_p: lambda x, y: x / y
}
carry, ys = rewrite(f, rewrites)(xs)
self.assertAllClose(carry, 1. + 8.)
self.assertAllClose(ys, jnp.arange(4.) / 2.)


def testRewriteThroughWhile(self):
def f(x):
def cond(x):
return x < 5
def body(x):
return x + 1
return lax.while_loop(cond, body, x)

x = 0
self.assertAllClose(f(x), 5)

rewrites = {
lax.add_p: lambda x, y: x + y + 100,
}
self.assertAllClose(rewrite(f, rewrites)(x), 101)

rewrites = {
lax.lt_p: lambda x, y: x < y + 5
}
self.assertAllClose(rewrite(f, rewrites)(x), 10)


def testRewriteThroughForLoop(self):
def f(x):
def body(i, x):
return x * i
return lax.fori_loop(1, 5, body, x)

x = 1
self.assertAllClose(f(x), 24)

rewrites = {
lax.mul_p: lambda x, y: x + y
}
self.assertAllClose(rewrite(f, rewrites)(x), 11)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit ddaef19

Please sign in to comment.