Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add derivative for diag and vec #24

Closed
dfdx opened this issue Feb 25, 2019 · 9 comments
Closed

Add derivative for diag and vec #24

dfdx opened this issue Feb 25, 2019 · 9 comments

Comments

@dfdx
Copy link
Owner

dfdx commented Feb 25, 2019

Motivating example:

function pd(x)
    m = x'*x
    v = diag(m)
    v .+ v' .- 2 .* m
end

Yota.grad(x -> sum(pd(x)), rand(2,3)) 
@dfdx dfdx changed the title Add derivative for diag Add derivative for diag and vec Feb 26, 2019
@improbable-22
Copy link

I tried to define these as follows:

@diffrule diag(x::Matrix)    x    diagm(0=>ds) # or perhaps Diagonal(ds)
@diffrule vec(x::AbstractArray)    x    reshape(ds, size(x))

But this doesn't work:

julia> using Yota, LinearAlgebra

julia> Yota.@diffrule LinearAlgebra.diag(x::Matrix) x diagm(0=>ds)                                        

julia> function pd(x)                                                                       
           m = x'*x
           v = diag(m)
           v .+ v' .- 2 .* m
       end
pd (generic function with 1 method)

julia> Yota.grad(x -> sum(pd(x)), rand(2,3))                                                
┌ Error: Failed to find a derivative for %13 = broadcast(%12, %7, %-1)::Array{Float64,2} at position 2, current state of backpropagation saved to Yota.DEBUG_STATE
└ @ Yota ~/.julia/packages/Yota/f2RBi/src/grad.jl:125
ERROR: BoundsError: attempt to access 26-element Array{Yota.AbstractOp,1} at index [-1]

@dfdx
Copy link
Owner Author

dfdx commented Feb 27, 2019

Hmm, looks like adjoint (e.g. x') generates additional promote_op calls which break tracer. I'm working on #23 right now so will take this issue just after that.

By the way, errors like:

ERROR: BoundsError: attempt to access 26-element Array{Yota.AbstractOp,1} at index [-1]

referring indexing with -1 usually mean that the tape was constructed incorrectly and some operation uses argument with non-existing ID. It might be interesting to check the tape with:

_, tape = Yota.trace(pd, rand(2,3))

Although in this particular case a bit more of debugging seems to be needed.

@improbable-22
Copy link

improbable-22 commented Feb 27, 2019

Thanks for having a look. I tried without the ' by reshaping, and at least get a different error:

julia> function pd(x)                                                                        
           m = x'*x                                                                          
           v = diag(m)                                                                       
           vrow = reshape(v,1,:)                                                             
           v .+ vrow .- 2 .* m                                                               
       end                                                                                   
pd (generic function with 1 method)

julia> Yota.grad(x -> sum(pd(x)), rand(2,3))                                                 
┌ Error: Failed to find a derivative for %10 = reshape(%7, %8, %9)::Array{Float64,2} at position 1, current state of backpropagation saved to Yota.DEBUG_STATE                            
└ @ Yota ~/.julia/packages/Yota/f2RBi/src/grad.jl:125                                        
ERROR: DimensionMismatch("new dimensions (3,) must be consistent with array size 9")         

Here's a shorter example of the same error:

julia> Yota.grad((x,y) -> sum(1 .+ x .+ reshape(y,2,3)), rand(2,3), rand(6)) # same shape is fine
(5.034909031105892, GradResult(2))

julia> Yota.grad((x,y) -> sum(1 .+ x .+ reshape(y,1,3)), rand(2,3), rand(3)) 
┌ Error: Failed to find a derivative for %5 = reshape(%2, %3, %4)::Array{Float64,2} at position 1, current state of backpropagation saved to Yota.DEBUG_STATE
└ @ Yota ~/.julia/packages/Yota/f2RBi/src/grad.jl:125
ERROR: DimensionMismatch("new dimensions (3,) must be consistent with array size 6")

@dfdx
Copy link
Owner Author

dfdx commented Feb 27, 2019

Ok, let's start with something working - on branch diag-vec-rules the following runs with a couple of warnings:

using Yota
using LinearAlgebra

function pd(x)
    m = transpose(x)*x
    v = diag(m)    
    v .+ transpose(v) .- 2 .* m
end

Yota.grad(x -> sum(pd(x)), rand(2,3))
# ┌ Warning: Gradient %26 has size (3, 3), but original variable %4 has size (3,)
# └ @ Yota ~/work/Yota/src/grad.jl:183
# ┌ Warning: Gradient %23 has size (3, 3), but original variable %5 has size (1, 3)
# └ @ Yota ~/work/Yota/src/grad.jl:183
# (1.1181653013988138, GradResult(1))

These warnings are related to reshape issues you've got. The core reason for them are broadcasting operations over arrays of different shapes. It's easier to show by example:

x = rand(3, 3)
y = rand(3)
z = x .+ y
L = sum(z)

z has size (3, 3), so dz - derivative of L w.r.t. z - also has size (3, 3). Derivative rules tell us that dx and dy should in general case also have size (3, 3), and this is correct for x, but not for y!

I thought I hacked it around with a set of very custom rules, but as it turns out it's very easy to go out from their scope. In you last example, derivative of reshape (which is itself reshape) tries to fit derivative if size (2, 3) into a container of size 3, hence the error. So perhaps we need to think out more general strategy for such cases - maybe kind of adaptive_reshape() which sums dimensions correctly or something. I need to think about it a bit.

Adjoint (') is a very different issue - I now see tests for it commented out (I event remember thinking "oh, I'll come back to it in a day"... that was 2 months ago :)). For whatever reason adjoint calls some unexpected internal stuff and totally breaks the tape. This seems less acute, so let's first concentrate on reshape / broadcasting issue.

@improbable-22
Copy link

OK I start to see. The rule is correct for vectors:

@diffrule broadcast(_fn::typeof(+), x::AbstractVector, y::AbstractMatrix) x sum_dropdims(ds, 2)

but broadcasting also knows that a rowvector is not a matrix, but this isn't in the types of x,y (at least not always).

sum! understands the same things... could something like this work? (not yet for me)

@diffrule broadcast(::typeof(+), x, y) x sum!(similar(x), ds)

@dfdx
Copy link
Owner Author

dfdx commented Feb 28, 2019

Wow, actually it fixes most of the examples! With your rules modified a bit the following works without any errors or warnings:

julia> using Yota

julia> using LinearAlgebra

julia> function pd(x)
           m = transpose(x)*x
           v = diag(m)    
           v .+ transpose(v) .- 2 .* m
       end
pd (generic function with 1 method)

julia> Yota.grad(x -> sum(pd(x)), rand(2,3))
(1.0071901213953678, GradResult(1))

julia> Yota.grad((x,y) -> sum(1 .+ x .+ reshape(y,2,3)), rand(2,3), rand(6))
(11.271115872899797, GradResult(2))

julia> Yota.grad((x,y) -> sum(1 .+ x .+ reshape(y,1,3)), rand(2,3), rand(3)) 
(10.995366600207689, GradResult(2))

Thanks for this excellent solution.

I'm going to spent some more time on fixing adjoint example and then merge all the changes.

@improbable-22
Copy link

OK, that's great! Clearly I was failing to overwrite the rules, or something.

I wondered if it would be worth checking size(ds)==size(x) so as not to copy if they are all the same size. Then I looked up what Flux does, and it's one step more involved:

https://github.com/FluxML/Flux.jl/blob/master/src/tracker/lib/array.jl#L434-L442

@dfdx
Copy link
Owner Author

dfdx commented Mar 1, 2019

I think Flux partially does the same thing as sum!, but need to test performance to make sure. Will do in the evening.

@dfdx
Copy link
Owner Author

dfdx commented Mar 2, 2019

Unsurprisingly, Flux version works nearly as fast as sum! when size(x) != size(ds) and much faster when they are equal, so I rewrote rules to use their approach.

I also fixed adjoint, so now all the examples in this issue work without errors or warnings. Thanks again for your input! Closing this issue since there seems nothing else to be fixed, but feel free to re-open if I missed something.

@dfdx dfdx closed this as completed Mar 2, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants