Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom derivative for np.linalg.det #2809

Merged
merged 7 commits into from Apr 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
100 changes: 100 additions & 0 deletions jax/numpy/linalg.py
Expand Up @@ -148,12 +148,112 @@ def _slogdet_jvp(primals, tangents):
return (sign, ans), (sign_dot, ans_dot)


def _cofactor_solve(a, b):
"""Equivalent to det(a)*solve(a, b) for nonsingular mat.

Intermediate function used for jvp and vjp of det.
This function borrows heavily from jax.numpy.linalg.solve and
jax.numpy.linalg.slogdet to compute the gradient of the determinant
in a way that is well defined even for low rank matrices.

This function handles two different cases:
* rank(a) == n or n-1
* rank(a) < n-1

For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
Rather than computing det(a)*solve(a, b), which would return NaN, we work
directly with the LU decomposition. If a = p @ l @ u, then
det(a)*solve(a, b) =
prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
If a is rank n-1, then the lower right corner of u will be zero and the
triangular_solve will fail.
Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
Then y_{nn} =
x_{nn} / u_{nn} * prod_{i=1...n}(u_{ii}) =
x_{nn} * prod_{i=1...n-1}(u_{ii})
So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
we can avoid the triangular_solve failing.
To correctly compute the rest of x_{ii} for i != n, we simply multiply
x_{ii} by det(a) for all i != n, which will be zero if rank(a) = n-1.

For the second case, a check is done on the matrix to see if `solve`
returns NaN or Inf, and gives a matrix of zeros as a result, as the
gradient of the determinant of a matrix with rank less than n-1 is 0.
This will still return the correct value for rank n-1 matrices, as the check
is applied *after* the lower right corner of u has been updated.

Args:
a: A square matrix or batch of matrices, possibly singular.
b: A matrix, or batch of matrices of the same dimension as a.

Returns:
det(a) and cofactor(a)^T*b, aka adjugate(a)*b
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have a reference on the method you use here for calculating the adjugate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. Derived it myself. I could add a short description to the docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"""
a = _promote_arg_dtypes(np.asarray(a))
b = _promote_arg_dtypes(np.asarray(b))
a_shape = np.shape(a)
b_shape = np.shape(b)
a_ndims = len(a_shape)
b_ndims = len(b_shape)
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
and b_shape[-2:] == a_shape[-2:]):
msg = ("The arguments to _cofactor_solve must have shapes "
"a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
raise ValueError(msg.format(a_shape, b_shape))
if a_shape[-1] == 1:
return a[0, 0], b
# lu contains u in the upper triangular matrix and l in the strict lower
# triangular matrix.
# The diagonal of l is set to ones without loss of generality.
lu, pivots = lax_linalg.lu(a)
dtype = lax.dtype(a)
batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
x = np.broadcast_to(b, batch_dims + b.shape[-2:])
lu = np.broadcast_to(lu, batch_dims + lu.shape[-2:])
Comment on lines +211 to +213
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably already have this working, but I'll note that you could probably simplify this considerably if you make use of jnp.vectorize() to adding batch dimensions with vmap instead of writing them by hand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is taken pretty much line-by-line from an older version of jnp.linalg.solve. Is that what is used there now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually you'd previously mentioned some things about changes to solve related to custom gradients, maybe worth taking a look at the new version to see if we can port some of those changes over. This also still doesn't work with gradients-of-gradients, so perhaps some of the tricks from custom_solve can be useful there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a quick look - it seems pretty complicated! I'll keep this in mind for the future - but as you said, this is working now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, we're using vectorize inside solve now, in particular for the _lu_solve_core helper function:

@partial(vectorize, excluded={3}, signature='(n,n),(n),(n,k)->(n,k)')

I'll have to think a little bit more about the custom_linear_solve thing....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with jnp.vectorize. Is it really just a matter of wrapping the function call with that decorator and removing the 3 lines you highlighted?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's roughly my understanding. We can iterate on that in follow-up PRs though!

# Compute (partial) determinant, ignoring last diagonal of LU
diag = np.diagonal(lu, axis1=-2, axis2=-1)
parity = np.count_nonzero(pivots != np.arange(a_shape[-1]), axis=-1)
sign = np.array(-2 * (parity % 2) + 1, dtype=dtype)
# partial_det[:, -1] contains the full determinant and
# partial_det[:, -2] contains det(u) / u_{nn}.
partial_det = np.cumprod(diag, axis=-1) * sign[..., None]
lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2])
permutation = lax_linalg.lu_pivots_to_permutation(pivots, a_shape[-1])
permutation = np.broadcast_to(permutation, batch_dims + (a_shape[-1],))
iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1,)))
# filter out any matrices that are not full rank
d = np.ones(x.shape[:-1], x.dtype)
d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
d = np.any(np.logical_or(np.isnan(d), np.isinf(d)), axis=-1)
d = np.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:])
x = np.where(d, np.zeros_like(x), x) # first filter
x = x[iotas[:-1] + (permutation, slice(None))]
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
unit_diagonal=True)
x = np.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None],
x[..., -1:, :]), axis=-2)
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
x = np.where(d, np.zeros_like(x), x) # second filter

return partial_det[..., -1], x


@custom_jvp
@_wraps(onp.linalg.det)
def det(a):
sign, logdet = slogdet(a)
return sign * np.exp(logdet)


@det.defjvp
def _det_jvp(primals, tangents):
x, = primals
g, = tangents
y, z = _cofactor_solve(x, g)
return y, np.trace(z, axis1=-1, axis2=-2)


@_wraps(onp.linalg.eig)
def eig(a):
a = _promote_arg_dtypes(np.asarray(a))
Expand Down
35 changes: 35 additions & 0 deletions tests/linalg_test.py
Expand Up @@ -109,6 +109,41 @@ def testDet(self, n, dtype, rng_factory):
def testDetOfSingularMatrix(self):
x = np.array([[-1., 3./2], [2./3, -1.]], dtype=onp.float32)
self.assertAllClose(onp.float32(0), jsp.linalg.det(x), check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype, "rng_factory": rng_factory}
for shape in [(1, 1), (3, 3), (2, 4, 4)]
for dtype in float_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_devices("tpu")
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testDetGrad(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
a = rng(shape, dtype)
jtu.check_grads(np.linalg.det, (a,), 2, atol=1e-1, rtol=1e-1)
# make sure there are no NaNs when a matrix is zero
if len(shape) == 2:
pass
jtu.check_grads(
np.linalg.det, (np.zeros_like(a),), 1, atol=1e-1, rtol=1e-1)
else:
a[0] = 0
jtu.check_grads(np.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)

def testDetGradOfSingularMatrix(self):
# Rank 2 matrix with nonzero gradient
a = np.array([[ 50, -30, 45],
[-30, 90, -81],
[ 45, -81, 81]], dtype=np.float32)
# Rank 1 matrix with zero gradient
b = np.array([[ 36, -42, 18],
[-42, 49, -21],
[ 18, -21, 9]], dtype=np.float32)
jtu.check_grads(np.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
jtu.check_grads(np.linalg.det, (b,), 1, atol=1e-1, rtol=1e-1)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
Expand Down