Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No byproduct by default #57

Merged
merged 10 commits into from
Jun 1, 2023
Merged

No byproduct by default #57

merged 10 commits into from
Jun 1, 2023

Conversation

mohamed82008
Copy link
Collaborator

This PR fixes #54 by implementing the approach suggested there to support both the cases with and without a byproduct from the forward function.

@mohamed82008
Copy link
Collaborator Author

The examples here need to be updated.

@codecov
Copy link

codecov bot commented May 27, 2023

Codecov Report

Patch coverage: 76.92% and project coverage change: -9.59 ⚠️

Comparison is base (ed66786) 93.90% compared to head (40738a1) 84.31%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #57      +/-   ##
==========================================
- Coverage   93.90%   84.31%   -9.59%     
==========================================
  Files           5        5              
  Lines          82      102      +20     
==========================================
+ Hits           77       86       +9     
- Misses          5       16      +11     
Impacted Files Coverage Δ
src/ImplicitDifferentiation.jl 25.00% <ø> (-50.00%) ⬇️
src/implicit_function.jl 65.38% <64.00%> (-34.62%) ⬇️
ext/ImplicitDifferentiationChainRulesExt.jl 100.00% <100.00%> (ø)
ext/ImplicitDifferentiationForwardDiffExt.jl 100.00% <100.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

src/implicit_function.jl Outdated Show resolved Hide resolved
src/implicit_function.jl Outdated Show resolved Hide resolved
src/implicit_function.jl Outdated Show resolved Hide resolved
src/implicit_function.jl Outdated Show resolved Hide resolved
src/implicit_function.jl Outdated Show resolved Hide resolved
src/implicit_function.jl Outdated Show resolved Hide resolved
test/misc.jl Outdated Show resolved Hide resolved
test/misc.jl Outdated Show resolved Hide resolved
@@ -9,6 +9,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More modern successor to UnPack.jl, useful to destructure for Julia <1.7

@@ -10,19 +10,19 @@ In the future, we would like to add [Enzyme.jl](https://github.com/EnzymeAD/Enzy
## Higher-dimensional arrays

For simplicity, our examples only display functions that eat and spit out vectors.
However, arbitrary array shapes are supported, as long as the forward _and_ conditions callables return similar arrays.
However, arbitrary array shapes are supported, as long as the forward mapping _and_ conditions return similar arrays.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to unify the terminology as "forward mapping" + "conditions"

Importantly, this forward pass _doesn't need to be differentiable_.
First we define a forward mapping correponding to the function we consider.
It returns the actual output $y(x)$ of the function, and can be thought of as a black box solver.
Importantly, this Julia callable _doesn't need to be differentiable by automatic differentiation packages but the underlying function still needs to be mathematically differentiable_.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I stressed the difference between the Julia callable and the mathematical function

implicit2 = ImplicitFunction(forward, conditions, linear_solver)
manual_linear_solver(A, b) = (Matrix(A) \ b, (solved=true,))

implicit_higher_order = ImplicitFunction(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disambiguated the name


"""
rrule(rc, implicit, x[; kwargs...])
rrule(rc, implicit, x; kwargs...)
rrule(rc, implicit, x, Val(return_byproduct); kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this way of showing the call signature is clearer for users who don't know what a Val is

```
This requires solving a linear system `A * J = -B`, where `A ∈ ℝᵈˣᵈ`, `B ∈ ℝᵈˣⁿ` and `J ∈ ℝᵈˣⁿ`.

# Fields
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to keep coherent notations throughout:

  • forward / conditions are our wrappers
  • f / c are the callables provided by the user

ImplicitFunction(f, c, Val(handle_byproduct); linear_solver=gmres)

Construct an `ImplicitFunction` by wrapping a forward mapping `f` into a field of type `Forward` and conditions `c` into a field of type `Conditions`, taking into account the value of `handle_byproduct` (which defaults to `false`).
The default linear solver is `Krylov.gmres`, but this can be changed with a keyword argument.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the linear solver to a keyword argument in order to solve the method ambiguity

conditions::CC
linear_solver::LS

function ImplicitFunction(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably better as an inner constructor since it enforces the invariant we want. That way users cannot screw up the handle_byproduct shenanigans


The first (default) call signature only returns `y(x)`, while the second returns `(y(x), z(x))`.
The argument `return_byproduct` is independent from the type parameter `handle_byproduct` in `ImplicitFunction`, so any combination is possible.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused by this at first, so it's probably best to specify the nuance

@testset verbose = true "Forward" begin
@test ForwardDiff.jacobian(implicit, x) ≈ J
x_and_dx = ForwardDiff.Dual.(x, ((0, 0),))
for return_byproduct in (true, false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You added this for reverse, I did the same for forward

@gdalle
Copy link
Member

gdalle commented Jun 1, 2023

Tests on nightly seem to fail because of JET. Maybe it's not worth having it as a test dependency? We can always run it locally before each commit, it's in my base env anyway

@mohamed82008
Copy link
Collaborator Author

Up to you. Maybe just allow breakage on nightly?

@gdalle
Copy link
Member

gdalle commented Jun 1, 2023

Up to you. Maybe just allow breakage on nightly?

I tried to do that in the CI file but for some reason it doesn't work. See here: https://discourse.julialang.org/t/ignoring-nightly-failure-for-ci-badge/98028

@gdalle
Copy link
Member

gdalle commented Jun 1, 2023

On second thought, adding examples with dummy byproducts is more confusing than anything. I think we should keep the first few examples dead simple, and add one where we use byproducts (eg. for the Jacobian), just like there is one for component arrays. Could you maybe draft it?

@mohamed82008
Copy link
Collaborator Author

Sure let's do it in another PR.

@gdalle gdalle merged commit 70dab9e into main Jun 1, 2023
@gdalle gdalle deleted the mt/no_byproduct_by_default branch June 2, 2023 08:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

make additional information z optional
2 participants