Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Custom derivative for np.linalg.det (#2809)
* Add vjp and jvp rules for jnp.linalg.det * Add tests for new determinant gradients * Replace index_update with concatenate in cofactor_solve This avoids issues with index_update not having a transpose rule, removing one bug in the way of automatically converting the JVP into a VJP (still need to deal with the np.where). * Changes to cofactor_solve so it can be transposed This allows a single JVP rule to give both forward and backward derivatives * Update det grad tests All tests pass now - however second derivatives still do not work for nonsingular matrices. * Add explanation to docstring for _cofactor_solve * Fixed comment
- Loading branch information
Showing
2 changed files
with
135 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters