# DiffEqFlux

Foucs = Universal Differential Equations
SciMLSensitiviy.jl = underlyin, see https://docs.sciml.ai/SciMLSensitivity/stable/getting_started/

Lux.jl neural network are preferred for technical reasons: https://docs.sciml.ai/DiffEqFlux/stable/#Flux.jl-vs-Lux.jl
- https://www.youtube.com/watch?v=5jF-c_DNSkg&ab_channel=TheJuliaProgrammingLanguage 
    - EXPLICIT parameterisation! specifyin the trainable and non-trainable parts of the model. So give explicitly to Zygote what you want to differentiate!
    - Similar to Flux for backend!

Tutorial on Neural ODE: https://docs.sciml.ai/DiffEqFlux/stable/examples/neural_ode/

In [None]:
using DrWatson 
@quickactivate "diff_gleam"
using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations
using Optimization #, OptimizationOptimJL, OptimizationFlux,
using Random, Plots

In [None]:
rng = Random.default_rng() #Return the default global random number generator (RNG).
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1],tspan[2], length = datasize)

#so the ODE we want to approximate
function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)' #' denotes the transpose
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
sol_trueode = solve(prob_trueode, Tsit5(), saveat = tsteps)
plot(sol_trueode) #plot recipe for this solution object

So above, we see the real function we want to approximate using the NeuralODE!

In [None]:
ode_data = Array(sol_trueode) #time and solution as matrix!

In [None]:
dudt2 = Lux.Chain(
    x -> x.^3, #include this as prior knowledge!
    Lux.Dense(2,50,tanh),
    Lux.Dense(50,2)
)

In [None]:
p, st = Lux.setup(rng, dudt2) #initialise random weights!
#p = parmaeter, st = state varaibals

In [None]:
dudt2 = Lux.Chain(x -> x.^3,
                  Lux.Dense(2, 50, tanh),
                  Lux.Dense(50, 2))

In [None]:
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
#default is the adjoint method!

In [None]:
methods(prob_neuralode)

In [None]:
function predict_neuralode(p)
    Array(prob_neuralode(u0, p, st)[1])
end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2,ode_data .- pred) #abs2 = square ob the absolute value, applied on each element
    return loss, pred
end

To explain the code above!

In [None]:
test = prob_neuralode(u0,p,st) #so this is FORWARD mode!

In [None]:
Array(test[1])

In [None]:
loss_neuralode(p)[1] #the current loss

In [None]:
callback = function(p, l, pred; doplot = false)
    println(l)
    # plot current prediction against data
    if doplot
        plt = scatter(tsteps, ode_data[1,:], label = "data")
        scatter!(plt, tsteps, pred[1,:], label = "prediction")
        display(plot(plt))
    end
    return false
end

pinit = ComponentArray(p) #useful for problems with mutable arrays
callback(pinit, loss_neuralode(pinit)...) #3 dots to expand

First ADAM is used, then LBFGS is used

In [None]:
#ADtype
adtype = Optimization.AutoZygote() #for Reversemode AD

#x = the parameters = the old 'p' = what we want to change
#p = the hyperparameters of the opitmization
loss_ft_for_opt = (x,p) -> loss_neuralode(x) #obliged form
optf = Optimization.OptimizationFunction(loss_ft_for_opt, adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)

This optimisation is based on the library called [Optimization.jl](https://docs.sciml.ai/Optimization/stable/getting_started/)
- Defining optimisation problems: https://docs.sciml.ai/Optimization/stable/API/optimization_problem/
- Defining optimisation functions: https://docs.sciml.ai/Optimization/stable/API/optimization_function/

In [None]:
using OptimizationFlux #for using Adam
#the first training with Adam
result_neuralode = Optimization.solve(
    optprob,
    Adam(0.05), #0.05 = the learning rate
    callback = callback,
    maxiters = 300
)

Retrain with LBFGS

In [None]:
using OptimizationOptimJL
optprob2 = remake(optprob, u0 = result_neuralode.u)
results_neuralode2 = Optimization.solve(
    optprob2,
    Optim.BFGS(),
    callback = callback,
    allow_f_increases =false #stop near minimum
)

In [None]:
callback(results_neuralode2.u, loss_neuralode(results_neuralode2.u)...; doplot=true)

In [None]:
using NBInclude
nbexport("diffeqflux.jl", "diffeqflux.ipynb")

## Experiment: try it with a modellingtoolkit interface

Use this as example to recreate: https://github.com/SciML/ModelingToolkit.jl/issues/1271

In [None]:
using ModelingToolkit

In [None]:
tpsan = (0.0f0,8.0f0)
ann = Lux.Chain(
    Lux.Dense(1,6),
    Lux.Dense(6,6,tanh),
    Lux.Dense(6,1)
)
θ, st = Lux.setup(rng, ann) #initialise random weights!
N = length(θ)
@variables t, x(t), xx(t)
@register_symbolic Lux.apply(ann, x, θ, st)
@parameters p
D = Differential(t)
eqs = [
    D(x) ~ xx
    D(xx) ~ Lux.apply(ann, x, θ, st)[1]
]
sys = ODEsystem(eqs,t,[x,xx],[p])
# function an(t,p)
#     return ann(t,p)[1]^3

In [None]:
methods(ann)

In [None]:
using DiffEqFlux, DifferentialEquations, Plots, Statistics,ModelingToolkit
tspan = (0.0f0,8.0f0)
ann = FastChain(FastDense(1,6,tanh), FastDense(6,6,tanh), FastDense(6,1))
θ = initial_params(ann)
N = length(θ)
function an(t,p)
    return ann(t,p)[1]^3
end
@register an(t,p)
@variables t,x(t),xx(t)
@parameters p[1:N]
D = Differential(t)
eqs=[ 
    D(x)~ xx
    D(xx) ~ an(t,p[1:N])
]
sys = ODESystem(eqs,t,[x,xx],p)
u0 =[
    x => -4.f0
    xx => 0.f0
]
paras = [p[i]=>θ[i] for i in 1:N]
prob = ODEProblem(structural_simplify(sys),u0,tspan,paras)
sol = solve(prob)


In [None]:
methods(ann)

In [None]:
Lux.apply(ann, [10], θ, st)

In [None]:
methods(Lux.apply)

https://docs.sciml.ai/ModelingToolkit/stable/systems/ODESystem/#ODESystem