Skip to content

Commit

Permalink
custom batching jvp tests
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Jan 6, 2022
1 parent 0ab93a0 commit ad7c7d6
Showing 1 changed file with 174 additions and 0 deletions.
174 changes: 174 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6460,6 +6460,180 @@ def rule(axis_size, in_batched, xs):
'custom vmap rule output values must be a sequence.*',
lambda: api.vmap(f)(xs))

def test_jvp_basic(self):
@api.custom_vmap
def f(x): return jnp.sin(x)

@f.def_vmap
def rule(axis_size, in_batched, xs):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [True])
return [jnp.cos(xs)], in_batched

f_jvp = lambda x, tx: api.jvp(f, [x], [tx])

x, tx = jnp.array(1.), jnp.array(2.)
xs, txs = jnp.arange(3.), jnp.arange(3.) * 2.

y, ty = f_jvp(x, tx)
self.assertAllClose(y, jnp.sin(x))
self.assertAllClose(ty, jnp.cos(x) * tx)

ys, tys = api.vmap(f_jvp)(xs, txs)
self.assertAllClose(ys, jnp.cos(xs))
self.assertAllClose(tys, -jnp.sin(xs) * txs)

ys, tys = api.jvp(api.vmap(f), [xs], [txs])
self.assertAllClose(ys, jnp.cos(xs))
self.assertAllClose(tys, -jnp.sin(xs) * txs)

def test_jvp_nary(self):
@api.custom_vmap
def f(x, y): return jnp.sin(x) + y

@f.def_vmap
def rule(axis_size, in_batched, xs, ys):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [True, True])
return [jnp.cos(xs) + ys], [True]

f_jvp = lambda x, y, tx, ty: api.jvp(f, [x, y], [tx, ty])

x, y, tx, ty = jnp.arange(4.)
xs, ys, txs, tys = 4. + jnp.arange(3. * 4).reshape((4, 3))

zs, tzs = api.vmap(f_jvp)(xs, ys, txs, tys)
self.assertAllClose(zs, jnp.cos(xs) + ys)
self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys)

zs, tzs = api.jvp(api.vmap(f), [xs, ys], [txs, tys])
self.assertAllClose(zs, jnp.cos(xs) + ys)
self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys)

def test_jvp_extra_batched_tangents(self):
@api.custom_vmap
def f(x): return jnp.sin(x)

@f.def_vmap
def rule(axis_size, in_batched, xs):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [False])
return [jnp.cos(xs)], in_batched

f_jvp = lambda x, tx: api.jvp(f, [x], [tx])

x, txs = jnp.array(1.), 2. + jnp.arange(3.)
y, tys = api.vmap(f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs)
self.assertAllClose(y, jnp.cos(x))
self.assertAllClose(tys, -jnp.sin(x) * txs)

def test_jacfwd(self):
# jacfwd is another way to exercise extra-batched tangents

@api.custom_vmap
def f(x): return jnp.sin(x)

@f.def_vmap
def rule(axis_size, in_batched, xs):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [False])
return [jnp.cos(xs)], in_batched

x = jnp.arange(3.) + .72
j = api.jacfwd(f)(x)
self.assertAllClose(j, -jnp.diag(jnp.sin(x)))

def test_jvp_extra_batched_primals(self):
@api.custom_vmap
def f(x): return jnp.sin(x)

@f.def_vmap
def rule(axis_size, in_batched, xs):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [False])
return [jnp.cos(xs)], in_batched

f_jvp = lambda x, tx: api.jvp(f, [x], [tx])

xs, tx = jnp.arange(3.), jnp.array(4.)
ys, tys = api.vmap(f_jvp, in_axes=(0, None))(xs, tx)
self.assertAllClose(ys, jnp.cos(xs))
self.assertAllClose(tys, -jnp.sin(xs) * tx)

def test_jvp_extra_batched_primals_with_linear_vmap_rule(self):
# When a function is linear, its Jacobian is constant. JAX's JVP
# of linear functions takes advantage of this: when mapping over a
# batch of primals relative to a fixed (i.e. symbolically
# replicated) tangent, output tangents remain replicated as well
# (i.e. JAX will not broadcast them). This is true in general, and
# this test checks that vmapped JVPs continue to behave this way
# when custom_vmap is involved and the custom vmap rule is linear.

@api.custom_vmap
def f_linear(x): return 7. * x

@f_linear.def_vmap
def linear_rule(axis_size, in_batched, xs):
return [11. * xs], in_batched

@api.custom_vmap
def f_nonlinear(x): return jnp.sin(x)

@f_nonlinear.def_vmap
def nonlinear_rule(axis_size, in_batched, xs):
return [jnp.cos(xs)], in_batched

f_lin_jvp = lambda x, tx: api.jvp(f_linear, [x], [tx])
f_non_jvp = lambda x, tx: api.jvp(f_nonlinear, [x], [tx])
xs, tx = jnp.arange(3.), jnp.array(4.)

# doesn't err
_ = api.vmap(f_lin_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx)

# does err
self.assertRaisesRegex(
ValueError, 'vmap has mapped output but out_axes is None',
lambda: api.vmap(
f_non_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx))

def test_jvp_dataflow_violation(self):
# The jvp-of-custom-vmap machinery should not assume the standard
# dataflow constraint on the JVP of the custom vmap rule (primal
# outputs independent of tangent inputs). Both jvp and vmap are
# "forward" transformations under which, at present, we don't
# enforce the JVP dependence diagram. Because output primals can
# depend on input tangents, extra-batched input tangents can
# create batched output primals, as this test checks.

@api.custom_jvp
def cos_with_invalid_dataflow_jvp(x): return jnp.cos(x)

@cos_with_invalid_dataflow_jvp.defjvp
def invalid_dataflow_jvp(x, tx):
[x], [tx] = x, tx
return jnp.cos(x * tx), tx

@api.custom_vmap
def f(x): return jnp.sin(x)

@f.def_vmap
def rule(axis_size, in_batched, xs):
return [cos_with_invalid_dataflow_jvp(xs)], in_batched

f_jvp = lambda x, tx: api.jvp(f, [x], [tx])
x, txs = jnp.array(1.), 2. + jnp.arange(3.)

# doesn't err
ys, tys = api.vmap(f_jvp, in_axes=(None, 0))(x, txs)
self.assertAllClose(ys, jnp.cos(x * txs))
self.assertAllClose(tys, txs)

# does err
self.assertRaisesRegex(
ValueError, 'vmap has mapped output but out_axes is None',
lambda: api.vmap(
f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs))


class InvertibleADTest(jtu.JaxTestCase):

Expand Down

0 comments on commit ad7c7d6

Please sign in to comment.