Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tapir.value_and_pullback!! and mutation #113

Closed
gdalle opened this issue Apr 3, 2024 · 8 comments
Closed

Tapir.value_and_pullback!! and mutation #113

gdalle opened this issue Apr 3, 2024 · 8 comments

Comments

@gdalle
Copy link

gdalle commented Apr 3, 2024

As discovered in gdalle/DifferentiationInterface.jl#126, Tapir seems to struggle with mutating functions z = f!(y, x) (where z is ignored). I extracted the MWE from DifferentiationInterface.jl.

Setup:

using LinearAlgebra
using ForwardDiff: ForwardDiff
using Tapir: CoDual, NoTangent, build_rrule, value_and_pullback!!, tangent_type, zero_codual

function forwarddiff_value_and_pullback!!(f!; y, dx, x, dy)
    J = ForwardDiff.jacobian(f!, y, x)
    return y, transpose(J) * dy
end

function tapir_value_and_pullback!!(f!; y, dx, x, dy, verbose=false)
    rrule = build_rrule(f!, y, x)
    dy_righttype = convert(tangent_type(typeof(y)), dy)
    dx_righttype = convert(tangent_type(typeof(x)), dx)
    dx_righttype .= zero(eltype(dx_righttype))
    verbose && @info "before: primals" x y
    verbose && @info "before: tangents" dx_righttype dy_righttype
    new_z, (new_df!, new_dy, new_dx) = value_and_pullback!!(
        rrule,
        NoTangent(),
        zero_codual(f!),
        CoDual(y, dy_righttype),
        CoDual(x, dx_righttype),
    )
    verbose && @info "after: primals" x y
    verbose && @info "after: tangents" dx_righttype dy_righttype
    verbose && @info "after: Tapir output" new_z new_dx new_dy
    return y, new_dx
end

function f!(y, x)
    y .= x .^ 2
    return 0  # to get around the `copy(nothing)` issue for now
end

Results:

julia> forwarddiff_value_and_pullback!!(f!; y  = zeros(2), dx = zeros(2), x  = float.(1:2), dy = float.(3:4))
([1.0, 4.0], [6.0, 16.0])

julia> tapir_value_and_pullback!!(f!; y  = zeros(2), dx = zeros(2), x  = float.(1:2), dy = float.(3:4))
([0.0, 0.0], [0.0, 0.0])

There might be something wrong in the way I used Tapir though.

@willtebbutt willtebbutt changed the title Tapir and mutation Tapir.value_and_pullback!! and mutation Apr 3, 2024
@willtebbutt
Copy link
Member

willtebbutt commented Apr 3, 2024

Eurgh, yeah, this looks like another interface issue. This is what I get for publishing an interface without working through a proper range of examples with it.

If you tweak your above example to make use of the rule directly, something like the following:

function tapir_rrule!!(f!; y, dx, x, dy, verbose=false)
    rrule = build_rrule(f!, y, x)
    dy_righttype = convert(tangent_type(typeof(y)), dy)
    dx_righttype = convert(tangent_type(typeof(x)), dx)
    dx_righttype .= zero(eltype(dx_righttype))
    @show dx_righttype === dx
    verbose && @info "before: primals" x y
    verbose && @info "before: tangents" dx_righttype dy_righttype
    # new_y, (new_df!, new_dy, new_dx) = value_and_pullback!!(
    #     rrule,
    #     NoTangent(),
    #     zero_codual(f!),
    #     CoDual(y, dy_righttype),
    #     CoDual(x, dx_righttype),
    # )

    new_y, pb!! = rrule(
        zero_codual(f!),
        CoDual(y, dy_righttype),
        CoDual(x, dx_righttype),
    )

    display(dy_righttype)
    display(dx_righttype)

    _, new_dy, new_dx =  pb!!(Tapir.tangent(new_y), NoTangent(), dy_righttype, dx_righttype)

    verbose && @info "after: primals" x y
    verbose && @info "after: tangents" dx_righttype dy_righttype
    verbose && @info "after: Tapir output" new_y new_dx new_dy
    return y, new_dx
end

you'll get the following:

julia> tapir_rrule!!(f!; y  = zeros(2), dx = zeros(2), x  = float.(1:2), dy = float.(3:4), verbose=true)
dx_righttype === dx = true
┌ Info: before: primals
│   x =2-element Vector{Float64}:1.02.0
│   y =2-element Vector{Float64}:0.00.0
┌ Info: before: tangents
│   dx_righttype =2-element Vector{Float64}:0.00.0
│   dy_righttype =2-element Vector{Float64}:3.04.0
2-element Vector{Float64}:
 0.0
 0.0
