In [1]:
import OrdinaryDiffEq as ODE
import SciMLSensitivity as SMS
import Zygote
import Enzyme
import Enzyme: make_zero

function fiip(du, u, p, t)
    du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2]
    du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2]
end

p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0; 1.0]
prob = ODE.ODEProblem(fiip, u0, (0.0, 10.0), p)
sol = ODE.solve(prob, ODE.Tsit5(); saveat=0.1)


retcode: Success
Interpolation: 1st order linear
t: 101-element Vector{Float64}:
  0.0
  0.1
  0.2
  0.3
  0.4
  0.5
  0.6
  0.7
  0.8
  0.9
  1.0
  1.1
  1.2
  ⋮
  8.9
  9.0
  9.1
  9.2
  9.3
  9.4
  9.5
  9.6
  9.7
  9.8
  9.9
 10.0
u: 101-element Vector{Vector{Float64}}:
 [1.0, 1.0]
 [1.0610780673356448, 0.8210842775886172]
 [1.1440276717257591, 0.6790526689784506]
 [1.2491712125724797, 0.5668931465840906]
 [1.377644570563644, 0.47881295137947383]
 [1.5312308177480016, 0.41015646708682607]
 [1.7122697558187607, 0.35726544879975486]
 [1.9235782758300233, 0.31734720616177275]
 [2.1683910896990164, 0.2883888437879505]
 [2.45025066714094, 0.26905370939633366]
 [2.7728223025931613, 0.2587244160559188]
 [3.139732989414826, 0.25749669323911817]
 [3.5539013554938403, 0.26645005590015974]
 ⋮
 [4.346171302478965, 4.204698054263225]
 [3.246586345347383, 4.546928376347062]
 [2.3956662257437724, 4.457765582199836]
 [1.817282293373577, 4.064946558412839]
 [1.4427612812619388, 3.5397375370727238]


In [4]:
sum(sol)

450.64377243968494

In [None]:
function loss(u0, p, prob; kw...)
    # sol = ODE.solve(prob, ODE.Tsit5(); u0, p, saveat=0.1, sensealg=sens_alg)
    new_prob = ODE.remake(prob, u0=u0, p=p)
    sol = ODE.solve(new_prob, ODE.Tsit5(); saveat=0.1, kw...)
    return sum(sol)
end

In [15]:
sol = ODE.solve(prob, ODE.Tsit5())
loss(u0, p) = sum(ODE.solve(prob, ODE.Tsit5(), u0 = u0, p = p, saveat = 0.1))
@time du0, dp = Zygote.gradient(loss, u0, p)

  0.037141 seconds (105.90 k allocations: 5.972 MiB, 83.29% compilation time: 100% of which was recompilation)


([-39.49430990220676, -8.631888332498164], [7.349038361940163, -159.31079867552987, 74.93924771004801, -339.327238254502])

In [19]:
du0 = make_zero(u0)  # 用于存放 u0 的梯度
dp = make_zero(p)    # 用于存放 p 的梯度

# Active: 表示函数的返回值是“活性的”（即我们要对这个标量返回值求导，初始导数为1.0）
@time Enzyme.autodiff(
    Enzyme.Reverse, 
    loss, 
    Enzyme.Active, 
    Enzyme.Duplicated(u0, du0), 
    Enzyme.Duplicated(p, dp), 
    Enzyme.Const(prob)
)

# 5. 查看结果
println("Gradient w.r.t u0: ", du0)
println("Gradient w.r.t p:  ", dp)

  0.000634 seconds (4.04 k allocations: 246.906 KiB)
Gradient w.r.t u0: [-39.49430990220676, -8.631888332498164]
Gradient w.r.t p:  [7.349038361940163, -159.31079867552987, 74.93924771004801, -339.327238254502]


## 2 InterpolatingAdjoint

## 2.1 Zygote

In [37]:
# sens_alg = SMS.InterpolatingAdjoint(autojacvec=SMS.EnzymeVJP())
sens_alg = SMS.InterpolatingAdjoint(autojacvec=SMS.ZygoteVJP())
# sol = ODE.solve(prob, ODE.Tsit5())
loss_adjoint_zygote(u0, p) = sum(ODE.solve(prob, ODE.Tsit5(); u0 = u0, p = p, saveat = 0.1, sensealg=sens_alg))

loss_adjoint_zygote (generic function with 1 method)

In [42]:
@time du0, dp = Zygote.gradient(loss_adjoint_zygote, u0, p)

  0.027435 seconds (318.95 k allocations: 15.350 MiB)


([-38.55581458857458, -8.845868879971489], [9.074401626113048, -159.68322950050717, 75.14436266093547, -338.48149414140573])

## 2.2 Enzyme

In [30]:
function loss_adjoint_enzyme(u0, p, prob)
    sensealg = SMS.InterpolatingAdjoint(autojacvec=SMS.EnzymeVJP())
    new_prob = ODE.remake(prob, u0=u0, p=p)
    sol = ODE.solve(new_prob, ODE.Tsit5(); saveat=0.1, sensealg)
    return sum(sol)
end

loss_adjoint_enzyme (generic function with 1 method)

In [36]:
du0 = make_zero(u0)  # 用于存放 u0 的梯度
dp = make_zero(p)    # 用于存放 p 的梯度

# 4. 执行 Enzyme 自动微分
@time Enzyme.autodiff(
    Enzyme.Reverse, 
    loss_adjoint_enzyme, 
    Enzyme.Active, 
    Enzyme.Duplicated(u0, du0), 
    Enzyme.Duplicated(p, dp), 
    Enzyme.Const(prob)
)

println("Gradient w.r.t u0: ", du0)
println("Gradient w.r.t p:  ", dp)

  0.000650 seconds (3.52 k allocations: 274.289 KiB)
Gradient w.r.t u0: [-38.555814588574385, -8.845868879971427]
Gradient w.r.t p:  [9.074401626113332, -159.68322950050703, 75.14436266093549, -338.4814941414056]
