# 20 - Neural ODE/SDE

---

# 1. Introduction

In [14]:
versioninfo()

using Pkg
Pkg.instantiate()
Pkg.precompile()

using NPZ, LinearAlgebra, Statistics
using DifferentialEquations, SciMLSensitivity
using Lux, Optimisers, ComponentArrays, Zygote
using Random

using Optimization
using OptimizationOptimisers

Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × 11th Gen Intel(R) Core(TM) i7-1165G7 @ 2.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, tigerlake)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)


In [15]:
reducedState = npzread("data/processed/sstReducedStateCOPERNICUS20102019Prepared.npz")
println("Available keys in reducedState:", keys(reducedState))

# Charger les données - NE PAS utiliser 'std' comme nom de variable !
PCsTrain_raw = Float32.(reducedState["PCsTrain"])
tTrain_raw   = Float32.(reducedState["tTrain"])
PCsVal_raw   = Float32.(reducedState["PCsVal"])
tVal_raw     = Float32.(reducedState["tVal"])
std_data     = Float32.(reducedState["std"])

println("\nSize of PCsTrain_raw: ", size(PCsTrain_raw))
println("Size of tTrain_raw: ", size(tTrain_raw))
println("Size of PCsVal_raw: ", size(PCsVal_raw))
println("Size of tVal_raw: ", size(tVal_raw))

# Vérifier et corriger les dimensions si nécessaire
# On veut: (n_modes, n_timesteps)
if size(PCsTrain_raw, 1) == 150 && size(PCsTrain_raw, 2) != 150
    PCsTrain_full = PCsTrain_raw'
    println("Données transposées")
else
    PCsTrain_full = PCsTrain_raw
end

# tTrain doit être un vecteur de temps
if length(tTrain_raw) == size(PCsTrain_full, 1)
    # Créer un vecteur de temps artificiel
    tTrain = Float32.(0:size(PCsTrain_full, 2)-1)
    println("Vecteur de temps créé artificiellement")
else
    tTrain = tTrain_raw
end

println("\nDimensions finales:")
println("PCsTrain_full: ", size(PCsTrain_full), " (modes × temps)")
println("tTrain: ", length(tTrain), " pas de temps")

# IMPORTANT: Split temporel car Val a des dimensions différentes
split_idx = Int(floor(0.8 * size(PCsTrain_full, 2)))
PCsTrain = PCsTrain_full[:, 1:split_idx]
PCsVal = PCsTrain_full[:, split_idx+1:end]

tTrain_split = tTrain[1:split_idx]
tVal_split = tTrain[split_idx+1:end]

println("\nAprès split 80/20:")
println("PCsTrain: ", size(PCsTrain))
println("PCsVal: ", size(PCsVal))
println("tTrain_split: ", length(tTrain_split))
println("tVal_split: ", length(tVal_split))

# Paramètres du modèle
z0 = PCsTrain[:, 1]
tspan = (tTrain[1], tTrain[end])
r = size(PCsTrain, 1)

println("\nNombre de modes (r): ", r)

# Créer le réseau de neurones
nn = Chain(
    Dense(r, 8, tanh),
    Dense(8, r)
)

rng = Random.default_rng()
ps, st = Lux.setup(rng, nn)
ps = ComponentArray(ps)

function f!(dz, z, p, t)
    y, _ = nn(z, p, st)
    @assert length(y) == length(dz) "Dimension mismatch: length(y)=$(length(y)), length(dz)=$(length(dz))"
    dz .= y
end

dz = similar(z0)
y, _ = nn(z0, ps, st)

@show length(z0)
@show length(dz)
@show length(y)

f!(dz, z0, ps, tTrain[1])

prob = ODEProblem(f!, z0, tspan, ps)

Available keys in reducedState:["PCsVal", "tVal", "PCsTrain", "tTrain", "std"]

Size of PCsTrain_raw: (2922, 150)
Size of tTrain_raw: (2922,)
Size of PCsVal_raw: (730, 150)
Size of tVal_raw: (730,)
Vecteur de temps créé artificiellement

Dimensions finales:
PCsTrain_full: (2922, 150) (modes × temps)
tTrain: 150 pas de temps

Après split 80/20:
PCsTrain: (2922, 120)
PCsVal: (2922, 30)
tTrain_split: 120
tVal_split: 30

Nombre de modes (r): 2922
length(z0) = 2922
length(dz) = 2922
length(y) = 2922


[38;2;86;182;194mODEProblem[0m with uType [38;2;86;182;194mVector{Float32}[0m and tType [38;2;86;182;194mFloat32[0m. In-place: [38;2;86;182;194mtrue[0m
Non-trivial mass matrix: [38;2;86;182;194mfalse[0m
timespan: (0.0f0, 149.0f0)
u0: 2922-element Vector{Float32}:
 -0.2329068
 -0.21520926
 -0.3013477
 -0.32271194
 -0.43515974
 -0.93329793
 -1.11059
 -1.1857773
 -1.3257517
 -1.4246912
 -1.6363107
 -1.6228559
 -1.6361544
  ⋮
 -1.0148848
 -1.6231638
 -2.0417738
 -2.3244388
 -2.0626087
 -1.2057661
 -1.0863514
 -1.0674461
 -1.0999924
 -1.2391645
 -1.3900986
 -1.134983

In [25]:

Ztrain = PCsTrain
Zval   = PCsVal

Ztrain = transpose(Ztrain)  # size (2922, 150)
Zval = transpose(Zval)      # size (730, 150)

@assert size(Ztrain, 1) == 150
@assert size(Ztrain, 2) == length(tTrain)

z0Train = Ztrain[:, 1]
z0Val   = Zval[:, 1]

tSaveTrain = tTrain              # length = 2922
tspanTrain = (tTrain[1], tTrain[end])

tSaveVal = tVal
tspanVal = (tVal[1], tVal[end])

r = size(Ztrain, 1)   # = 150 modes
T = size(Ztrain, 2)   # = 2922 temps

println("Configuration finale:")
println("  Ztrain: ", size(Ztrain))
println("  Zval: ", size(Zval))
println("  tSaveTrain: ", length(tSaveTrain))
println("  tSaveVal: ", length(tSaveVal))

LoadError: AssertionError: size(Ztrain, 1) == 150

In [17]:
# Créer les templates de problèmes
probTrainTemplate = ODEProblem(f!, z0Train, tspanTrain, nothing)
probValTemplate   = ODEProblem(f!, z0Val,   tspanVal,   nothing)

[38;2;86;182;194mODEProblem[0m with uType [38;2;86;182;194mVector{Float32}[0m and tType [38;2;86;182;194mFloat32[0m. In-place: [38;2;86;182;194mtrue[0m
Non-trivial mass matrix: [38;2;86;182;194mfalse[0m
timespan: (120.0f0, 149.0f0)
u0: 2922-element Vector{Float32}:
  1.9507536
  2.0597448
  1.563616
  1.3165294
  0.3310595
 -0.6211111
 -0.6341623
 -0.68685955
  0.6022188
 -0.17954591
  0.10940971
 -0.2102112
 -1.0187709
  ⋮
 -1.1079515
 -1.7380584
 -4.4323564
 -1.3724302
  2.1139307
  2.8146522
  0.2066283
 -1.5705686
 -1.9603825
 -2.960355
 -1.707478
  0.1362139

In [18]:
# Fonction de prédiction
function predict(ps, probTemplate, tSave)
    prob = remake(probTemplate; p=ps)
    sol = solve(
        prob,
        Tsit5(),
        saveat = tSave,
        sensealg = InterpolatingAdjoint(),
        verbose = false
    )
    return Array(sol)
end

predict (generic function with 1 method)

In [19]:
# Fonction de perte
function loss(ps)
    Ẑ = predict(ps, probTrainTemplate, tSaveTrain)
    @assert size(Ẑ) == size(Ztrain) "Size mismatch: Ẑ=$(size(Ẑ)), Ztrain=$(size(Ztrain))"
    return sum(abs2, Ẑ .- Ztrain) / length(Ztrain)
end

println("Test de la fonction de perte...")
@time loss_val = loss(ps)
println("Loss initiale: ", loss_val)

Test de la fonction de perte...
  0.525573 seconds (1.20 M allocations: 85.202 MiB)
Loss initiale: 4200.4395


In [22]:
# Optimisation
optf = OptimizationFunction(
    (x, p) -> loss(x),
    Optimization.AutoZygote()
)

optprob = OptimizationProblem(optf, ps)

res = Optimization.solve(
    optprob,
    ADAM(1e-3),
    maxiters = 30
)

ps = res.u

[0mComponentVector{Float32}(layer_1 = (weight = Float32[-0.034508582 -0.046859886 … 0.051330145 -0.01808001; 0.012341956 0.009533996 … 0.021751324 -0.030476851; … ; 0.028706724 0.04866908 … -0.057179652 0.015467573; -0.029961145 -0.039770395 … 0.050584935 -0.023534976], bias = Float32[0.015047729, -0.016602535, 0.012948913, -0.011682161, 0.006463981, 0.0022086669, -0.00736034, 0.008705462]), layer_2 = (weight = Float32[-0.048493292 0.5388912 … 0.2159935 -0.21262634; 0.018541625 -0.6017983 … 0.5639847 0.54178035; … ; -0.35061815 0.005476359 … -0.35028172 -0.53401726; 0.069593556 -0.37267753 … 0.4736618 0.02154415], bias = Float32[0.007100861, -0.25438187, 0.1314209, 0.035487477, -0.1633043, -0.053435132, 0.21019304, -0.2807542, 0.09411444, 0.04273678  …  0.13469133, 0.21509212, 0.20696883, -0.15809868, 0.06361971, 0.03007162, 0.051801924, 0.19577952, -0.12647538, 0.2974972]))

In [23]:
# Validation
println("Prédiction sur les données de validation...")
ẐVal = predict(ps, probValTemplate, tSaveVal)

println("Dimensions:")
println("  ẐVal: ", size(ẐVal))
println("  Zval: ", size(Zval))

@assert size(ẐVal) == size(Zval) "Dimension mismatch!"

valError = norm(ẐVal .- Zval) / norm(Zval)
println("\n✓ Relative validation error = ", valError)

err_mode = vec(norm.(eachrow(ẐVal .- Zval)) ./ norm.(eachrow(Zval)))
for i in 1:r
    println("  Mode $i relative error: ", err_mode[i])
end

Prédiction sur les données de validation...
Dimensions:
  ẐVal: (2922, 30)
  Zval: (2922, 30)

✓ Relative validation error = 15.889517
  Mode 1 relative error: 12.5519495
  Mode 2 relative error: 1.1054728
  Mode 3 relative error: 2.726629
  Mode 4 relative error: 10.203815
  Mode 5 relative error: 19.58934
  Mode 6 relative error: 14.780125
  Mode 7 relative error: 9.949944
  Mode 8 relative error: 7.81066
  Mode 9 relative error: 6.791775
  Mode 10 relative error: 6.198534
  Mode 11 relative error: 38.111645
  Mode 12 relative error: 12.376015
  Mode 13 relative error: 5.180577
  Mode 14 relative error: 1.2513075
  Mode 15 relative error: 9.732062
  Mode 16 relative error: 1.4304824
  Mode 17 relative error: 11.485823
  Mode 18 relative error: 14.093328
  Mode 19 relative error: 17.966267
  Mode 20 relative error: 8.874242
  Mode 21 relative error: 1.2367437
  Mode 22 relative error: 1.7746804
  Mode 23 relative error: 4.6578517
  Mode 24 relative error: 6.1886206
  Mode 25 relative 

In [1]:
############################################################
# NEURAL ODE ROM — VERSION STABLE & RAPIDE
############################################################

using NPZ, LinearAlgebra, Random, Statistics
using Lux, ComponentArrays
using OrdinaryDiffEq, SciMLSensitivity
using Optimization, OptimizationOptimisers

############################################################
# 1) LOAD DATA
############################################################

reducedState = npzread("data/processed/sstReducedStateCOPERNICUS20102019Prepared.npz")

PCsTrain_raw = Float32.(reducedState["PCsTrain"])
tTrain_raw   = Float32.(reducedState["tTrain"])

println("Raw PCsTrain size = ", size(PCsTrain_raw))
println("Raw tTrain length = ", length(tTrain_raw))

# Auto-detect orientation → want Z[mode, time]
if size(PCsTrain_raw, 2) == length(tTrain_raw)
    Zfull = PCsTrain_raw
    println("PCsTrain already (modes × time)")
elseif size(PCsTrain_raw, 1) == length(tTrain_raw)
    Zfull = PCsTrain_raw'
    println("PCsTrain transposed to (modes × time)")
else
    error("Cannot infer orientation of PCsTrain")
end

tfull = tTrain_raw

@assert size(Zfull, 2) == length(tfull)

println("Final Zfull size = ", size(Zfull), " (modes × time)")

############################################################
# 2) TEMPORAL SPLIT (80 / 20)
############################################################

T = size(Zfull, 2)
split_idx = Int(floor(0.8 * T))

Ztrain_full = Zfull[:, 1:split_idx]
Zval        = Zfull[:, split_idx+1:end]

tTrain_full = tfull[1:split_idx]
tVal        = tfull[split_idx+1:end]

println("Ztrain_full size = ", size(Ztrain_full))
println("Zval size        = ", size(Zval))

############################################################
# 3) SHORT HORIZON (CRUCIAL POUR LA VITESSE)
############################################################

Tshort = 400   # <<< NE PAS AUGMENTER AU DÉBUT

Ztrain = Ztrain_full[:, 1:Tshort]
tTrain = tTrain_full[1:Tshort]

############################################################
# 4) NORMALISATION PAR MODE
############################################################

std_mode = vec(std(Ztrain; dims=2))      # (n_modes,)

Ztrain .= Ztrain ./ reshape(std_mode, :, 1)
Zval   .= Zval   ./ reshape(std_mode, :, 1)

############################################################
# 5) NEURAL ODE DEFINITION
############################################################

r = size(Ztrain, 1)   # number of modes

z0Train = Ztrain[:, 1]
z0Val   = Zval[:, 1]

tspanTrain = (tTrain[1], tTrain[end])
tspanVal   = (tVal[1], tVal[end])

nn = Chain(
    Dense(r, 32, tanh),
    Dense(32, r)
)

rng = Random.default_rng()
ps, st = Lux.setup(rng, nn)
ps = ComponentArray(ps)

λ = 0.1f0   # coefficient de dissipation (à ajuster)

function f!(dz, z, p, t)
    y, _ = nn(z, p, st)
    @. dz = y - λ * z
end

probTrainTemplate = ODEProblem(f!, z0Train, tspanTrain, nothing)
probValTemplate   = ODEProblem(f!, z0Val,   tspanVal,   nothing)

############################################################
# 6) PREDICTION & LOSS
############################################################

function predict(ps, probTemplate, tSave)
    prob = remake(probTemplate; p=ps)
    sol = solve(
        prob,
        Tsit5(),
        saveat = tSave,
        dense = false,   # <<< IMPORTANT
        sensealg = InterpolatingAdjoint(autodiff=ZygoteVJP()),
        verbose = false
    )
    return Array(sol)
end

function loss(ps)
    Ẑ = predict(ps, probTrainTemplate, tTrain)
    return sum(abs2, Ẑ .- Ztrain) / length(Ztrain)
end

println("Warm-up loss (compile):")
@time println("Initial loss = ", loss(ps))

############################################################
# 7) OPTIMISATION
############################################################

optf = OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote())
optprob = OptimizationProblem(optf, ps)

res = Optimization.solve(
    optprob,
    ADAM(1e-3),
    maxiters = 20
)

ps = res.u

############################################################
# 8) VALIDATION (LONG HORIZON)
############################################################

println("Validation rollout...")

ẐVal = predict(ps, probValTemplate, tVal)

@assert size(ẐVal) == size(Zval)

valError = norm(ẐVal .- Zval) / norm(Zval)
println("✓ Relative validation error = ", valError)

err_mode = vec(norm.(eachrow(ẐVal .- Zval)) ./ norm.(eachrow(Zval)))

println("Per-mode relative error (first 10 modes):")
for i in 1:min(10, r)
    println("  Mode $i : ", err_mode[i])
end


Raw PCsTrain size = (2922, 150)
Raw tTrain length = 2922
PCsTrain transposed to (modes × time)
Final Zfull size = (150, 2922) (modes × time)
Ztrain_full size = (150, 2337)
Zval size        = (150, 585)
Warm-up loss (compile):
Initial loss = 91.197395
  4.411494 seconds (10.97 M allocations: 726.169 MiB, 12.26% gc time, 99.93% compilation time)


[33m[1m│ [22m[39m
[33m[1m│ [22m[39m1. If this was not the desired behavior overload the dispatch on `m`.
[33m[1m│ [22m[39m
[33m[1m│ [22m[39m2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
[33m[1m└ [22m[39m[90m@ ArrayInterfaceReverseDiffExt ~/.julia/packages/LuxCore/kQC9S/ext/ArrayInterfaceReverseDiffExt.jl:9[39m


Validation rollout...
✓ Relative validation error = 9.283604
Per-mode relative error (first 10 modes):
  Mode 1 : 7.985529
  Mode 2 : 8.934202
  Mode 3 : 9.867152
  Mode 4 : 9.028579
  Mode 5 : 5.5766478
  Mode 6 : 8.635152
  Mode 7 : 4.527218
  Mode 8 : 18.321096
  Mode 9 : 5.958862
  Mode 10 : 6.6799955


In [7]:
ẐTrain = predict(ps, probTrainTemplate, tTrain)

trainError = norm(ẐTrain .- Ztrain) / norm(Ztrain)
println("Train relative error (Tshort) = ", trainError)


Train relative error (Tshort) = 6.7249575


In [1]:
############################################################
# NEURAL SDE DISSIPATIF — VERSION STABLE
############################################################

using NPZ, LinearAlgebra, Random, Statistics
using Lux, ComponentArrays
using OrdinaryDiffEq, StochasticDiffEq, SciMLSensitivity
using Optimization, OptimizationOptimisers

############################################################
# 1) LOAD DATA
############################################################

reducedState = npzread("data/processed/sstReducedStateCOPERNICUS20102019Prepared.npz")

PCsTrain_raw = Float32.(reducedState["PCsTrain"])
tTrain_raw   = Float32.(reducedState["tTrain"])

println("Raw PCsTrain size = ", size(PCsTrain_raw))
println("Raw tTrain length = ", length(tTrain_raw))

# Auto-detect orientation → Z[mode, time]
if size(PCsTrain_raw, 2) == length(tTrain_raw)
    Zfull = PCsTrain_raw
    println("PCsTrain already (modes × time)")
elseif size(PCsTrain_raw, 1) == length(tTrain_raw)
    Zfull = PCsTrain_raw'
    println("PCsTrain transposed to (modes × time)")
else
    error("Cannot infer orientation of PCsTrain")
end

tfull = tTrain_raw
@assert size(Zfull, 2) == length(tfull)

println("Final Zfull size = ", size(Zfull), " (modes × time)")

############################################################
# 2) TEMPORAL SPLIT (80 / 20)
############################################################

T = size(Zfull, 2)
split_idx = Int(floor(0.8 * T))

Ztrain_full = Zfull[:, 1:split_idx]
Zval        = Zfull[:, split_idx+1:end]

tTrain_full = tfull[1:split_idx]
tVal        = tfull[split_idx+1:end]

println("Ztrain_full size = ", size(Ztrain_full))
println("Zval size        = ", size(Zval))

############################################################
# 3) SHORT HORIZON (CRUCIAL)
############################################################

Tshort = 200          # <<< IMPORTANT : court horizon
Ztrain = Ztrain_full[:, 1:Tshort]
tTrain = tTrain_full[1:Tshort]

############################################################
# 4) NORMALISATION PAR MODE
############################################################

std_mode = vec(std(Ztrain; dims=2))      # (n_modes,)
Ztrain .= Ztrain ./ reshape(std_mode, :, 1)
Zval   .= Zval   ./ reshape(std_mode, :, 1)

############################################################
# 5) NEURAL SDE DEFINITION
############################################################

r = size(Ztrain, 1)

z0Train = Ztrain[:, 1]
z0Val   = Zval[:, 1]

tspanTrain = (tTrain[1], tTrain[end])
tspanVal   = (tVal[1], tVal[end])

# Neural network for drift
nn = Chain(
    Dense(r, 32, tanh),
    Dense(32, r)
)

rng = Random.default_rng()
ps, st = Lux.setup(rng, nn)
ps = ComponentArray(ps)

λ = 0.1f0  # dissipation

function f!(dz, z, p, t)
    y, _ = nn(z, p, st)
    dz .= y .- λ .* z
end


σ = 0.05f0  # intensité du bruit

function g!(dz, z, p, t)
    @inbounds for i in eachindex(z)
        dz[i] = σ
    end
end


probTrainTemplate = SDEProblem(
    f!,
    g!,
    z0Train,
    tspanTrain,
    nothing
)

probValTemplate = SDEProblem(
    f!,
    g!,
    z0Val,
    tspanVal,
    nothing
)

############################################################
# 6) PREDICTION & LOSStspan = (tTrain[1], tTrain[end])
############################################################

function predict(ps, probTemplate, tSave)
    prob = remake(probTemplate; p = ps)

    sol = solve(
        prob,
        EM(),                      # ← ICI LE FIX
        saveat = tSave,
        sensealg = BacksolveAdjoint(),
        dt = 0.01f0,               # important pour stabilité
        verbose = false
    )

    return Array(sol)
end



function loss(ps)
    Ẑ = predict(ps, probTrainTemplate, tTrain)
    return sum(abs2, Ẑ .- Ztrain) / length(Ztrain)
end

############################################################
# 7) WARM-UP & TRAINING
############################################################

Random.seed!(42)   # IMPORTANT for SDE reproducibility

println("Warm-up loss (compile):")
@time println("Initial loss = ", loss(ps))

optf = OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote())
optprob = OptimizationProblem(optf, ps)

res = Optimization.solve(
    optprob,
    ADAM(1e-3),
    maxiters = 10
)

ps = res.u

############################################################
# 8) VALIDATION (LONG HORIZON)
############################################################

println("Validation rollout...")

ẐVal = predict(ps, probValTemplate, tVal)
@assert size(ẐVal) == size(Zval)

valError = norm(ẐVal .- Zval) / norm(Zval)
println("✓ Relative validation error = ", valError)

err_mode = vec(norm.(eachrow(ẐVal .- Zval)) ./ norm.(eachrow(Zval)))

println("Per-mode relative error (first 10 modes):")
for i in 1:min(10, r)
    println("  Mode $i : ", err_mode[i])
end


Raw PCsTrain size = (2922, 150)
Raw tTrain length = 2922
PCsTrain transposed to (modes × time)
Final Zfull size = (150, 2922) (modes × time)
Ztrain_full size = (150, 2337)
Zval size        = (150, 585)
Warm-up loss (compile):
Initial loss = 45.74937
  5.469410 seconds (11.40 M allocations: 743.704 MiB, 12.31% gc time, 92.54% compilation time)


In [1]:
using NPZ

# We open the reduced state previously prepared in Python
reducedState = npzread("data/processed/sstReducedStateCOPERNICUS20102019Prepared.npz")

# We reconstruct vars from reduced state

PCs = cat


Dict{String, Array{Float32}} with 5 entries:
  "PCsVal"   => [0.716095 -0.150198 … 1.68743 0.360322; 0.620925 -0.289956 … 1.…
  "tVal"     => [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0  …  720.0, 72…
  "PCsTrain" => [-0.232907 -0.0829777 … -1.66511 0.844336; -0.215209 -0.113246 …
  "tTrain"   => [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0  …  2912.0, 2…
  "std"      => [56.027, 13.5325, 9.16372, 6.82113, 5.76459, 5.05414, 4.66145, …