In [2]:
using Flux, DiffEqFlux, Plots, DifferentialEquations, Random, Optim, Interact
Random.seed!(1);

# Neural Differential Equations in Julia
> Exploring the [Flux.jl](https://github.com/FluxML/Flux.jl) and [DiffEqFlux.jl](https://github.com/JuliaDiffEq/DiffEqFlux.jl) packages. 


## Warm-Up: Using Flux for Linear Regression

[Flux](https://julialang.org/blog/2018/12/ml-language-compiler/): "...typical frameworks are all-encompassing monoliths in hundreds of thousands of lines of C++, Flux is only a thousand lines of straightforward Julia code. Simply take one package for gradients (Zygote.jl), one package for GPU support (CuArrays.jl), sprinkle with some light convenience functions, bake for fifteen minutes and out pops a fully-featured ML stack."

**Problem:** Given data $(x_i,y_i)_{i=0}^m$ we want to approximately solve the problem 

$$ \min_{W,b} \sum_{i=0}^m \| W*x_i+b - y_i \|_2^2. $$

In [None]:
# model
W = rand(1)
b = rand(1)

# prediction
pred(x) = W.*x .+ b

# loss
loss(x, y) = sum(abs2, (y .- pred(x)))

# data
samples = 30
noise = 0.1

x = rand(samples)
y = rand(1).*x .+ rand(1) .+ noise.*rand(samples) 

# initial loss
println(loss(x, y))

# plot
scatter(x, y, label="data")
dx = range(0, 1; length=100)
# Note: mutating/in-place functions have names that end in !
plot!(dx, pred(dx), label="prediction")

**Idea:** To improve the prediction we can take the gradients of W and b with respect to the loss and perform gradient descent.

In contrast to TensorFlow or PyTorch in Python this is possible without tracing the operations in advance (Julia is just-in-time compiled, the *computational graph* is Julia’s own syntax).

In [None]:
# gradient steps
steps = 100
# learning rate
lr = 0.01

# gradient descent
for i=1:steps
  gs = gradient(() -> loss(x, y), params(W, b))
  W .-= lr .* gs[W]
  b .-= lr .* gs[b]
  if i%20==0
    println("Step: $i Loss: $(loss(x, y))")
  end
end

# plot
scatter(x, y, label="data")
dx = range(0, 1; length=100)
plot!(dx, pred(dx), label="prediction")

## Neural Differential Equations using DiffEqFlux

[DiffEqFlux](https://julialang.org/blog/2019/01/fluxdiffeq/): "Layers have traditionally been simple functions like matrix multiply, but in the spirit of differentiable programming people are increasingly experimenting with much more complex functions, such as ray tracers and physics engines. Turns out that differential equations solvers fit this framework, too."


**Problem:** Given data $(t_i, u(t_i))_{i=0}^m$ of the solution to an *unkown* ODE

$$ u'(t) = f(u), \quad u(t_0) = u_0 $$

**Goal:**  Train a neural network model $\mathcal{N}_\Phi$ (with learnable parameters $\Phi$) to approximately recover $f$, i.e. learn the underlying ODE from data.

**Idea:** Numerically solve the *neural* ODE 

$$ \tilde{u}_\Phi'(t) = \mathcal{N}_{\Phi}(\tilde{u}_\Phi), \quad \tilde{u}_\Phi(t_0) = u_0 $$

at times $(t_i)_{i=0}^t$ with a package that allows computing the gradient of the error 
$$\sum_{i=0}^m \big( \tilde{u}_\Phi(t_i)-u(t_i)\big)^2$$

w.r.t. to $\Phi$ in order to perform first-order optimization. 

### Underlying (Unkown) Dynamics

In [None]:
# initial condition  Note: 2.0f0 = 2.0e0 for type Float32 instead of Float64
u0 = Float32[2.0f0] 

# number of samples
datasize = 100

# time horizon,
tspan = (0.0f0,15f0) 

# uniformly distributed points in tspan
t = tspan[1] .+ rand(Float32, datasize)*(tspan[2]-tspan[1])

# true du/dt
f(u,p,t) = 2*sin.(u)

# underlying true ODE]
ode = ODEProblem(f, u0, tspan) 

# solution of the true ODE at time-points t
sol = solve(ode, Tsit5(), saveat=t)

# data
t, u = sol.t, Array(sol)

# plot the solution 
scatter(t, vcat(u...), label="data")

### Neural Network Model

In [None]:
# neural network model
model = Chain(Dense(1,50,relu), Dense(50,100,relu), Dense(100,1))

# ODE solver for the neural network model
n_ode = NeuralODE(model, tspan, Tsit5(), saveat=t)

# prediction for given initial condition
ũ(Φ) = n_ode(u0,Φ)

# plot of the data and the (untrained) neural ODE prediction
scatter(t, vcat(u...), label="data")
scatter!(t, vcat(ũ(n_ode.p)...), label="prediction") 

### Optimization

In [None]:
# loss 
function loss(Φ) 
    pred = ũ(Φ)
    loss = sum(abs2, u .- pred)
    loss, pred
end

# callback function to observe training
cb = function (p, l, pred; doplot=false) 
  println("Loss: $l")
  if doplot
    pl = scatter(t,vcat(u...),label="data")
    scatter!(pl,t,vcat(pred...),label="prediction")
    display(plot(pl))
  end
  return false
end

# optimize with ADAM
res1 = DiffEqFlux.sciml_train(loss, n_ode.p, ADAM(0.01), cb=cb, maxiters=50)

In [None]:
# plot
cb(res1.minimizer, loss(res1.minimizer)...;doplot=true);

In [None]:
# optimize with LBFGS
res2 = DiffEqFlux.sciml_train(loss, res1.minimizer, LBFGS(), cb=cb)

In [None]:
# plot
cb(res2.minimizer, loss(res2.minimizer)...;doplot=true);

### Extrapolate

In [None]:
@manipulate for u0=0:0.2:4
    u0 = Float32[u0] 
    ode = ODEProblem(f, u0, tspan) 
    u = Array(solve(ode, Tsit5(), saveat=t))
    ũ = n_ode(u0,res2.minimizer)
    pl = scatter(t,vcat(u...),label="data")
    scatter!(pl,t,vcat(ũ...),label="prediction")
    display(plot(pl))
end