In [None]:
using Flux
using DifferentialEquations
using Plots
using DiffEqFlux

In [None]:
function lotka_volterra(du,u,p,t)
    x, y = u
    α, β, δ, γ = p
    du[1] = dx = α*x - β*x*y
    du[2] = dy = -δ*y + γ*x*y
end

u0 = [1.0, 1.0]
tspan = (0., 10.)
p = [1.5, 1., 3., 1.]
prob = ODEProblem(lotka_volterra,u0,tspan,p)

In [None]:
soln = solve(prob)
plot(soln)

In [None]:
ret, soln = diffeq_rd(p, prob, Tsit5(), saveat=0.1)
plot(soln.t,[u[1] for u in soln.u])

In [None]:
p = param([2.2, 1.0, 2.0, 0.4])
params = Flux.Params([p])

In [None]:
function predict_rd()
    diffeq_rd(p, prob, Tsit5(), saveat=0.1)[1,:]
end

loss_rd() = sum(abs2, x-1 for x in predict_rd())
loss_rd()

In [None]:
data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function ()
        display(loss_rd())
        display(plot(solve(remake(prob,p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6)))
    end

In [None]:
cb()

In [None]:
Flux.train!(loss_rd, params, data, opt, cb=cb)

# Section 7

In [None]:
dudt = Chain(Dense(2,50,tanh),Dense(50,2))

In [None]:
tspan = (0.,25.)
x -> neural_ode(dudt,x,tspan,Tsit5(),saveat=0.1) # Not on GPU!

In [None]:
u0 = Float32[2.; 0.]
datasize = 30
tspan = (0., 1.5)

In [None]:
function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
t = range(tspan[1], tspan[2], length=datasize)
prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),saveat=t))

In [None]:
using Random
Random.seed!(1)
dudt = Chain(x->x.^3, Dense(2,50,tanh),Dense(50,2))
ps = Flux.params(dudt)
n_ode = x->neural_ode(dudt,x,tspan,Tsit5(),saveat=t,reltol=1e-7,abstol=1e-9)

In [None]:
pred = n_ode(u0)
scatter(t, ode_data[1,:],label="data")
scatter!(t,Flux.data(pred[1,:]),label="prediction")

In [None]:
function predict_n_ode()
    n_ode(u0)
end
loss_n_ode() = sum(abs2, ode_data .- predict_n_ode())
data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function ()
    display(loss_n_ode())
    cur_pred = Flux.data(predict_n_ode())
    pl = scatter(t, ode_data[1,:], label="data")
    scatter!(pl,t,cur_pred[1,:],label="prediction")
    display(plot(pl))
end

In [None]:
cb()

In [None]:
Flux.train!(loss_n_ode,ps,data,opt,cb=cb)