Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code inside function with rrule should not run #57

Closed
gaurav-arya opened this issue May 10, 2022 · 2 comments · Fixed by #93
Closed

Code inside function with rrule should not run #57

gaurav-arya opened this issue May 10, 2022 · 2 comments · Fixed by #93

Comments

@gaurav-arya
Copy link

gaurav-arya commented May 10, 2022

MWE:

using Zygote
import ChainRulesCore.rrule
using ChainRulesCore: NoTangent
using AbstractDifferentiation

function myfunc(x)
    println("This should not print if I have an rrule.")
    x
end

rrule(::typeof(myfunc), x) = (x, (y -> (NoTangent(), y)))

println("Zygote run:")
Zygote.gradient(myfunc, 1) # nothing prints
println("AD run:")
AD.derivative(AD.ZygoteBackend(), myfunc, 1) # something prints

The code inside of myfunc should never run. In addition to possible inefficiency, this may lead to incorrect/confusing results for stateful functions or stochastic functions.

@devmotion
Copy link
Member

My hypothesis (without checking it in detail) is that the problem is that the backend for ChainRulesCore integration (used also for Zygote in your example) uses the rrule to define the pullback function (

AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...)
) but does drop the value of the primal completely, which therefore has to be recovered from the original function in value_and_pullback_function (
function value_and_pullback_function(
) which is e.g. used by derivative (
function derivative(ab::AbstractBackend, f, xs::Number...)
) via jacobian (
function AbstractDifferentiation.jacobian($(args...),)
).

I think one "just" has to define value_and_pullback_function in terms of only rrule for these backends. This should avoid the unnecessary call of the original function.

@devmotion
Copy link
Member

Just noticed, there's already #34, #35, #36.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants