Skip to content

Commit

Permalink
Custom derivative for np.linalg.det (#2809)
Browse files Browse the repository at this point in the history
* Add vjp and jvp rules for jnp.linalg.det

* Add tests for new determinant gradients

* Replace index_update with concatenate in cofactor_solve

This avoids issues with index_update not having a transpose rule, removing one bug in the way of automatically converting the JVP into a VJP (still need to deal with the np.where).

* Changes to cofactor_solve so it can be transposed

This allows a single JVP rule to give both forward and backward derivatives

* Update det grad tests

All tests pass now - however second derivatives still do not work for nonsingular matrices.

* Add explanation to docstring for _cofactor_solve

* Fixed comment
  • Loading branch information
dpfau committed Apr 25, 2020
1 parent 4e020cc commit 02b3fc5
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 0 deletions.
100 changes: 100 additions & 0 deletions jax/numpy/linalg.py
Expand Up @@ -149,12 +149,112 @@ def _slogdet_jvp(primals, tangents):
return (sign, ans), (sign_dot, ans_dot)


def _cofactor_solve(a, b):
"""Equivalent to det(a)*solve(a, b) for nonsingular mat.
Intermediate function used for jvp and vjp of det.
This function borrows heavily from jax.numpy.linalg.solve and
jax.numpy.linalg.slogdet to compute the gradient of the determinant
in a way that is well defined even for low rank matrices.
This function handles two different cases:
* rank(a) == n or n-1
* rank(a) < n-1
For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
Rather than computing det(a)*solve(a, b), which would return NaN, we work
directly with the LU decomposition. If a = p @ l @ u, then
det(a)*solve(a, b) =
prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
If a is rank n-1, then the lower right corner of u will be zero and the
triangular_solve will fail.
Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
Then y_{nn} =
x_{nn} / u_{nn} * prod_{i=1...n}(u_{ii}) =
x_{nn} * prod_{i=1...n-1}(u_{ii})
So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
we can avoid the triangular_solve failing.
To correctly compute the rest of x_{ii} for i != n, we simply multiply
x_{ii} by det(a) for all i != n, which will be zero if rank(a) = n-1.
For the second case, a check is done on the matrix to see if `solve`
returns NaN or Inf, and gives a matrix of zeros as a result, as the
gradient of the determinant of a matrix with rank less than n-1 is 0.
This will still return the correct value for rank n-1 matrices, as the check
is applied *after* the lower right corner of u has been updated.
Args:
a: A square matrix or batch of matrices, possibly singular.
b: A matrix, or batch of matrices of the same dimension as a.
Returns:
det(a) and cofactor(a)^T*b, aka adjugate(a)*b
"""
a = _promote_arg_dtypes(np.asarray(a))
b = _promote_arg_dtypes(np.asarray(b))
a_shape = np.shape(a)
b_shape = np.shape(b)
a_ndims = len(a_shape)
b_ndims = len(b_shape)
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
and b_shape[-2:] == a_shape[-2:]):
msg = ("The arguments to _cofactor_solve must have shapes "
"a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
raise ValueError(msg.format(a_shape, b_shape))
if a_shape[-1] == 1:
return a[0, 0], b
# lu contains u in the upper triangular matrix and l in the strict lower
# triangular matrix.
# The diagonal of l is set to ones without loss of generality.
lu, pivots = lax_linalg.lu(a)
dtype = lax.dtype(a)
batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
x = np.broadcast_to(b, batch_dims + b.shape[-2:])
lu = np.broadcast_to(lu, batch_dims + lu.shape[-2:])
# Compute (partial) determinant, ignoring last diagonal of LU
diag = np.diagonal(lu, axis1=-2, axis2=-1)
parity = np.count_nonzero(pivots != np.arange(a_shape[-1]), axis=-1)
sign = np.array(-2 * (parity % 2) + 1, dtype=dtype)
# partial_det[:, -1] contains the full determinant and
# partial_det[:, -2] contains det(u) / u_{nn}.
partial_det = np.cumprod(diag, axis=-1) * sign[..., None]
lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2])
permutation = lax_linalg.lu_pivots_to_permutation(pivots, a_shape[-1])
permutation = np.broadcast_to(permutation, batch_dims + (a_shape[-1],))
iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1,)))
# filter out any matrices that are not full rank
d = np.ones(x.shape[:-1], x.dtype)
d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
d = np.any(np.logical_or(np.isnan(d), np.isinf(d)), axis=-1)
d = np.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:])
x = np.where(d, np.zeros_like(x), x) # first filter
x = x[iotas[:-1] + (permutation, slice(None))]
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
unit_diagonal=True)
x = np.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None],
x[..., -1:, :]), axis=-2)
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
x = np.where(d, np.zeros_like(x), x) # second filter

return partial_det[..., -1], x


@custom_jvp
@_wraps(onp.linalg.det)
def det(a):
sign, logdet = slogdet(a)
return sign * np.exp(logdet)


@det.defjvp
def _det_jvp(primals, tangents):
x, = primals
g, = tangents
y, z = _cofactor_solve(x, g)
return y, np.trace(z, axis1=-1, axis2=-2)


@_wraps(onp.linalg.eig)
def eig(a):
a = _promote_arg_dtypes(np.asarray(a))
Expand Down
35 changes: 35 additions & 0 deletions tests/linalg_test.py
Expand Up @@ -109,6 +109,41 @@ def testDet(self, n, dtype, rng_factory):
def testDetOfSingularMatrix(self):
x = np.array([[-1., 3./2], [2./3, -1.]], dtype=onp.float32)
self.assertAllClose(onp.float32(0), jsp.linalg.det(x), check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype, "rng_factory": rng_factory}
for shape in [(1, 1), (3, 3), (2, 4, 4)]
for dtype in float_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_devices("tpu")
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testDetGrad(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
a = rng(shape, dtype)
jtu.check_grads(np.linalg.det, (a,), 2, atol=1e-1, rtol=1e-1)
# make sure there are no NaNs when a matrix is zero
if len(shape) == 2:
pass
jtu.check_grads(
np.linalg.det, (np.zeros_like(a),), 1, atol=1e-1, rtol=1e-1)
else:
a[0] = 0
jtu.check_grads(np.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)

def testDetGradOfSingularMatrix(self):
# Rank 2 matrix with nonzero gradient
a = np.array([[ 50, -30, 45],
[-30, 90, -81],
[ 45, -81, 81]], dtype=np.float32)
# Rank 1 matrix with zero gradient
b = np.array([[ 36, -42, 18],
[-42, 49, -21],
[ 18, -21, 9]], dtype=np.float32)
jtu.check_grads(np.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
jtu.check_grads(np.linalg.det, (b,), 1, atol=1e-1, rtol=1e-1)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
Expand Down

0 comments on commit 02b3fc5

Please sign in to comment.