# Write and Document Benchmark Examples for NeuralNetDiffEq.jl

In [2]:
using OrdinaryDiffEq, NeuralNetDiffEq, Plots, Flux

### Example: Lotka-Voltera 
The Lotka-Volterra equations, also the predator-prey equations, models changes in population of prey and predator over time when they interact. This is an example of a first-order nonlinear coupled differential equations, described as:

$$\frac{dx}{dt}=\alpha x -\beta x y, \; \frac{dy}{dt}=\delta x y-\gamma y$$
where 
$x$ and $y$ are the prey and predator populations; $\frac{dx}{dt}$ and $\frac{dy}{dt}$ denote instantaneous growth rates. This interaction is parameterised by positive real parameters $\alpha, \beta, \gamma, \delta$

In [3]:
function f(du,u,p,t)
  du[1] = p[1]*u[1] - p[2]*u[1]*u[2]
  du[2] = -p[3]*u[2] + p[4]*u[1]*u[2]
end
function f(u,p,t)
  [p[1]*u[1] - p[2]*u[1]*u[2],-p[3]*u[2] + p[4]*u[1]*u[2]]
end

f (generic function with 2 methods)

In [4]:
p = Float32[1.5,1.0,3.0,1.0]
u0 = Float32[1.0,1.0]
prob = ODEProblem(f,u0,(0f0,3f0),p)
prob_oop = ODEProblem{false}(f,u0,(0f0,3f0),p)

[36mODEProblem[0m with uType [36mArray{Float32,1}[0m and tType [36mFloat32[0m. In-place: [36mfalse[0m
timespan: (0.0f0, 3.0f0)
u0: Float32[1.0, 1.0]

### Defining Model and Optimiser
After many experiments, Adam and NADAM optimiser (with eta=1e-03) tend to outperform others, achieveing a loss as low as 44.6 in 100 epochs. All methods plateau after this loss, if it is reached at all. Larger models with up to 4096 channels tend to plateau at 51 loss within 100 iterations, while smaller models fail to achieve this accuracy. MaxPool layers do not seem to add value to the model.

In [5]:
true_sol = solve(prob,Tsit5())

opt = ADAM(1e-03) #1e-04
# opt = NADAM()
# opt = Nesterov()
# opt = AMSGrad()
# chain = Chain(x -> reshape(x, length(x), 1, 1), Conv((1,), 1=>16, relu), Conv((1,), 16=>8, relu), x -> reshape(x, :, size(x, 4)), Dense(8, 10), softmax)


chain = Chain(
    x -> reshape(x, length(x), 1, 1), 
    MaxPool((1,)), 
    Conv((1,), 1=>16, relu), 
    Conv((1,), 16=>16, relu), 
    Conv((1,), 16=>32, relu), 
    Conv((1,), 32=>64, relu), 
    Conv((1,), 64=>256, relu), 
    Conv((1,), 256=>256, relu), 
    Conv((1,), 256=>1028, relu), 
    Conv((1,), 1028=>1028), 
    x -> reshape(x, :, size(x, 4)), 
    Dense(1028, 512, tanh), 
    Dense(512, 128, relu), 
    Dense(128, 64, tanh), 
    Dense(64, 2), 
    softmax)

# m = Chain(Conv((1,), 1=>16, relu), Conv((1,), 16=>8, relu), x -> reshape(x, :, size(x, 4)), Dense(16, length(u0)), softmax) 

Chain(#3, MaxPool((1,), pad = (0, 0), stride = (1,)), Conv((1,), 1=>16, relu), Conv((1,), 16=>16, relu), Conv((1,), 16=>32, relu), Conv((1,), 32=>64, relu), Conv((1,), 64=>256, relu), Conv((1,), 256=>256, relu), Conv((1,), 256=>1028, relu), Conv((1,), 1028=>1028), #4, Dense(1028, 512, tanh), Dense(512, 128, relu), Dense(128, 64, tanh), Dense(64, 2), softmax)

In [None]:
sol  = solve(prob_oop,NeuralNetDiffEq.NNODE(chain,opt),maxiters = 100, verbose = true, dt=1/5f0)

Current loss is: 402.70865
Current loss is: 134.39284
Current loss is: 74.71057
Current loss is: 57.470463
Current loss is: 53.17737
Current loss is: 52.367645
Current loss is: 52.21946
Current loss is: 52.1955
Current loss is: 52.17527
Current loss is: 52.19066
Current loss is: 52.1906
Current loss is: 52.180058
Current loss is: 52.181484
Current loss is: 52.180534
Current loss is: 52.186043
Current loss is: 52.17969
Current loss is: 52.181576
Current loss is: 52.193066
Current loss is: 52.188477
Current loss is: 52.203316
Current loss is: 52.18
Current loss is: 52.179367
Current loss is: 52.178936
Current loss is: 52.180885
Current loss is: 52.178772
Current loss is: 52.17996
Current loss is: 52.171608
Current loss is: 52.186348
Current loss is: 52.177525
Current loss is: 52.176796
Current loss is: 52.178925
Current loss is: 52.17926
Current loss is: 52.17619
Current loss is: 52.179497
Current loss is: 52.18814
Current loss is: 52.174843
Current loss is: 52.17905
Current loss is: 52.

In [3]:
1e-03

0.001

### Plotting

In [None]:
plot(true_sol)
plot!(sol)
# savefig("ADAM.png")