Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error: "NamedTuple has no field axes" #22

Open
metanoid opened this issue Jun 13, 2020 · 1 comment
Open

Error: "NamedTuple has no field axes" #22

metanoid opened this issue Jun 13, 2020 · 1 comment

Comments

@metanoid
Copy link

It seems that the ComponentArray constructor is not differentiable.

Context: I have a two-step loss function, where I do some upfront work to estimate some parameters from the data, then predict using those parameters and others, so I'm trying to build a single parameter array combining the sets of parameters.

Reproducible code sample, adapted from docs

using ComponentArrays
using OrdinaryDiffEq
using Plots
using UnPack

using DiffEqFlux: sciml_train
using Flux: glorot_uniform, ADAM
using Optim: LBFGS

u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 1.5f0)

dense_layer(in, out) = ComponentArray(W=glorot_uniform(out, in), b=zeros(out))

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
t = range(tspan[1], tspan[2], length = datasize)
prob = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat = t))


function dudt(u, p, t)
    @unpack L1, L2 = p
    return L2.W * tanh.(L1.W * u.^3 .+ L1.b) .+ L2.b
end

prob = ODEProblem(dudt, u0, tspan)

layers = (L1=dense_layer(2, 50), L2=dense_layer(50, 2))
θ = ComponentArray(u=u0, p=layers)

predict_n_ode(θ) = Array(solve(prob, Tsit5(), u0=θ.u, p=θ.p, saveat=t))

function loss_n_ode(θ)
    other_params = rand(3) # simulates additional work done
    θ2 = ComponentArray(u = θ.u, p = θ.p, other = other_params) # constructor
    pred = predict_n_ode(θ2) # changed
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end
loss_n_ode(θ)

cb = function (θ, loss, pred; doplot=false)
    display(loss)
    # plot current prediction against data
    pl = scatter(t, ode_data[1,:], label = "data")
    scatter!(pl, t, pred[1,:], label = "prediction")
    display(plot(pl))
    return false
end


cb(θ, loss_n_ode(θ)...)

data = Iterators.repeated((), 1000)

res1 = sciml_train(loss_n_ode, θ, ADAM(0.05); cb=cb, maxiters=100)
cb(res1.minimizer, loss_n_ode(res1.minimizer)...; doplot=true)

res2 = sciml_train(loss_n_ode, res1.minimizer, LBFGS(); cb=cb)
cb(res2.minimizer, loss_n_ode(res2.minimizer)...; doplot=true)

Error message:

ERROR: type NamedTuple has no field axes
Stacktrace:
 [1] getproperty at .\Base.jl:33 [inlined]
 [2] getindex at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\set_get.jl:49 [inlined]
 [3] _broadcast_getindex_evalf at .\broadcast.jl:631 [inlined]
 [4] _broadcast_getindex at .\broadcast.jl:604 [inlined]
 [5] (::Base.Broadcast.var"#19#20"{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(getindex),Tuple{Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}},Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(ComponentArrays.getval),Tuple{Tuple{DataType}}}}}})(::Int64) at .\broadcast.jl:1024
 [6] ntuple at .\ntuple.jl:41 [inlined]
 [7] copy at .\broadcast.jl:1024 [inlined]
 [8] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(getindex),Tuple{Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}},Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(ComponentArrays.getval),Tuple{Tuple{DataType}}}}}) at .\broadcast.jl:820
 [9] #s16#21(::Any, ::Any, ::Any) at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\set_get.jl:74
 [10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at .\boot.jl:526
 [11] getproperty at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\set_get.jl:68 [inlined]
 [12] adjoint at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\if_required\zygote.jl:10 [inlined]
 [13] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}}, ::Val{:axes}) at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47
 [14] _pullback(::Zygote.Context, ::typeof(getfield), ::ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}}, ::Symbol) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\lib\lib.jl:221
 [15] getaxes at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\set_get.jl:31 [inlined]
 [16] make_idx at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:131 [inlined]
 [17] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_idx), ::Array{Any,1}, ::ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W 
= View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}}, ::UnitRange{Int64}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [18] make_idx at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:121 [inlined]
 [19] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_idx), ::Array{Any,1}, ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}, ::Int64) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [20] make_carray_args at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:111 [inlined]
 [21] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_carray_args), ::Type{Array{Float64,1}}, ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [22] make_carray_args at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:109 [inlined]
 [23] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_carray_args), ::Type{Float64}, ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [24] make_carray_args at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:108 [inlined]
 [25] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_carray_args), ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [26] ComponentArray at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:64 [inlined]
 [27] _pullback(::Zygote.Context, ::Type{ComponentArray}, ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b 
= 101:102)))}}},Array{Float64,1}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [28] #ComponentArray#12 at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:66 [inlined]
 [29] _pullback(::Zygote.Context, ::ComponentArrays.var"##ComponentArray#12", ::Base.Iterators.Pairs{Symbol,AbstractArray{Float64,1},Tuple{Symbol,Symbol,Symbol},NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}}, ::Type{ComponentArray}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0 (repeats 2 times)
 [30] loss_n_ode at .\untitled-79d0146585cdbfb1aa13a5142027add7:39 [inlined]
 [31] _pullback(::Zygote.Context, ::typeof(loss_n_ode), ::ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [32] adjoint at C:\Users\username\.julia\packages\Zygote\uGBKO\src\lib\lib.jl:179 [inlined]
 [33] _pullback at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [34] #24 at C:\Users\username\.julia\packages\DiffEqFlux\7Lfxh\src\train.jl:99 [inlined]
 [35] _pullback(::Zygote.Context, ::DiffEqFlux.var"#24#29"{Tuple{},typeof(loss_n_ode),ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [36] pullback(::Function, ::Zygote.Params) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface.jl:172
 [37] gradient(::Function, ::Zygote.Params) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface.jl:53
 [38] macro expansion at C:\Users\username\.julia\packages\DiffEqFlux\7Lfxh\src\train.jl:98 [inlined]
 [39] macro expansion at C:\Users\username\.julia\packages\ProgressLogging\g8xnW\src\ProgressLogging.jl:328 [inlined]
 [40] (::DiffEqFlux.var"#23#28"{var"#42#44",Int64,Bool,Bool,typeof(loss_n_ode),ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}},Zygote.Params})() at C:\Users\username\.julia\packages\DiffEqFlux\7Lfxh\src\train.jl:43
 [41] maybe_with_logger(::DiffEqFlux.var"#23#28"{var"#42#44",Int64,Bool,Bool,typeof(loss_n_ode),ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}},Zygote.Params}, ::Nothing) at C:\Users\username\.julia\packages\DiffEqBase\Co6yv\src\utils.jl:259
 [42] sciml_train(::Function, ::ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}}, ::ADAM, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at C:\Users\username\.julia\packages\DiffEqFlux\7Lfxh\src\train.jl:42
 [43] top-level scope at none:0
@jonniedie
Copy link
Owner

jonniedie commented Jun 13, 2020

Oh interesting. I'll have to give that one some thought.

It's also worth noting that the normal ComponentArray(a=something, b=something_else, ...) style of constructor is pretty slow because it has to recurse through the structure and build up the inner array as it goes. It wasn't really intended to be used in a hot loops. The other style of constructor, ComponentArray(data, axes) is both differentiable and very fast, but right now it's kinda hard to add fields this way because the Axis interface is so clunky.

I've had it on the back of my mind for a while to add a ComponentArray(θ; other=other_params, ...) style method for quickly creating new ComponentArrays with additional component fields. This gives me good reason to finally implement that.

I try to take Saturdays off, so either tomorrow night or the Monday I'll try to get to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants