&nbsp;

# 20 - Neural ODE/SDE

---

&nbsp;

# 1. Introduction

In [1]:
versioninfo()

using Pkg
Pkg.instantiate()
Pkg.precompile()

using NPZ, LinearAlgebra, Statistics
using DifferentialEquations, SciMLSensitivity
using Lux, Optimisers, ComponentArrays, Zygote
using Random

using Optimization
using OptimizationOptimisers

Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × 11th Gen Intel(R) Core(TM) i7-1165G7 @ 2.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, tigerlake)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)


In [2]:
reducedState = npzread("data/processed/sstReducedStateCOPERNICUS20102019Prepared.npz")
println("Available keys in reducedState:", keys(reducedState))

PCsTrain = Float32.(reducedState["PCsTrain"])
tTrain   = Float32.(reducedState["tTrain"])
PCsVal = Float32.(reducedState["PCsVal"])
tVal   = Float32.(reducedState["tVal"])
std_data = Float32.(reducedState["std"])  # Renommé pour éviter conflit

t = vcat(tTrain, tVal)

z0 = PCsTrain[:, 1]
tspan = (t[1], t[end])

r = size(PCsTrain, 1)

nn = Chain(
    Dense(r, 64, tanh),
    Dense(64, 64, tanh),
    Dense(64, r)
)

rng = Random.default_rng()
ps, st = Lux.setup(rng, nn)
ps = ComponentArray(ps)

function f!(dz, z, p, t)
    y, _ = nn(z, p, st)
    @assert length(y) == length(dz)
    dz .= y
end

dz = similar(z0)
y  = nn(z0, ps, st)

@show length(z0)
@show length(dz)
@show length(y)

f!(dz, z0, ps, tTrain[1])

prob = ODEProblem(f!, z0, tspan, ps)

Available keys in reducedState:["PCsVal", "tVal", "PCsTrain", "tTrain", "std"]
length(z0) = 2922
length(dz) = 2922
length(y) = 2


[38;2;86;182;194mODEProblem[0m with uType [38;2;86;182;194mVector{Float32}[0m and tType [38;2;86;182;194mFloat32[0m. In-place: [38;2;86;182;194mtrue[0m
Non-trivial mass matrix: [38;2;86;182;194mfalse[0m
timespan: (0.0f0, 729.0f0)
u0: 2922-element Vector{Float32}:
 -0.2329068
 -0.21520926
 -0.3013477
 -0.32271194
 -0.43515974
 -0.93329793
 -1.11059
 -1.1857773
 -1.3257517
 -1.4246912
 -1.6363107
 -1.6228559
 -1.6361544
  ⋮
 -1.0148848
 -1.6231638
 -2.0417738
 -2.3244388
 -2.0626087
 -1.2057661
 -1.0863514
 -1.0674461
 -1.0999924
 -1.2391645
 -1.3900986
 -1.134983

In [3]:
# Préparer les données d'entraînement et de validation
Ztrain = PCsTrain
Zval   = PCsVal

z0Train = Ztrain[:, 1]
z0Val   = Zval[:, 1]

tspanTrain = (tTrain[1], tTrain[end])
tspanVal   = (tVal[1], tVal[end])

Ttrain = size(Ztrain, 2)
Tval   = size(Zval, 2)

# Convertir range en Array pour permettre l'indexation
tSaveTrain = Float32.(collect(range(tspanTrain[1], tspanTrain[2], length=Ttrain)))
tSaveVal   = Float32.(collect(range(tspanVal[1],   tspanVal[2],   length=Tval)))

T = size(Ztrain, 2)        # = 150
Tshort = min(100, T)       # par ex.

ZtrainS = Ztrain[:, 1:Tshort]
tSaveTrainS = tSaveTrain[1:Tshort]
z0TrainS = ZtrainS[:, 1]

std_mode = std(ZtrainS; dims=2)      # (2922, 1)
std_mode = vec(std_mode)             # (2922,)

ZtrainN = ZtrainS ./ reshape(std_mode, :, 1)
ZvalN   = Zval   ./ reshape(std_mode, :, 1)

LoadError: DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 730 and 2922

In [None]:
# Créer les templates de problèmes
probTrainTemplate = ODEProblem(f!, z0Train, tspanTrain, nothing)
probValTemplate   = ODEProblem(f!, z0Val,   tspanVal,   nothing)

In [None]:
# Fonction de prédiction - CORRECTION: ZygoteVJP au lieu de ZygoteVJP()
function predict(ps, probTemplate, tSave)
    prob = remake(probTemplate; p=ps)
    sol = solve(
        prob,
        Tsit5(),
        saveat = tSave,
        sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP),  # ← CORRECTION ICI
        verbose = false
    )
    return Array(sol)
end

In [None]:
# Fonction de perte
function loss(ps)
    Ẑ = predict(ps, probTrainTemplate, tSaveTrain)
    @assert size(Ẑ) == size(Ztrain)
    return sum(abs2, Ẑ .- Ztrain) / length(Ztrain)
end

@time loss(ps)

In [None]:
# Optimisation
optf = OptimizationFunction(
    (x, p) -> loss(x),
    Optimization.AutoZygote()
)

optprob = OptimizationProblem(optf, ps)

res = Optimization.solve(
    optprob,
    ADAM(1e-3),
    maxiters = 200
)

ps = res.u

In [None]:
# Validation
ẐVal = predict(ps, probValTemplate, tSaveVal)

@assert size(ẐVal) == size(Zval)

valError = norm(ẐVal .- Zval) / norm(Zval)
println("relative validation error = ", valError)