Skip to content

Split reverse mode for Tapir #115

@gdalle

Description

@gdalle

We've discussed this on Slack with @willtebbutt but I wanted to make sure where we stand on split reverse mode, i.e. separating the forward sweep from the reverse sweep.
The idea is being able to perform multiple reverse sweeps with different seeds after just one forward sweep. A typical example computing a Jacobian (where there is one seed per basis vector of the output space).

My question is the following: is that currently possible with Tapir's rrule? IIUC, the answer is no for functions that mutate their argument (JuliaDiff/DifferentiationInterface.jl#142), but what about simple allocating functions?

I took inspiration from https://github.com/withbayes/Tapir.jl/blob/f5e2b90cd17fd3127dd0fd8dfa617bc112275626/src/interface.jl#L9-L15
to try and write what I called value_and_pullback_split in DifferentiationInterface

function value_and_pullback_split(f, x)
    rule = build_rrule(f, x)
    tf = zero_tangent(f)
    tx = zero_tangent(x)
    out, pb!! = rule(CoDual(f, tf), CoDual(x, tx))
    y = copy(primal(out))
    function pullback(dy)
        dy_righttype = convert(tangent_type(typeof(y)), copy(dy))
        ty = increment!!(tangent(out), dy_righttype)
        new_df, new_dx = pb!!(ty, tf, tx)
        return new_dx
    end
    return y, pullback
end

But the behavior of the resulting closure changes at each call.
For some functions it gives different results:

julia> y, pullback = value_and_pullback_split(copy, [1.0])
([1.0], pullback)

julia> pullback([1.0])
1-element Vector{Float64}:
 1.0

julia> pullback([1.0])
1-element Vector{Float64}:
 3.0

julia> pullback([1.0])
1-element Vector{Float64}:
 6.0

For others it downright errors:

julia> y, pullback = value_and_pullback_split(x -> x .^ 2, [1.0])
([1.0], pullback)

julia> pullback([1.0])
1-element Vector{Float64}:
 2.0

julia> pullback([1.0])
ERROR: BoundsError: attempt to access 1-element Vector{Vector{Float64}} at index [0]
Stacktrace:
      internal @ Unknown
 [4] Pullback
   @ ~/.julia/packages/Tapir/BqxEi/src/interpreter/s2s_reverse_mode_ad.jl:632 [inlined]
 [5] (::var"#pullback#10"{Vector{}, Tapir.Pullback{}, CoDual{}, Vector{}, NoTangent})(dy::Vector{Float64})
   @ Main ./REPL[16]:10
Use `err` to retrieve the full stack trace.
Some type information was truncated. Use `show(err)` to see complete types.

What should I copy to allow for independent pullback calls? Probably out, tf and tx?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions