-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
No byproduct by default #57
Conversation
The examples here need to be updated. |
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #57 +/- ##
==========================================
- Coverage 93.90% 84.31% -9.59%
==========================================
Files 5 5
Lines 82 102 +20
==========================================
+ Hits 77 86 +9
- Misses 5 16 +11
☔ View full report in Codecov by Sentry. |
@@ -9,6 +9,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | |||
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" | |||
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" | |||
Requires = "ae029012-a4dd-5104-9daa-d747884805df" | |||
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More modern successor to UnPack.jl, useful to destructure for Julia <1.7
@@ -10,19 +10,19 @@ In the future, we would like to add [Enzyme.jl](https://github.com/EnzymeAD/Enzy | |||
## Higher-dimensional arrays | |||
|
|||
For simplicity, our examples only display functions that eat and spit out vectors. | |||
However, arbitrary array shapes are supported, as long as the forward _and_ conditions callables return similar arrays. | |||
However, arbitrary array shapes are supported, as long as the forward mapping _and_ conditions return similar arrays. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to unify the terminology as "forward mapping" + "conditions"
Importantly, this forward pass _doesn't need to be differentiable_. | ||
First we define a forward mapping correponding to the function we consider. | ||
It returns the actual output $y(x)$ of the function, and can be thought of as a black box solver. | ||
Importantly, this Julia callable _doesn't need to be differentiable by automatic differentiation packages but the underlying function still needs to be mathematically differentiable_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I stressed the difference between the Julia callable and the mathematical function
implicit2 = ImplicitFunction(forward, conditions, linear_solver) | ||
manual_linear_solver(A, b) = (Matrix(A) \ b, (solved=true,)) | ||
|
||
implicit_higher_order = ImplicitFunction( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disambiguated the name
|
||
""" | ||
rrule(rc, implicit, x[; kwargs...]) | ||
rrule(rc, implicit, x; kwargs...) | ||
rrule(rc, implicit, x, Val(return_byproduct); kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this way of showing the call signature is clearer for users who don't know what a Val
is
``` | ||
This requires solving a linear system `A * J = -B`, where `A ∈ ℝᵈˣᵈ`, `B ∈ ℝᵈˣⁿ` and `J ∈ ℝᵈˣⁿ`. | ||
|
||
# Fields |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to keep coherent notations throughout:
forward / conditions
are our wrappersf
/c
are the callables provided by the user
src/implicit_function.jl
Outdated
ImplicitFunction(f, c, Val(handle_byproduct); linear_solver=gmres) | ||
|
||
Construct an `ImplicitFunction` by wrapping a forward mapping `f` into a field of type `Forward` and conditions `c` into a field of type `Conditions`, taking into account the value of `handle_byproduct` (which defaults to `false`). | ||
The default linear solver is `Krylov.gmres`, but this can be changed with a keyword argument. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved the linear solver to a keyword argument in order to solve the method ambiguity
conditions::CC | ||
linear_solver::LS | ||
|
||
function ImplicitFunction( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably better as an inner constructor since it enforces the invariant we want. That way users cannot screw up the handle_byproduct
shenanigans
|
||
The first (default) call signature only returns `y(x)`, while the second returns `(y(x), z(x))`. | ||
The argument `return_byproduct` is independent from the type parameter `handle_byproduct` in `ImplicitFunction`, so any combination is possible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was confused by this at first, so it's probably best to specify the nuance
@testset verbose = true "Forward" begin | ||
@test ForwardDiff.jacobian(implicit, x) ≈ J | ||
x_and_dx = ForwardDiff.Dual.(x, ((0, 0),)) | ||
for return_byproduct in (true, false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You added this for reverse, I did the same for forward
Tests on nightly seem to fail because of JET. Maybe it's not worth having it as a test dependency? We can always run it locally before each commit, it's in my base env anyway |
Up to you. Maybe just allow breakage on nightly? |
I tried to do that in the CI file but for some reason it doesn't work. See here: https://discourse.julialang.org/t/ignoring-nightly-failure-for-ci-badge/98028 |
On second thought, adding examples with dummy byproducts is more confusing than anything. I think we should keep the first few examples dead simple, and add one where we use byproducts (eg. for the Jacobian), just like there is one for component arrays. Could you maybe draft it? |
Sure let's do it in another PR. |
This PR fixes #54 by implementing the approach suggested there to support both the cases with and without a byproduct from the
forward
function.