Skip to content

Commit

Permalink
improve and add to closure_convert testing
Browse files Browse the repository at this point in the history
* Test closure conversion with mixed values in the closure, one
  participating in AD and the other not.
* Simplify the basic closure_convert test and give its intermediates
  more descriptive names.
  • Loading branch information
froystig committed Jul 27, 2021
1 parent 258ae44 commit 52f0cbe
Showing 1 changed file with 63 additions and 21 deletions.
84 changes: 63 additions & 21 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5211,40 +5211,82 @@ def f(x):
self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False)

def test_closure_convert(self):
def minimize(objective_fn, x0):
converted_fn, aux_args = api.closure_convert(objective_fn, x0)
return _minimize(converted_fn, x0, *aux_args)
def cos_after(fn, x):
converted_fn, aux_args = api.closure_convert(fn, x)
self.assertLessEqual(len(aux_args), 1)
return _cos_after(converted_fn, x, *aux_args)

@partial(api.custom_vjp, nondiff_argnums=(0,))
def _minimize(objective_fn, x0, *args):
_ = objective_fn(x0, *args)
return jnp.cos(x0)
def _cos_after(fn, x, *args):
return jnp.cos(fn(x, *args))

def fwd(objective_fn, x0, *args):
y = _minimize(objective_fn, x0, *args)
return y, (y, args)
def fwd(fn, x, *args):
y = _cos_after(fn, x, *args)
return y, (x, args)

def rev(objective_fn, res, g):
y, args = res
x0_bar = 17. * y
def rev(fn, res, g):
x, args = res
x_bar = 17. * x
args_bars = [42. * a for a in args]
return (x0_bar, *args_bars)
return (x_bar, *args_bars)

_minimize.defvjp(fwd, rev)
_cos_after.defvjp(fwd, rev)

def obj(c, x):
def dist(c, x):
return jnp.sum((x - c) ** 2.)

def solve(c, x):
def closure(x):
return obj(c, x)
return jnp.sum(minimize(closure, x))
return dist(c, x)
return cos_after(closure, x)

c, x = jnp.ones(2), jnp.zeros(2)
self.assertAllClose(solve(c, x), 2.0, check_dtypes=False)
c, x = 2. * jnp.ones(2), jnp.ones(2)
expected = jnp.cos(dist(c, x))
self.assertAllClose(solve(c, x), expected, check_dtypes=False)
g_c, g_x = api.grad(solve, argnums=(0, 1))(c, x)
self.assertAllClose(g_c, 42. * jnp.ones(2), check_dtypes=False)
self.assertAllClose(g_x, 17. * jnp.ones(2), check_dtypes=False)
self.assertAllClose(g_c, 42. * c, check_dtypes=False)
self.assertAllClose(g_x, 17. * x, check_dtypes=False)

def test_closure_convert_mixed_consts(self):
# Like test_closure_convert, but close over values that
# participate in AD as well as values that do not.
# See https://github.com/google/jax/issues/6415

def cos_after(fn, x):
converted_fn, aux_args = api.closure_convert(fn, x)
self.assertLessEqual(len(aux_args), 1)
return _cos_after(converted_fn, x, *aux_args)

@partial(api.custom_vjp, nondiff_argnums=(0,))
def _cos_after(fn, x, *args):
return jnp.cos(fn(x, *args))

def fwd(fn, x, *args):
y = _cos_after(fn, x, *args)
return y, (x, args)

def rev(fn, res, g):
x, args = res
x_bar = 17. * x
args_bars = [42. * a for a in args]
return (x_bar, *args_bars)

_cos_after.defvjp(fwd, rev)

def dist(c, s, x):
return jnp.sum(s * (x - c) ** 2.)

def solve(c, s, x):
def closure(x):
return dist(c, s, x)
return cos_after(closure, x)

c, s, x = 2. * jnp.ones(2), 3. * jnp.ones(2), jnp.ones(2)
expected = jnp.cos(dist(c, s, x))
self.assertAllClose(solve(c, s, x), expected, check_dtypes=False)
g_c, g_x = api.grad(solve, argnums=(0, 2))(c, s, x)
self.assertAllClose(g_c, 42. * c, check_dtypes=False)
self.assertAllClose(g_x, 17. * x, check_dtypes=False)


class CustomTransposeTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 52f0cbe

Please sign in to comment.