2-element Vector{Float64}:
 0.0
 0.0
┌ Info: after: primals
│   x =2-element Vector{Float64}:1.02.0
│   y =2-element Vector{Float64}:0.00.0
┌ Info: after: tangents
│   dx_righttype =2-element Vector{Float64}:0.00.0
│   dy_righttype =2-element Vector{Float64}:3.04.0
┌ Info: after: Tapir output
│   new_y = CoDual{Int64, NoTangent}(0, NoTangent())
│   new_dx =2-element Vector{Float64}:0.00.0
│   new_dy =2-element Vector{Float64}:3.04.0
([0.0, 0.0], [0.0, 0.0])

So. y is reverted to its initial value by the end of doing both the forwards- and reverse- passes, which actually is the correct thing for AD to be doing, but it's not the correct thing for value_and_pullback!! to be doing.

Similarly, the state of dy is [0.0, 0.0] in the middle of execution, which is also intended behaviour, and then gets reverted to zero at the end of execution which is, again, intended behaviour.

Conversely, if we do the following:

function tapir_rrule!!_correct(f!; y, dx, x, dy, verbose=false)
    rrule = build_rrule(f!, y, x)
    yc = copy(y)
    xc = copy(x)
    dy_righttype = zero_tangent(y)
    dx_righttype = zero_tangent(x)

    verbose && @info "before: primals" x y
    verbose && @info "before: tangents" dx_righttype dy_righttype

    new_y, pb!! = rrule(
        zero_codual(f!),
        CoDual(yc, dy_righttype),
        CoDual(xc, dx_righttype),
    )

    display(dy_righttype)
    display(dx_righttype)

    # Copy state into originals,
    y .= yc
    x .= xc

    # Load tangents into AD memory.
    dy_righttype .= dy
    dx_righttype .= dx

    _, new_dy, new_dx =  pb!!(Tapir.tangent(new_y), NoTangent(), dy_righttype, dx_righttype)

    verbose && @info "after: primals" x y
    verbose && @info "after: tangents" dx_righttype dy_righttype
    verbose && @info "after: Tapir output" new_y new_dx new_dy
    return y, new_dx
end

we see

julia> tapir_rrule!!_correct(f!; y  = zeros(2), dx = zeros(2), x  = float.(1:2), dy = float.(3:4), verbose=true)
┌ Info: before: primals
│   x =2-element Vector{Float64}:1.02.0
│   y =2-element Vector{Float64}:0.00.0
┌ Info: before: tangents
│   dx_righttype =2-element Vector{Float64}:0.00.0
│   dy_righttype =2-element Vector{Float64}:0.00.0
2-element Vector{Float64}:
 0.0
 0.0
2-element Vector{Float64}:
 0.0
 0.0
┌ Info: after: primals
│   x =2-element Vector{Float64}:1.02.0
│   y =2-element Vector{Float64}:1.04.0
┌ Info: after: tangents
│   dx_righttype =2-element Vector{Float64}:6.016.0
│   dy_righttype =2-element Vector{Float64}:0.00.0
┌ Info: after: Tapir output
│   new_y = CoDual{Int64, NoTangent}(0, NoTangent())
│   new_dx =2-element Vector{Float64}:
│   new_dy =2-element Vector{Float64}:
([1.0, 4.0], [6.0, 16.0])

The lesson for value_and_pullback!! is basically that I need to be taking more copies of things, which is a bit of a shame, but is definitely necessary.

I'l have a go at this today. Clearly there are some lots of opportunities for subtle bugs in this interface, so I really appreciate you finding some of them @gdalle !

@gdalle
Copy link
Author

gdalle commented Apr 3, 2024

Sorry I got confused in the MWE with the function output itself, let's call it z. I corrected it above

@gdalle
Copy link
Author

gdalle commented Apr 3, 2024

Why is it the intended behavior for the AD to leave everything the way it was found?

@willtebbutt
Copy link
Member

Why is it the intended behavior for the AD to leave everything the way it was found?

There are two things here.

