diff --git a/jax/api.py b/jax/api.py index 9b8756eb3536..5178b3ea1135 100644 --- a/jax/api.py +++ b/jax/api.py @@ -1831,56 +1831,3 @@ def abstractify(x): out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat)) out = [ShapeDtypeStruct(x.shape, x.dtype) for x in out] return tree_unflatten(out_tree(), out) - - -def _custom_implicit_solve(solve, tangent_solve): - """Define gradients for a function that performs an implicit solve. - - Note: this isn't ready for widespread use yet -- it does not handle closed - over values inside solve yet. - - Args: - solve: callable that takes two positional arguments, func and params, and - returns a solution such that func(params, solution) = 0. In other words, - the following is assumed to be true (but not checked): - solution = solve(func, params) - error = func(solution, params) - assert tree_all(tree_map(partial(np.allclose, 0.0), error) - tangent_solve: callable that takes two positional arguments, a linear - function ``f`` and (possibly nested) array(s) ``y``, and returns a - solution ``x`` such that ``f(x)=y``: - - - For scalar ``y``, use ``lambda f, y: y / f(1.0)``. - - For vector ``y``, you could use a linear solve with the Jacobian, if - dimensionality of ``y`` is not too large: - ``lambda f, y: np.linalg.solve(jacobian(f)(y), y)``. - - Returns: - Wrapped version of solve with JVP and VJPs defined with respect to - ``params`` via implicit differentaion, rather than differntiating through - the solve. - """ - @wraps(solve) - def wrapper(func, params): - - @custom_transforms - def solve_impl(params): - return solve(func, params) - - @partial(defjvp_all, solve_impl) - def solve_impl_jvp(primals, tangents): - # F(u(m), m) = 0 # system of equations in m - # ∂_0 F(u(m), m) ∂ u(m) + ∂_1 F(u(m), m) = 0 - # ∂ u(m) = - (∂_0 F(u*, m))^{-1} ∂_1 F(u*, m) - params, = primals - grad_params, = tangents - solution = solve_impl(params) - unchecked_zeros, f_jvp = vjp(func, solution, params) - grad_solution = tree_map( - lambda x: -x, - tangent_solve(lambda p: f_jvp(p)[0], f_jvp(grad_params)[1]) - ) - return solution, grad_solution - - return solve_impl(params) - return wrapper diff --git a/jax/lax/lax_control_flow.py b/jax/lax/lax_control_flow.py index 4558e47d112f..a609950a26f9 100644 --- a/jax/lax/lax_control_flow.py +++ b/jax/lax/lax_control_flow.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -31,7 +32,7 @@ from jax.lax import lax from jax import linear_util as lu from jax.abstract_arrays import ShapedArray, raise_to_shaped -from jax.api_util import flatten_fun_nokwargs +from jax.api_util import flatten_fun_nokwargs, apply_flat_fun_nokwargs from jax.interpreters import ad from jax.interpreters import partial_eval as pe from jax.interpreters import xla @@ -42,7 +43,7 @@ from jax.util import (partial, unzip2, safe_map, safe_zip, split_list, split_dict, cache) from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf, - treedef_children) + treedef_children, tree_map) from jax import ad_util _map = safe_map @@ -829,3 +830,127 @@ def body(i, dst): return fori_loop(0, num, body, dst) masking.masking_rules[lax.concatenate_p] = _concat_masking_rule + + +def root(f, initial_guess, solve, tangent_solve): + """Differentiably solve for a roots of a function. + + This is a low-level routine, mostly intended for internal use in JAX. + Gradients of root() are defined with respect to closed-over variables from + the provided function f. + + Args: + f: function for which to find a root. Should accept a single argument, + return a tree of arrays with the same structure as its input. + initial_guess: initial guess for a zero of f. + solve: function to solve for the roots of f. Should take two positional + arguments, f and initial_guess, and return a solution with the same + structure as initial_guess such that func(solution) = 0. In other words, + the following is assumed to be true (but not checked):: + + solution = solve(f, initial_guess) + error = f(solution) + assert all(error == 0) + + tangent_solve: function to solve the tangent system. Should take two + positional arguments, a linear function ``g`` (the function ``f`` + linearized at its root) and a tree of array(s) ``y`` with the same + structure as initial_guess, and return a solution ``x`` such that + ``g(x)=y``: + + - For scalar ``y``, use ``lambda g, y: y / g(1.0)``. + - For vector ``y``, you could use a linear solve with the Jacobian, if + dimensionality of ``y`` is not too large: + ``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``. + + Returns: + The result of calling solve(f, initial_guess) with gradients defined via + implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``. + """ + guess_flat, in_args_tree = tree_flatten((initial_guess,)) + guess_avals = tuple(_map(_abstractify, guess_flat)) + jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_args_tree, guess_avals) + in_tree, = treedef_children(in_args_tree) + if in_tree != out_tree: + raise TypeError( + "f() output pytree structure must match initial_guess, got {} and {}." + .format(out_tree, in_tree) + ) + out_flat = root_p.bind(*itertools.chain(consts, guess_flat), + tree=out_tree, num_consts=len(consts), + jaxpr=jaxpr, solve=solve, tangent_solve=tangent_solve) + return tree_unflatten(out_tree, out_flat) + + +def _root_abstract_eval(*args, **kwargs): + return args[kwargs['num_consts']:] + + +def _root_impl(*args, **kwargs): + tree, num_consts, jaxpr, solve, _ = split_dict( + kwargs, ['tree', 'num_consts', 'jaxpr', 'solve', 'tangent_solve']) + + f = partial( + apply_flat_fun_nokwargs, + partial(core.jaxpr_as_fun(jaxpr), *args[:num_consts]), + (tree, tree), + ) + initial_guess = tree_unflatten(tree, args[num_consts:]) + out = solve(f, initial_guess) + + out_flat, out_tree = tree_flatten(out) + if out_tree != tree: + raise TypeError( + "solve() output pytree structure must match initial_guess, got {} and {}" + .format(out_tree, tree)) + + return out_flat + + +def _root_jvp( + primals, tangents, tree, num_consts, jaxpr, solve, tangent_solve): + params = primals[:num_consts] + solution = tuple( + root_p.bind(*primals, tree=tree, num_consts=num_consts, + jaxpr=jaxpr, solve=solve, tangent_solve=tangent_solve) + ) + + params_dot = tangents[:num_consts] + + # F(u(m), m) = 0 # system of equations in m + # ∂_0 F(u(m), m) ∂ u(m) + ∂_1 F(u(m), m) = 0 + # ∂ u(m) = - (∂_0 F(u*, m))^{-1} ∂_1 F(u*, m) + unchecked_zeros, f_jvp = api.linearize( + core.jaxpr_as_fun(jaxpr), *(params + solution) + ) + + params_zeros = tuple(_map(ad_util.zeros_like_jaxval, params)) + solution_zeros = tuple(_map(ad_util.zeros_like_jaxval, solution)) + + f_linearized_at_solution = partial( + apply_flat_fun_nokwargs, partial(f_jvp, *params_zeros), (tree, tree), + ) + rhs = tree_unflatten(tree, f_jvp(*(params_dot + solution_zeros))) + solution_dot = tree_map( + operator.neg, tangent_solve(f_linearized_at_solution, rhs) + ) + + solution_dot_flat, out_tree = tree_flatten(solution_dot) + if out_tree != tree: + raise TypeError( + "tangent_solve() output pytree structure must match initial_guess, " + "got {} and {}".format(out_tree, tree)) + + return solution, solution_dot_flat + +def _root_batch(args, dims, **params): + return batching.batch_fun(lu.wrap_init(_root_impl, params), args, dims) + + +root_p = core.Primitive('root') +root_p.multiple_results = True +root_p.def_impl(_root_impl) +root_p.def_abstract_eval(_root_abstract_eval) +ad.primitive_jvps[root_p] = _root_jvp +xla.initial_style_translations[root_p] = xla.lower_fun(_root_impl, initial_style=True) +batching.primitive_batchers[root_p] = _root_batch diff --git a/tests/api_test.py b/tests/api_test.py index f5a1d6cc71b5..aee0c346d911 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -920,41 +920,6 @@ def f(x): xla_comp = api.xla_computation(f) xla_comp(np.arange(8)).GetHloText() # doesn't crash - def test_custom_implicit_solve(self): - - def scalar_solve(f, y): - return y / f(1.0) - - def _binary_search(func, params, low=0.0, high=100.0, tolerance=1e-6): - def cond(state): - low, high = state - return high - low > tolerance - - def body(state): - low, high = state - midpoint = 0.5 * (low + high) - update_upper = func(midpoint, params) > 0 - low = np.where(update_upper, low, midpoint) - high = np.where(update_upper, midpoint, high) - return (low, high) - - solution, _ = lax.while_loop(cond, body, (low, high)) - return solution - - binary_search = api._custom_implicit_solve(_binary_search, scalar_solve) - sqrt_cubed = lambda y, x: y ** 2 - x ** 3 - value, grad = api.value_and_grad(binary_search, argnums=1)(sqrt_cubed, 5.0) - self.assertAllClose(value, 5 ** 1.5, check_dtypes=False) - self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False) - - def scalar_solve2(f, y): - y_1d = y[np.newaxis] - return np.linalg.solve(api.jacobian(f)(y_1d), y_1d).squeeze() - - binary_search = api._custom_implicit_solve(_binary_search, scalar_solve2) - grad = api.grad(binary_search, argnums=1)(sqrt_cubed, 5.0) - self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False) - def test_jit_device(self): device = xb.devices()[-1] x = api.jit(lambda x: x, device=device)(3.) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 20b68317bdc4..10baa9725ef2 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -19,6 +19,7 @@ import collections from functools import partial import itertools +import re from unittest import SkipTest from absl.testing import absltest @@ -992,6 +993,79 @@ def fun(carry, _): api.grad(lambda x: jit_run_scan(x))(0.) # doesn't crash + def test_root_scalar(self): + + def scalar_solve(f, y): + return y / f(1.0) + + def binary_search(func, x0, low=0.0, high=100.0, tolerance=1e-6): + del x0 # unused + + def cond(state): + low, high = state + return high - low > tolerance + + def body(state): + low, high = state + midpoint = 0.5 * (low + high) + update_upper = func(midpoint) > 0 + low = np.where(update_upper, low, midpoint) + high = np.where(update_upper, midpoint, high) + return (low, high) + + solution, _ = lax.while_loop(cond, body, (low, high)) + return solution + + def sqrt_cubed(x, tangent_solve=scalar_solve): + f = lambda y: y ** 2 - x ** 3 + return lax.root(f, 0.0, binary_search, tangent_solve) + + value, grad = api.value_and_grad(sqrt_cubed)(5.0) + self.assertAllClose(value, 5 ** 1.5, check_dtypes=False) + self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False) + + jtu.check_grads(sqrt_cubed, (5.0,), order=2, rtol=1e-3) + + inputs = np.array([4.0, 5.0]) + results = api.vmap(sqrt_cubed)(inputs) + self.assertAllClose(results, inputs ** 1.5, check_dtypes=False) + + results = api.jit(sqrt_cubed)(5.0) + self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False) + + def test_root_vector(self): + + def oracle(func, x0): + del func # unused + return x0 + + def vector_solve(f, y): + return np.linalg.solve(api.jacobian(f)(y), y) + + def linear_solve(a, b): + f = lambda y: np.dot(a, y) - b + x0 = np.linalg.solve(a, b) + return lax.root(f, x0, oracle, vector_solve) + + rng = onp.random.RandomState(0) + a = rng.randn(2, 2) + b = rng.randn(2) + jtu.check_grads(linear_solve, (a, b), order=2) + + def test_root_errors(self): + with self.assertRaisesRegex(TypeError, re.escape("f() output pytree")): + lax.root(lambda x: (x, x), 0.0, lambda f, x: x, lambda f, x: x) + with self.assertRaisesRegex(TypeError, re.escape("solve() output pytree")): + lax.root(lambda x: x, 0.0, lambda f, x: (x, x), lambda f, x: x) + + def dummy_root_usage(x): + f = lambda y: x - y + return lax.root(f, 0.0, lambda f, x: x, lambda f, x: (x, x)) + + with self.assertRaisesRegex( + TypeError, re.escape("tangent_solve() output pytree")): + api.jvp(dummy_root_usage, (0.0,), (0.0,)) + if __name__ == '__main__': absltest.main()