In [None]:
using Flux
using DiffEqFlux
using DifferentialEquations
using Plots
using Printf
using Random

In [None]:
batch_time = 10
batch_size = 20

u0 = Float32[2.; 0.]
datasize = 1000
tspan = (0., 25.)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= true_A' * (u.^3)
end

t = range(tspan[1], tspan[2], length=datasize)
prob = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob,Tsit5(), saveat=t))

N = 50
Random.seed!(2)
dudt = Chain(x->x.^3, Dense(2,N,tanh),Dense(N,2))
ps = Flux.params(dudt)
n_ode = u0 -> neural_ode(dudt, u0, tspan, Tsit5(), saveat=t, reltol=1e-7, abstol=1e-9)
#n_ode = u0 -> neural_ode(gpu(dudt), gpu(u0), tspan, Tsit5(), saveat=t, reltol=1e-7, abstol=1e-9)

pred = n_ode(u0)
scatter(t, ode_data[1,:], label="data")
scatter!(t, Flux.data(pred[1,:]), label="prediction")

In [None]:
u0 = Float32[2.; 0.]
t = range(tspan[1], tspan[2], length=datasize)
prob = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob,Tsit5(), saveat=t));

In [None]:
function predict_n_ode()
    n_ode(u0)
end

opt = RMSProp(0.001)

In [None]:
tt = t[1:batch_time]

function make_minibatch(batch_size, batch_time)
    u0 = 4*rand(2,batch_size) .- 2
    prob = ODEProblem(trueODEfunc, u0, extrema(tt))
    batch_u = solve(prob, Tsit5(), saveat=tt, reltol=1e-7, abstol=1e-9)
    return [(u0, batch_u.u)]
end

function loss_n_ode(u0, u)
    ode_data = neural_ode(dudt, u0, extrema(tt), Tsit5(), saveat=tt, reltol=1e-7, abstol=1e-9)
    L = 0
    for idx in 1:10
        L += sum(abs2, ode_data[:,:,idx] .- u[idx])
    end
    return L
end

In [None]:
loss_one_trajectory() = sum(abs2, ode_data .- n_ode(u0))

cb = function ()
    cur_pred = Flux.data(predict_n_ode())
    fig = scatter(t, ode_data[1,:], label="data")
    scatter!(fig, t, cur_pred[1,:], label="prediction")
    plot!(fig, title=@sprintf("Loss = %.2e", Flux.data(loss_one_trajectory())))
    IJulia.clear_output(true)
    display(fig)
end


#Flux.throttle(cb, 2);
for epoch_idx in 1:1000
    opt = RMSProp(0.001)#/(epoch_idx))
    batch_data = make_minibatch(batch_size, batch_time)
    Flux.train!(loss_n_ode, ps, batch_data, opt, cb=cb)
end

## Can we predict accurately for other initial data?

In [None]:
u0 = Float32[2.0; 0.]

prob2 = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob2,Tsit5(),saveat=t))

pred = n_ode(u0)
scatter(t, ode_data[1,:],label="data")
scatter!(t,Flux.data(pred[1,:]),label="prediction")

In [None]:
u0 = Float32[1.9; 0.]

prob2 = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob2,Tsit5(),saveat=t))

pred = n_ode(u0)
scatter(t, ode_data[1,:],label="data")
scatter!(t,Flux.data(pred[1,:]),label="prediction")

In [None]:
u0 = Float32[1.0; 0.]

prob2 = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob2,Tsit5(),saveat=t))

pred = n_ode(u0)
scatter(t, ode_data[1,:],label="data")
scatter!(t,Flux.data(pred[1,:]),label="prediction")

## Can we extrapolate to longer times?

In [None]:
tspan = (0.,3.)
u0 = Float32[2.0; 0.]

t = range(tspan[1], tspan[2], length=datasize)
prob2 = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob2,Tsit5(),saveat=t))

pred = n_ode(u0)
scatter(t, ode_data[1,:],label="data")
scatter!(t,Flux.data(pred[1,:]),label="prediction")