In [None]:
using OrdinaryDiffEq
using SciMLSensitivity
using Enzyme
using LinearAlgebra

struct VGParam{T}
    θ_s::T
    θ_r::T
    Ks::T
    α::T
    n::T
end

struct SoilConfig{T}
    N::Int
    Δz::Vector{T}
    Δz_half::Vector{T}
    depths::Vector{T}
end

struct SoilCache{T}
    K::Vector{T}
    K_half::Vector{T}
    ψ::Vector{T}
    Q::Vector{T}
    sink::Vector{T}
end

function make_soil(n_layers::Int, dz_val::T) where T
    depths = collect(range(dz_val, length=n_layers, step=dz_val))
    Δz = fill(dz_val, n_layers)
    Δz_half = copy(Δz) 
    
    config = SoilConfig(n_layers, Δz, Δz_half, depths)
    cache = SoilCache(
        zeros(T, n_layers), zeros(T, n_layers), zeros(T, n_layers),
        zeros(T, n_layers + 1), zeros(T, n_layers)
    )
    return config, cache
end

# ==========================================
# 2. 物理方程实现
# ==========================================

@inline function vg_hydraulic(θ, p::VGParam{T}) where T
    S_e = (θ - p.θ_r) / (p.θ_s - p.θ_r)
    S_e = clamp(S_e, T(1e-4), T(1.0 - 1e-4))
    m = 1 - 1/p.n
    head = - (1/p.α) * (S_e^(-1/m) - 1)^(1/p.n)
    denom = (1 - S_e^(1/m))^m
    k_val = p.Ks * sqrt(S_e) * (1 - denom)^2
    return k_val, head
end

function cal_Q!(cache::SoilCache, config::SoilConfig, p_matrix::AbstractMatrix{T}, θ, t) where T
    N = config.N
    
    # 更新 K 和 ψ
    @inbounds for i in 1:N
        # 栈上临时构造结构体，零开销
        param = VGParam(
            p_matrix[1, i], p_matrix[2, i], p_matrix[3, i], p_matrix[4, i], p_matrix[5, i]
        )
        
        K, ψ = vg_hydraulic(θ[i], param)
        cache.K[i] = K
        cache.ψ[i] = ψ
    end

    @inbounds for i in 1:N-1
        cache.K_half[i] = (cache.K[i] + cache.K[i+1]) / 2 
    end

    cache.Q[1] = -1.0
    @inbounds for i in 1:N-1
        dψ_dz = (cache.ψ[i+1] - cache.ψ[i]) / config.Δz_half[i] 
        cache.Q[i+1] = -cache.K_half[i] * ( dψ_dz + 1.0 )
    end
    cache.Q[N+1] = -cache.K[N]
    return nothing
end

struct RichardsSystem{C, M}
    config::C
    cache::M
end

function (sys::RichardsSystem)(dθ, θ, p, t)
    config = sys.config
    cache = sys.cache
    
    cal_Q!(cache, config, p, θ, t)
    
    dθ[1] = -(cache.Q[2] - cache.Q[1]) / config.Δz[1]
    N = config.N
    @inbounds for i in 2:N
        dθ[i] = -(cache.Q[i+1] - cache.Q[i]) / config.Δz[i]
    end
    return nothing
end

In [None]:
n_layers = 10
config, cache = make_soil(n_layers, 5.0)
sys = RichardsSystem(config, cache)
θ0 = fill(0.3, n_layers)
tspan = (0.0, 2.0)

function loss_function(x)
    val_Ks = x[1]
    val_n  = x[2]
    
    # 构造参数矩阵 (Active)
    p_matrix = zeros(eltype(x), 5, n_layers)
    for i in 1:n_layers
        p_matrix[1, i] = 0.45
        p_matrix[2, i] = 0.05
        p_matrix[3, i] = val_Ks
        p_matrix[4, i] = 0.01
        p_matrix[5, i] = val_n
    end
    
    prob = ODEProblem(sys, θ0, tspan, p_matrix)
    # sensalg = InterpolatingAdjoint(autojacvec=EnzymeVJP())
    sol = solve(prob, Tsit5(), saveat=0.5, abstol=1e-4, reltol=1e-4)
    
    # 不要用 sum(sol)，这会触发 VectorOfArray 的复杂逻辑导致 Enzyme 崩溃
    # 使用 Array(sol) 将其转化为普通 Matrix，或者 sum(sum.(sol.u))
    return sum(Array(sol)) 
    # sum(sol)
end

p_in = [8.0, 1.8]
loss_function(p_in)

import Zygote
@time du0, dp = Zygote.gradient(loss_function, p_in)
# dp = make_zero(p_in)

# # # 使用 Const(loss_function)
# Enzyme.autodiff(
#     set_runtime_activity(Reverse), 
#     loss_function, 
#     Enzyme.Active, 
#     Enzyme.Duplicated(p_in, dp)
# )

# println("dLoss / dKs = ", dp[1])
# println("dLoss / dn  = ", dp[2])


LoadError: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Matrix{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Matrix{Float64}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Matrix{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Matrix{Float64}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Float64}, Float64, 12})
The type `Float64` exists, but no method is defined for this combination of argument types when trying to construct it.

[0mClosest candidates are:
[0m  (::Type{T})(::Real, [91m::RoundingMode[39m) where T<:AbstractFloat
[0m[90m   @[39m [90mBase[39m [90m[4mrounding.jl:265[24m[39m
[0m  (::Type{T})(::T) where T<:Number
[0m[90m   @[39m [90mCore[39m [90m[4mboot.jl:900[24m[39m
[0m  Float64([91m::UInt128[39m)
[0m[90m   @[39m [90mBase[39m [90m[4mfloat.jl:260[24m[39m
[0m  ...
