-
Notifications
You must be signed in to change notification settings - Fork 25
Description
We've discussed this on Slack with @willtebbutt but I wanted to make sure where we stand on split reverse mode, i.e. separating the forward sweep from the reverse sweep.
The idea is being able to perform multiple reverse sweeps with different seeds after just one forward sweep. A typical example computing a Jacobian (where there is one seed per basis vector of the output space).
My question is the following: is that currently possible with Tapir's rrule? IIUC, the answer is no for functions that mutate their argument (JuliaDiff/DifferentiationInterface.jl#142), but what about simple allocating functions?
I took inspiration from https://github.com/withbayes/Tapir.jl/blob/f5e2b90cd17fd3127dd0fd8dfa617bc112275626/src/interface.jl#L9-L15
to try and write what I called value_and_pullback_split in DifferentiationInterface
function value_and_pullback_split(f, x)
rule = build_rrule(f, x)
tf = zero_tangent(f)
tx = zero_tangent(x)
out, pb!! = rule(CoDual(f, tf), CoDual(x, tx))
y = copy(primal(out))
function pullback(dy)
dy_righttype = convert(tangent_type(typeof(y)), copy(dy))
ty = increment!!(tangent(out), dy_righttype)
new_df, new_dx = pb!!(ty, tf, tx)
return new_dx
end
return y, pullback
endBut the behavior of the resulting closure changes at each call.
For some functions it gives different results:
julia> y, pullback = value_and_pullback_split(copy, [1.0])
([1.0], pullback)
julia> pullback([1.0])
1-element Vector{Float64}:
1.0
julia> pullback([1.0])
1-element Vector{Float64}:
3.0
julia> pullback([1.0])
1-element Vector{Float64}:
6.0For others it downright errors:
julia> y, pullback = value_and_pullback_split(x -> x .^ 2, [1.0])
([1.0], pullback)
julia> pullback([1.0])
1-element Vector{Float64}:
2.0
julia> pullback([1.0])
ERROR: BoundsError: attempt to access 1-element Vector{Vector{Float64}} at index [0]
Stacktrace:
⋮ internal @ Unknown
[4] Pullback
@ ~/.julia/packages/Tapir/BqxEi/src/interpreter/s2s_reverse_mode_ad.jl:632 [inlined]
[5] (::var"#pullback#10"{Vector{…}, Tapir.Pullback{…}, CoDual{…}, Vector{…}, NoTangent})(dy::Vector{Float64})
@ Main ./REPL[16]:10
Use `err` to retrieve the full stack trace.
Some type information was truncated. Use `show(err)` to see complete types.What should I copy to allow for independent pullback calls? Probably out, tf and tx?