From b2de101be739e43f819870780d9213d81e9d7145 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Fri, 28 Jan 2022 09:48:23 -0800 Subject: [PATCH] require consistent output structure in custom vmap rules ... not always a sequence. --- jax/_src/custom_batching.py | 60 +++++++++++++-------------- tests/api_test.py | 81 +++++++++++++++++++------------------ 2 files changed, 72 insertions(+), 69 deletions(-) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 41a4512e07a1..6cd7e69dc098 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -14,7 +14,7 @@ import functools import operator -from typing import Callable, Optional, Sequence +from typing import Callable, Optional import jax from jax import core @@ -71,7 +71,8 @@ def __call__(self, *args, **kwargs): out_flat = custom_vmap_p.bind(*consts, *args_flat, call=closed_call, rule=self.vmap_rule, - in_tree=in_tree) + in_tree=in_tree, + out_tree=out_tree()) return tree_unflatten(out_tree(), out_flat) @@ -85,24 +86,21 @@ def rule_name(rule): return getattr(rule, '__name__', '') def call_rule(rule, axis_size, in_batched, args): - outs, out_batched = rule(axis_size, ensure_list(in_batched), *args) - if not isinstance(outs, Sequence): - raise TypeError( - 'custom vmap rule output values must be a sequence, ' - f'rule ({rule_name(rule)}) returned {type(outs)}') - if not isinstance(out_batched, Sequence): - raise TypeError( - 'custom vmap rule output batching specification must be a sequence, ' - f'rule ({rule_name(rule)}) returned {type(out_batched)}') - return ensure_list(outs), ensure_list(out_batched) - -def check_vmap_rule_trees(rule, out_tree, out_batched_tree): + return rule(axis_size, ensure_list(in_batched), *args) + +def check_vmap_rule_trees(rule, original_out_tree, out_tree, out_batched_tree): if out_tree != out_batched_tree: raise ValueError( - 'structure of output values and output batching specification returned ' + 'structure of output value and output batching specification returned ' f'by custom vmap rule ({rule_name(rule)}) do not match.\n' f'Output values: {out_tree}\n' f'Batching spec: {out_batched_tree}') + if out_tree != original_out_tree: + raise ValueError( + f'structure of output returned by custom vmap rule ({rule_name(rule)}) ' + 'does not match that of original custom-vmapped function.\n' + f'Original output: {original_out_tree}\n' + f'Rule output: {out_tree}') # Like batching.bdim_at_front, but doesn't broadcast if not mapped def maybe_bdim_at_front(x, bdim): @@ -127,12 +125,12 @@ def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size): ### custom_vmap_p rules -def custom_vmap_impl(*args, call, rule, in_tree): - del rule, in_tree +def custom_vmap_impl(*args, call, rule, in_tree, out_tree): + del rule, in_tree, out_tree return core.jaxpr_as_fun(call)(*args) -def custom_vmap_batching(args_flat, dims, *, call, rule, in_tree): +def custom_vmap_batching(args_flat, dims, *, call, rule, in_tree, out_tree): del call axis_size, = {x.shape[d] for x, d in zip(args_flat, dims) if d is not None} args_flat = map(maybe_bdim_at_front, args_flat, dims) @@ -140,10 +138,10 @@ def custom_vmap_batching(args_flat, dims, *, call, rule, in_tree): args = tree_unflatten(in_tree, args_flat) in_batched = tree_unflatten(in_tree, flat_in_batched) - outs, out_batched = call_rule(rule, axis_size, in_batched, args) - flat_outs, tree1 = tree_flatten(outs) + out, out_batched = call_rule(rule, axis_size, in_batched, args) + flat_outs, tree1 = tree_flatten(out) flat_out_batched, tree2 = tree_flatten(out_batched) - check_vmap_rule_trees(rule, tree1, tree2) + check_vmap_rule_trees(rule, out_tree, tree1, tree2) flat_out_dims = [0 if b else not_mapped for b in flat_out_batched] return flat_outs, flat_out_dims @@ -152,7 +150,7 @@ def custom_vmap_abstract_eval(*in_avals, call, **_): return call.out_avals -def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree): +def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree): def jvp_of_rule_rule(axis_size, in_batched, primals, tangents): in_batched_ps, in_batched_ts = in_batched @@ -175,16 +173,16 @@ def jvp_of_rule_rule(axis_size, in_batched, primals, tangents): del tree_ps_ts2 def to_jvp(*primals): - outs, out_batched = call_rule(rule, axis_size, mutually_batched, primals) + out, out_batched = call_rule(rule, axis_size, mutually_batched, primals) check_vmap_rule_trees( - rule, tree_structure(outs), tree_structure(out_batched)) + rule, out_tree, tree_structure(out), tree_structure(out_batched)) out_mutually_batched.store(out_batched) - return outs + return out def to_vmap_over_extra_batched_dims(primals, tangents): return jax.jvp(to_jvp, primals, tangents) - to_vmap_over_extra_batched_dims_flat, out_tree = flatten_fun_nokwargs( + to_vmap_over_extra_batched_dims_flat, out_tree2 = flatten_fun_nokwargs( lu.wrap_init(to_vmap_over_extra_batched_dims), tree_ps_ts) @@ -203,9 +201,9 @@ def to_vmap_over_extra_batched_dims(primals, tangents): flat_out_extra_batched_ts = [d is not not_mapped for d in flat_out_axes_t] out_ps, out_ts = tree_unflatten( - out_tree(), [*flat_out_ps, *flat_out_ts]) + out_tree2(), [*flat_out_ps, *flat_out_ts]) out_extra_batched_ps, out_extra_batched_ts = tree_unflatten( - out_tree(), [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts]) + out_tree2(), [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts]) out_batched_ps = tree_map( operator.or_, out_mutually_batched.val, out_extra_batched_ps) @@ -217,9 +215,11 @@ def to_vmap_over_extra_batched_dims(primals, tangents): tangents = map(ad.instantiate_zeros, tangents) jvp_call, _ = ad.jvp_jaxpr(call, [True] * len(primals), True) jvp_in_tree = treedef_tuple((in_tree, in_tree)) + jvp_out_tree = treedef_tuple((out_tree, out_tree)) outs = custom_vmap_p.bind( *primals, *tangents, - call=jvp_call, rule=jvp_of_rule_rule, in_tree=jvp_in_tree) + call=jvp_call, rule=jvp_of_rule_rule, + in_tree=jvp_in_tree, out_tree=jvp_out_tree) assert len(outs) % 2 == 0, len(outs) out_primals, out_tangents = util.split_list(outs, [len(outs) // 2]) return out_primals, out_tangents @@ -265,6 +265,6 @@ def to_map(mapped_args): mapped_args, bcast_args = tree_split(in_batched, list(args)) out = jax.lax.map(to_map, mapped_args) out_batched = tree_map(lambda _: True, out) - return [out], [out_batched] + return out, out_batched return f diff --git a/tests/api_test.py b/tests/api_test.py index 7c8006d91025..f25ce8512a4a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -6971,9 +6971,10 @@ def f(x): return jnp.sin(x) @f.def_vmap def rule(axis_size, in_batched, xs): - self.assertEqual(in_batched, [True]) + xs_batched, = in_batched + self.assertEqual(xs_batched, True) self.assertEqual(axis_size, xs.shape[0]) - return [jnp.cos(xs)], in_batched + return jnp.cos(xs), xs_batched x, xs = jnp.array(1.), jnp.arange(3) y = f(x) @@ -6981,6 +6982,22 @@ def rule(axis_size, in_batched, xs): ys = api.vmap(f)(xs) self.assertAllClose(ys, jnp.cos(xs)) + def test_rule_multi_output(self): + @api.custom_vmap + def f(x): return jnp.sin(x), jnp.cos(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + return (jnp.cos(xs), jnp.sin(xs)), tuple(in_batched * 2) + + x, xs = jnp.array(1.), jnp.arange(3) + y1, y2 = f(x) + self.assertAllClose(y1, jnp.sin(x)) + self.assertAllClose(y2, jnp.cos(x)) + ys1, ys2 = api.vmap(f)(xs) + self.assertAllClose(ys1, jnp.cos(xs)) + self.assertAllClose(ys2, jnp.sin(xs)) + def test_nary(self): @api.custom_vmap def f(x, y): return jnp.sin(x) + y ** 2. @@ -6991,7 +7008,7 @@ def rule(axis_size, in_batched, xs, ys): self.assertEqual(axis_size, 3) self.assertEqual(axis_size, xs.shape[0]) self.assertEqual(axis_size, ys.shape[0]) - return [jnp.cos(xs) + ys ** 2.], [True] + return jnp.cos(xs) + ys ** 2., True xs, ys = jnp.arange(3), jnp.arange(3) zs = api.vmap(f)(xs, ys) @@ -7029,7 +7046,7 @@ def vector_dot_vmap_rule(axis_size, in_batched, u, v): out = jnp.sum(u * v, axis=1) else: out = u @ v if u_batched else v @ u - return [out], [u_batched or v_batched] + return out, u_batched or v_batched f = vector_dot v = lambda *shape: jnp.ones(shape) @@ -7056,30 +7073,16 @@ def f(x): return jnp.sin(x) @f.def_vmap def rule(axis_size, in_batched, xs): rule_args.append((axis_size, in_batched)) - return [jnp.cos(xs)], in_batched + return jnp.cos(xs), in_batched[0] xs = jnp.arange(3) _ = api.vmap(f)(xs) (axis_size, in_batched), = rule_args self.assertIs(type(axis_size), int) self.assertIs(type(in_batched), list) + self.assertEqual(len(in_batched), 1) - def test_rule_output_signature_any_sequence(self): - @api.custom_vmap - def f(x): return jnp.sin(x) - - Box = collections.namedtuple('Box', 'value') - - @f.def_vmap - def rule(axis_size, in_batched, xs): - # custom vmap machinery should handle any sequence type for either output - return Box(jnp.cos(xs)), tuple(in_batched) - - xs = jnp.arange(3) - ys = api.vmap(f)(xs) - self.assertAllClose(ys, jnp.cos(xs)) - - def test_rule_output_mismatch(self): + def test_rule_output_vs_batching_output_mismatch(self): @api.custom_vmap def f(x): return jnp.sin(x) @@ -7090,23 +7093,23 @@ def test_rule_abc(axis_size, in_batched, xs): xs = jnp.arange(3) self.assertRaisesRegex( ValueError, - 'structure of output values and output batching specification ' + 'structure of output value and output batching specification ' r'returned by custom vmap rule \(test_rule_abc\) do not match.*', lambda: api.vmap(f)(xs)) - def test_rule_output_array(self): + def test_rule_vs_call_output_mismatch(self): @api.custom_vmap def f(x): return jnp.sin(x) @f.def_vmap - def rule(axis_size, in_batched, xs): - # common to overlook the need to box up single output value in a list - return jnp.cos(xs), in_batched + def test_rule_abc2(axis_size, in_batched, xs): + return [jnp.sin(xs)], in_batched xs = jnp.arange(3) self.assertRaisesRegex( - TypeError, - 'custom vmap rule output values must be a sequence.*', + ValueError, + r'structure of output returned by custom vmap rule \(test_rule_abc2\) ' + r'does not match that of original custom-vmapped function.*', lambda: api.vmap(f)(xs)) def test_jvp_basic(self): @@ -7117,7 +7120,7 @@ def f(x): return jnp.sin(x) def rule(axis_size, in_batched, xs): self.assertEqual(axis_size, 3) self.assertEqual(in_batched, [True]) - return [jnp.cos(xs)], in_batched + return jnp.cos(xs), in_batched[0] f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) @@ -7144,7 +7147,7 @@ def f(x, y): return jnp.sin(x) + y def rule(axis_size, in_batched, xs, ys): self.assertEqual(axis_size, 3) self.assertEqual(in_batched, [True, True]) - return [jnp.cos(xs) + ys], [True] + return jnp.cos(xs) + ys, True f_jvp = lambda x, y, tx, ty: api.jvp(f, [x, y], [tx, ty]) @@ -7167,7 +7170,7 @@ def f(x): return jnp.sin(x) def rule(axis_size, in_batched, xs): self.assertEqual(axis_size, 3) self.assertEqual(in_batched, [False]) - return [jnp.cos(xs)], in_batched + return jnp.cos(xs), in_batched[0] f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) @@ -7186,7 +7189,7 @@ def f(x): return jnp.sin(x) def rule(axis_size, in_batched, xs): self.assertEqual(axis_size, 3) self.assertEqual(in_batched, [False]) - return [jnp.cos(xs)], in_batched + return jnp.cos(xs), in_batched[0] x = jnp.arange(3.) + .72 j = api.jacfwd(f)(x) @@ -7200,7 +7203,7 @@ def f(x): return jnp.sin(x) def rule(axis_size, in_batched, xs): self.assertEqual(axis_size, 3) self.assertEqual(in_batched, [False]) - return [jnp.cos(xs)], in_batched + return jnp.cos(xs), in_batched[0] f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) @@ -7223,14 +7226,14 @@ def f_linear(x): return 7. * x @f_linear.def_vmap def linear_rule(axis_size, in_batched, xs): - return [11. * xs], in_batched + return 11. * xs, in_batched[0] @api.custom_vmap def f_nonlinear(x): return jnp.sin(x) @f_nonlinear.def_vmap def nonlinear_rule(axis_size, in_batched, xs): - return [jnp.cos(xs)], in_batched + return jnp.cos(xs), in_batched[0] f_lin_jvp = lambda x, tx: api.jvp(f_linear, [x], [tx]) f_non_jvp = lambda x, tx: api.jvp(f_nonlinear, [x], [tx]) @@ -7267,7 +7270,7 @@ def f(x): return jnp.sin(x) @f.def_vmap def rule(axis_size, in_batched, xs): - return [cos_with_invalid_dataflow_jvp(xs)], in_batched + return cos_with_invalid_dataflow_jvp(xs), in_batched[0] f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) x, txs = jnp.array(1.), 2. + jnp.arange(3.) @@ -7300,7 +7303,7 @@ def rule(axis_size, in_batched, xs): self.assertEqual(in_batched, [in_batched_ref]) sz, = set([z.shape[0] for z in tree_util.tree_leaves(xs)]) self.assertEqual(axis_size, sz) - return [tree_cos(xs)], in_batched + return tree_cos(xs), in_batched[0] y = f(x) self.assertAllClose(y, tree_sin(x)) @@ -7324,7 +7327,7 @@ def rule(axis_size, in_batched, xs): self.assertEqual(in_batched, [in_batched_ref]) sz, = set([z.shape[0] for z in tree_util.tree_leaves(xs)]) self.assertEqual(axis_size, sz) - return [tree_cos(xs)], in_batched + return tree_cos(xs), in_batched[0] y = f(x) self.assertAllClose(y, tree_sin(x)) @@ -7339,7 +7342,7 @@ def f(x): return jnp.sin(x) def rule(axis_size, in_batched, xs): self.assertEqual(in_batched, [True]) self.assertEqual(axis_size, xs.shape[0]) - return [jnp.cos(xs)], in_batched + return jnp.cos(xs), in_batched[0] x, xs = jnp.array(1.), jnp.arange(3) self.assertAllClose(f(x), jit(f)(x))