Primals: the reason that the primals must be left the way they were found is that you need to be able to assume that, upon entering the reverse-pass of an rrule, everything is as you left it at the end of the forwards-pass. (If you weren't to make this assumption, you'd have to make copies of the bits of all of the inputs that you need on the reverse-pass in case they're modified later on, which would often be quite computationally demanding). You can guarantee that this assumption holds by requiring each rrule to undo any changes it makes to primal state. This property holds for derived rules by induction.

(Co)tangents: to see why this makes sense, note that x and y effectively represent two separate things mathematically, at different points in the execution of value_and_pullback!!. Prior to executing the forwards-pass, they represent "x0" and "y0". Once you've executed the forwards-pass they represent "x1" and "y1". After running the reverse-pass, they represent "x0" and "y0" again. The point is here that because our variables have state, to get something you can make sense of mathematically you need to add a time-index to all of the variables that you are handling.

This raises the following question: are dy and dx the cotangents associated to "y0" / "x0", or to "y1" / "x1"? Tapir takes the view that they're associated to "y0" / "x0". So, if you:

  1. ran the forwards pass with non-zero tangents associated to y and x
  2. incremented the tangents associated to y and x at the end of the forwards-pass, and
  3. run the reverse-pass,

you should expect to see that the tangents to y and x have been incremented by some amount. In this case, they've been incremented by zero because the cotangents associated to "x0" and "y0" were both zero.

Does this help at all?

@gdalle
Copy link
Author

gdalle commented Apr 3, 2024

It helps a little bit, some parts are still obscure to me but it's okay.

My "mutating case" in DifferentiationInterface is still conceptually a function x -> y, even though we express it as f!(y, x). In other words, it modifies y but not x: such is the convention used by ForwardDiff and ReverseDiff, and it's generic enough to do lots of stuff.

So when I run DI.value_and_pullback!!(f!, y, dx, backend, x, dy), here is the behavior I would like to achieve:

  • the provided input x doesn't change: x1 = x0
  • the provided output y must be set to y1 (y0 can be discarded, it doesn't matter)
  • the provided tangent dx must be set to the VJP (\partial y1 / \partial dx0)' dy1 (not incremented, but it's enough to zero it myself beforehand)
  • the provided cotangent dy corresponds to dy1, which is where we might have an issue?

@willtebbutt
Copy link
Member

Ahhh cool. Yeah, we'll probably have to write that case specifically using the rrule interface directly and doing some copying, as I don't think I'll support this directly from Tapir. Happy to work with you on this!

@gdalle
Copy link
Author

gdalle commented Apr 3, 2024

Cool! Could you maybe write me a prototype manual value_and_pullback!! based on what I described above, so that I can insert it in gdalle/DifferentiationInterface.jl#126? It will probably be quicker if it's you than if I'm doing trial and error with the splitting into forward pass and pullback, and the seeding of the tangents

@willtebbutt
Copy link
Member

willtebbutt commented Apr 4, 2024

Sure! Something along the lines of the following ought to do it:

using Tapir:
    build_rrule, tangent_type, set_to_zero!!, zero_codual, CoDual, tangent, increment!!,
    NoTangent

function tapir_value_and_pullback!!(f!; y, dx, x, dy, verbose=false)

    # Build the rule. Obviously, this should really be done separately and re-used, as
    # you've previously pointed out.
    rrule = build_rrule(f!, y, x)

    # I'm just going to assert that the tangents are of the correct type. Possibly this is
    # the best strategy generally, but maybe not?
    @assert tangent_type(typeof(y)) == typeof(dy)
    @assert tangent_type(typeof(x)) == typeof(dx)

    # Assert that f! has no differentiable parameters.
    @assert tangent_type(typeof(f!)) == NoTangent

    # We want the VJP, not VJP + dx, so I'm going to zero-out `dx`. `set_to_zero!!` has the advantage
    # that it will also replace any immutable components of `dx` to zero.
    dx = set_to_zero!!(dx)

    # Per the above discussion, we want `dy` to correspond to the cotangent of `y` _after_
    # running the forwards-pass, so I'm going to take a copy, and zero-out the original.
    dy_1 = copy(dy)
    dy = set_to_zero!!(dy)

    # Mutate a copy of `y`, so that we can run the reverse-pass later on.
    y_copy = copy(y)

    # Run the forwards-pass.
    out, pb!! = rrule(zero_codual(f!), CoDual(y_copy, dy), CoDual(x, dx))

    # Verify that the output is non-differentiable.
    @assert tangent(out) == NoTangent()

    # Set the cotangent of `y` to be equal to the requested value.
    dy = increment!!(dy, dy_1)

    # Record the state of `y` before running the reverse-pass.
    y = copy!(y, y_copy)

    # Run the reverse-pass.
    _, _, new_dx = pb!!(NoTangent(), NoTangent(), dy, dx)

    return y, new_dx
end

Let me know if this helps!

I think the main way that this could fall over is if copy is not defined for y or dy, or if copy! is not defined for y.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants