Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff broken in latest release #67

Closed
bgroenks96 opened this issue Mar 4, 2021 · 8 comments
Closed

Autodiff broken in latest release #67

bgroenks96 opened this issue Mar 4, 2021 · 8 comments

Comments

@bgroenks96
Copy link

v0.8.20 seems to have broken autodiff code that worked fine in v0.8.19.

MWE (excuse the clutter)

using DiffEqFlux, DiffEqSensitivity, Flux, OrdinaryDiffEq, Zygote, Test #using Plots
using DiffEqBase: get_tmp, dualcache
using ComponentArrays
using Parameters
using ForwardDiff
using ReverseDiff

p = ComponentArray(lvpara=ComponentArray=2.2=1.0=2.0=0.4),a=1.0,b=1.0)
u0 = ComponentArray(state=ComponentArray(x=1.0,y=1.0))
ax_p = getaxes(p)
ax_u = getaxes(u0)

chunk_size(dual::Type{ForwardDiff.Dual{T,V,N}}) where {T,V,N} = N
select(a::AbstractArray, u::AbstractArray) = a
select(dc::DiffEqBase.DiffCache, u) = get_tmp(dc,u)
select(a::AbstractArray, u::ReverseDiff.TrackedArray) = begin
  x = similar(u,size(a))
  x .= a
  x
end

struct LotkaVolterra{T}
  d::T
end

function (lv::LotkaVolterra)(du,u_,p,t)
  d = select(lv.d,u_)
  u = ComponentArray(u_,ax_u)
  p = ComponentArray(p,ax_p)
  @unpack x,y = u.state
  @unpack lvpara, a, b = p
  @unpack α, β, δ, γ = lvpara
  d[1] = a^2+b^2
  d[2] = b^2-a^2
  du[1] = dx =- β*y)x + d[1]
  du[2] = dy =*x - γ)y + d[2]
end
u0 = Array(u0)
p = Array(p)
lv = LotkaVolterra(dualcache(zeros(2),Val{6}))
prob = ODEProblem(lv,u0,(0.0,1.0),p)
function predict_rd(p)
  Array(solve(prob,Tsit5(),p=p,saveat=0.1,reltol=1e-4,sensealg=ForwardDiffSensitivity()))
end
loss_rd(p) = sum(abs2,x-1 for x in predict_rd(p))

opt = ADAM(0.1)
cb = function (p,l,pred)
  display(loss_rd(p))
  #display(plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,6)))
end

@time res = DiffEqFlux.sciml_train(loss_rd, p, opt, maxiters=100)

Error message:

MethodError: no method matching ComponentArray(::var"#7#8", ::Array{Float64,2})
Closest candidates are:
  ComponentArray(::Any, !Matched::FlatAxis...) at /home/brian/.julia/packages/ComponentArrays/zHt90/src/componentarray.jl:50
  ComponentArray(::Any, !Matched::Union{FlatAxis, ComponentArrays.NullAxis, Axis{IdxMap}} where IdxMap...) at /home/brian/.julia/packages/ComponentArrays/zHt90/src/componentarray.jl:51
  ComponentArray(::Any, !Matched::Union{FlatAxis, ComponentArrays.NullAxis, Axis{IdxMap}, ShapedAxis{Shape,IdxMap}} where IdxMap where Shape...) at /home/brian/.julia/packages/ComponentArrays/zHt90/src/componentarray.jl:52
  ...
rrule(::UnionAll, ::Function, ::Array{Float64,2}) at chainrulescore.jl:23
chain_rrule at chainrules.jl:89 [inlined]
macro expansion at interface2.jl:0 [inlined]
_pullback(::Zygote.Context, ::Type{Base.Generator}, ::var"#7#8", ::Array{Float64,2}) at interface2.jl:9
loss_rd at lotka_volterra.jl:47 [inlined]
_pullback(::Zygote.Context, ::typeof(loss_rd), ::Array{Float64,1}) at interface2.jl:0
#69 at train.jl:3 [inlined]
_pullback(::Zygote.Context, ::DiffEqFlux.var"#69#70"{typeof(loss_rd)}, ::Array{Float64,1}, ::SciMLBase.NullParameters) at interface2.jl:0
adjoint at lib.jl:188 [inlined]
_pullback at adjoint.jl:57 [inlined]
OptimizationFunction at basic_problems.jl:107 [inlined]
_pullback(::Zygote.Context, ::OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}, ::Array{Float64,1}, ::SciMLBase.NullParameters) at interface2.jl:0
adjoint at lib.jl:188 [inlined]
adjoint(::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}, ::Tuple{Array{Float64,1},SciMLBase.NullParameters}) at none:0
_pullback at adjoint.jl:57 [inlined]
OptimizationFunction at basic_problems.jl:107 [inlined]
_pullback(::Zygote.Context, ::OptimizationFunction{false,GalacticOptim.AutoZygote,OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#146#156"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#149#159"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#154#164",Nothing,Nothing,Nothing}, ::Array{Float64,1}, ::SciMLBase.NullParameters) at interface2.jl:0
adjoint at lib.jl:188 [inlined]
_pullback at adjoint.jl:57 [inlined]
#8 at solve.jl:94 [inlined]
_pullback(::Zygote.Context, ::GalacticOptim.var"#8#13"{OptimizationProblem{false,OptimizationFunction{false,GalacticOptim.AutoZygote,OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#146#156"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#149#159"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#154#164",Nothing,Nothing,Nothing},Array{Float64,1},SciMLBase.NullParameters,Nothing,Nothing,Nothing,Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}},Array{Float64,1},GalacticOptim.NullData}) at interface2.jl:0
pullback(::Function, ::Params) at interface.jl:167
gradient(::Function, ::Params) at interface.jl:48
__solve(::OptimizationProblem{false,OptimizationFunction{false,GalacticOptim.AutoZygote,OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#146#156"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#149#159"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#154#164",Nothing,Nothing,Nothing},Array{Float64,1},SciMLBase.NullParameters,Nothing,Nothing,Nothing,Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}}, ::ADAM, ::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at solve.jl:93
__solve at solve.jl:66 [inlined]
__solve at solve.jl:6...
@bgroenks96
Copy link
Author

Maybe you forgot to add a chain rule for that constructor?

@jonniedie
Copy link
Owner

Ah, I see what the problem is. I have a rule

ChainRulesCore.rrule(::typeof(ComponentArray), data, axes) = ComponentArray(data, axes), Δ->(getdata(Δ), getaxes(Δ))

because that is how you'd normally define a rule for a function. But since ComponentArray is a type as well as a function, it's type is UnionAll. So I was basically overwriting the rrule for UnionAll. 😬

Unfortunately I'm a little busy tonight to look into how I'm supposed to handle this tonight, so the fix won't come until probably tomorrow night. Hopefully it's something easy, but looking through the ChainRules docs, I don't see how to handle it just yet. Maybe it doesn't even need that rule.

@jonniedie
Copy link
Owner

Nevermind, I think I fixed it. I just deleted that rule. Seems to work fine now, I think. At least it's completing without an error. If you want to get the fix now without waiting for it to be registered, just do:

julia> m = @which ComponentArrays.ChainRulesCore.rrule(ComponentArray, identity, [1.0]);

julia> Base.delete_method(m);

@jonniedie
Copy link
Owner

It doesn't return a ComponentArray, though, which I think it should. So I'll have to figure out how to correctly add this method back.

@bgroenks96
Copy link
Author

I don't quite understand what you mean. The type of the gradient isn't a ComponentArray?

@jonniedie
Copy link
Owner

The sciml_train part above returns a plain Array. I didn't look much into it, but it seems like it should return a ComponentArray of the same type as p, correct?

@bgroenks96
Copy link
Author

If you mean res.minimizer, then yes, I would expect that. But it's not a big deal, really. You can just reconstruct it with the axes.

@jonniedie
Copy link
Owner

Turned out to be a pretty dumb mistake. I should have done Type{ComponentArray} instead of typeof(ComponentArray).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants