# Sensitivity Approach Comparison

For more details, see https://doi.org/10.1109/HPEC49654.2021.9622796)

## Lotka-Volterra Model

The first test problem is LV, the non-stiff Lotka-Volterra model
$$
\begin{aligned}
\frac{dx}{dt} &= p_1 x - p_2 x y \\
\frac{dy}{dt} &= -p_3 y + x y
\end{aligned}
$$
with initial condition $x_0 = [1.0, 1.0]$ and $p = [1.5, 1.0, 3.0]$.

In [69]:
using OrdinaryDiffEq

function f(du, u, p, t)
    du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
    du[2] = dy = -p[3]*u[2] + u[1]*u[2]
end
p = [1.5,1.0,3.0]; tspan = (0.0, 10.0); u0 = [1.0,1.0]; 
prob = ODEProblem(f, u0, tspan, p)
sol = solve(prob, Tsit5(), u0=u0, p=p, abstol=1e-6, reltol=1e-6, tstops=0:0.1:10.0)

nothing #hide

To test the performance of each sensitivity analysis method, we use an $L^2$ loss function sampled at 100 evenly space points.

In [70]:
using BenchmarkTools
import Zygote, DiffEqSensitivity

In [71]:
# Forward Continuous Sensitivity Analysis (CSA)
function sum_of_solution(u0,p)
    _prob = remake(prob,u0=u0,p=p)
    sensealg = DiffEqSensitivity.ForwardSensitivity()
    sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1,sensealg=sensealg))
end
@btime du0_f, dp_f = Zygote.gradient(sum_of_solution,u0,p)

  319.327 μs (1136 allocations: 103.25 KiB)


(NotImplemented(DiffEqSensitivity, [90m#= /home/taylor/.julia/dev/DiffEqSensitivity/src/concrete_solve.jl:285 =#[39m, ForwardSensitivity does not differentiate with respect to u0. Change your sensealg.), [8.304646384317994, -159.4841294506934, 75.20331178751347])

In [72]:
# Reverse Continuous Sensitivity Analysis (CSA)
function sum_of_solution(u0,p)
    _prob = remake(prob,u0=u0,p=p)
    sensealg = DiffEqSensitivity.QuadratureAdjoint(compile=true)
    sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1,sensealg=sensealg))
end
@btime du0_r, dp_r = Zygote.gradient(sum_of_solution,u0,p)


  7.149 ms (134106 allocations: 6.91 MiB)


([-39.126103250526825, -8.787925705566884], [8.307610397522009, -159.4845962237941, 75.20354297006898])

In [73]:
# Forward Discrete Sensitivity Analysis using Automatic Differentiation (DSAAD)
function sum_of_solution(u0,p)
    _prob = remake(prob,u0=u0,p=p)
    sensealg = DiffEqSensitivity.ForwardDiffSensitivity(convert_tspan=false)
    sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1,sensealg=sensealg))
end
@btime du0_fd, dp_fd = Zygote.gradient(sum_of_solution,u0,p)


  589.976 μs (4319 allocations: 588.91 KiB)


([-39.1277375272366, -8.78749543437519], [8.304379835268392, -159.48405298061107, 75.20321406770285])

In [74]:
# Reverse Discrete Sensitivity Analysis using Automatic Differentiation (DSAAD)
function sum_of_solution(u0,p)
    _prob = remake(prob,u0=u0,p=p)
    sensealg = DiffEqSensitivity.ReverseDiffAdjoint()
    sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1,sensealg=sensealg))
end
@btime du0_rd, dp_rd = Zygote.gradient(sum_of_solution,u0,p)

  7.331 ms (171223 allocations: 7.18 MiB)


([-39.127714278024065, -8.78758072915774], [8.304155122753292, -159.48406211600587, 75.20312378752807])

In [75]:
import DiscreteAdjoint

dg(out,u,p,t,i) = out .= 1
ts = 0:0.5:10.0

0.0:0.1:10.0

In [76]:
# Discrete Sensitivity Analysis (ForwardDiff VJP)
autojacvec = DiscreteAdjoint.ForwardDiffVJP()
@btime dp_fd, du0_fd = DiscreteAdjoint.discrete_adjoint(sol, dg, ts; autojacvec)

  503.841 μs (8392 allocations: 837.31 KiB)


([8.305284097140321, -159.48427689520915, 75.20347422609628], [-39.127539710215686, -8.787767575876138])

In [77]:
# Discrete Sensitivity Analysis (ReverseDiff VJP, not compiled)
autojacvec = DiscreteAdjoint.ReverseDiffVJP()
@btime dp_rdc, du0_rdc = DiscreteAdjoint.discrete_adjoint(sol, dg, ts; autojacvec)

  5.403 ms (2939 allocations: 190.33 KiB)


([8.30528409714025, -159.4842768952089, 75.20347422609629], [-39.12753971021566, -8.7877675758761])

In [78]:
# Discrete Sensitivity Analysis (ReverseDiff VJP, compiled)
autojacvec = DiscreteAdjoint.ReverseDiffVJP(true)
@btime dp_rdc, du0_rdc = DiscreteAdjoint.discrete_adjoint(sol, dg, ts; autojacvec)

  1.603 ms (4177 allocations: 248.95 KiB)


([8.30528409714025, -159.4842768952089, 75.20347422609629], [-39.12753971021566, -8.7877675758761])

In [79]:
# Discrete Sensitivity Analysis (Zygote VJP)
autojacvec = DiscreteAdjoint.ZygoteVJP()
@btime dp_z, du0_z = DiscreteAdjoint.discrete_adjoint(sol, dg, ts; autojacvec)

  406.756 ms (1458587 allocations: 112.68 MiB)


([8.305284097140694, -159.48427689520915, 75.20347422609632], [-39.1275397102155, -8.78776757587616])

## Brusselator Model

The second model, BRUSS, is the two dimensional ($N \times N$) Brusselator stiff reaction-diffusion PDE:
$$
\begin{aligned}
\frac{\partial u}{\partial t} &= p_2 + u^2 v - (p_1 + 1) u + p_3 ( \frac{\partial^2 u}{\partial x^2}  + \frac{\partial^2 u}{\partial y^2}) + f(x, y, t) \\
\frac{\partial v}{\partial t} &= p_1 u - u^2 v + p_4 ( \frac{\partial^2 u}{\partial x^2} + \frac{\partial^2 u}{\partial y^2})
\end{aligned}
$$

where

$$
f(x,y,t) = \begin{cases}
5 & \text{if } (x-0.3)^2 + (y-0.6)^2 \leq 0.1^2 \text{ and } t \geq 1.1 \\
0 & \text{else} \\
\end{cases}
$$

with no-flux boundary conditions and $u(0, x, y) = 22(y(1 - y))^{3/2}$ with $v(0, x, y) = 27(x(1 - x))^{3/2}$. This PDE is discretized to a set of $N \times N \times 2$ ODEs using the finite difference method. The parameters are spatially-dependent, $p_i = p_i(x, y)$, making each discretized $p_i$ a $N \times N$ set of values at each discretization point, giving a total of $4 N^2$ parameters. The initial parameter values were the uniform
$p_i(x, y) = [3.4, 1.0, 10.0, 10.0]$



In [22]:
using OrdinaryDiffEq

N = 3

xyd_brusselator = range(0,stop=1,length=N)

dx = step(xyd_brusselator)

brusselator_f(x, y, t) = (((x-0.3)^2 + (y-0.6)^2) <= 0.1^2) * (t >= 1.1) * 5.

limit(a, N) = a == N+1 ? 1 : a == 0 ? N : a

function brusselator_2d_loop(du, u, p, t)
    lu = LinearIndices((1:N, 1:N, 1:2))
    lp = LinearIndices((1:N, 1:N, 1:4))
    @inbounds for I in CartesianIndices((N, N))
        i, j = Tuple(I)
        x, y = xyd_brusselator[I[1]], xyd_brusselator[I[2]]
        ip1, im1, jp1, jm1 = limit(i+1, N), limit(i-1, N), limit(j+1, N), limit(j-1, N)
        du[lu[i,j,1]] = p[lp[i,j,2]] + u[lu[i,j,1]]^2*u[lu[i,j,2]] - (p[lp[i,j,1]] + 1)*u[lu[i,j,1]] + 
            p[lp[i,j,3]]/dx^2*(u[lu[im1,j,1]] + u[lu[ip1,j,1]] + u[lu[i,jp1,1]] + u[lu[i,jm1,1]] - 4u[lu[i,j,1]]) +
            brusselator_f(x, y, t)
        du[lu[i,j,2]] = p[lp[i,j,1]]*u[lu[i,j,1]] - u[lu[i,j,1]]^2*u[lu[i,j,2]] + 
            p[lp[i,j,4]]/dx^2*(u[lu[im1,j,2]] + u[lu[ip1,j,2]] + u[lu[i,jp1,2]] + u[lu[i,jm1,2]] - 4u[lu[i,j,2]])
    end
end

pt = (3.4, 1., 10., 10.)

function init_brusselator_2d(xyd, pt)
    N = length(xyd)
    u0 = zeros(N*N*2)
    p = zeros(N*N*4)
    ru0 = reshape(u0, N, N, 2)
    rp = reshape(p, N, N, 4)
    for I in CartesianIndices((N, N))
        x = xyd[I[1]]
        y = xyd[I[2]]
        ru0[I,1] = 22*(y*(1-y))^(3/2)
        ru0[I,2] = 27*(x*(1-x))^(3/2)
        rp[I,1] = pt[1]
        rp[I,2] = pt[2]
        rp[I,3] = pt[3]
        rp[I,4] = pt[4]
    end
    return u0, p
end

u0, p = init_brusselator_2d(xyd_brusselator, pt)

prob_ode_brusselator_2d = ODEProblem(brusselator_2d_loop,u0,(0.,10.0),p)

sol = solve(prob_ode_brusselator_2d, Tsit5(), abstol=1e-6, reltol=1e-6, tstops=0:0.5:10.0)

nothing #hide

In [23]:
using BenchmarkTools
import Zygote, DiffEqSensitivity

nothing #hide

In [34]:
# Forward Continuous Sensitivity Analysis (CSA)
function sum_of_solution(u0,p)
    _prob = remake(prob_ode_brusselator_2d,u0=u0,p=p)
    sensealg = DiffEqSensitivity.ForwardSensitivity()
    sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.5,sensealg=sensealg))
end
@btime du0_f, dp_f = Zygote.gradient(sum_of_solution, u0, p)

In [None]:
# Reverse Continuous Sensitivity Analysis (CSA)
function sum_of_solution(u0,p)
    _prob = remake(prob_ode_brusselator_2d,u0=u0,p=p)
    sensealg = DiffEqSensitivity.QuadratureAdjoint(compile=true)
    sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.5,sensealg=sensealg))
end
@btime du0_r, dp_r = Zygote.gradient(sum_of_solution,u0,p)

([3.2493383313825404, 3.236483196259513, 3.249338331382539, 3.2360045626247884, 3.1834877703404403, 3.236004562624789, 3.2493383313825377, 3.236483196259518, 3.249338331382535, 4.254412329914135, 4.254189725668024, 4.254412329914137, 4.226753234046294, 4.225378110501217, 4.226753234046294, 4.254412329914135, 4.254189725668022, 4.254412329914136], [27.1406986640924, 27.140965058852945, 27.1406986640924, 27.163283248831114, 27.164471539813192, 27.163283248831007, 27.140698664091815, 27.140965058854032, 27.140698664091918, -61.51918344848566  …  0.20384238383267553, 0.3680141741116232, -0.7301725236481096, 0.36801417411162357, 0.3506605507121283, -0.7128115452047703, 0.3506605507121278, 0.36801417411732107, -0.7301725236594652, 0.36801417411732135])

In [None]:
# Forward Discrete Sensitivity Analysis using Automatic Differentiation (DSAAD)
function sum_of_solution(u0,p)
    _prob = remake(prob_ode_brusselator_2d,u0=u0,p=p)
    sensealg = DiffEqSensitivity.ForwardDiffSensitivity(convert_tspan=false)
    sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.5,sensealg=sensealg))
end
@btime du0_fd, dp_fd = Zygote.gradient(sum_of_solution,u0,p)


([3.2493386445773043, 3.236482031265129, 3.249338644577303, 3.236003322501268, 3.1834902387297004, 3.23600332250127, 3.2493386445773216, 3.2364820312650386, 3.249338644577321, 4.254412077711769, 4.254189522838437, 4.254412077711767, 4.226753644034689, 4.225378056743302, 4.2267536440346865, 4.2544120849575835, 4.254189528981069, 4.254412084957585], [27.140698697048112, 27.140965448983927, 27.14069869704811, 27.163283615556015, 27.164471239250886, 27.16328361555602, 27.140698697048173, 27.14096544898394, 27.140698697048173, -61.519184782765954  …  0.2038340040244969, 0.36801036718169977, -0.7301649058947175, 0.36801036718169977, 0.3506681135342229, -0.7128267300559522, 0.3506681135342229, 0.36801036718184765, -0.7301649058949959, 0.36801036718184765])

In [None]:
# Reverse Discrete Sensitivity Analysis using Automatic Differentiation (DSAAD)
function sum_of_solution(u0,p)
    _prob = remake(prob_ode_brusselator_2d,u0=u0,p=p)
    sensealg = DiffEqSensitivity.ReverseDiffAdjoint()
    sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.5,sensealg=sensealg))
end
@btime du0_rd, dp_rd = Zygote.gradient(sum_of_solution,u0,p)

([3.2493387084272882, 3.2364819320722704, 3.2493387084272882, 3.23600319386994, 3.1834904266203092, 3.2360031938699416, 3.2493387084758205, 3.2364819318728517, 3.2493387084758214, 4.254412066423688, 4.254189573733795, 4.25441206642369, 4.226753742260081, 4.225377949798395, 4.226753742260084, 4.254412066417825, 4.254189573736886, 4.254412066417826], [27.14054580243715, 27.14127124391021, 27.140545802437824, 27.163524558979795, 27.163989388460298, 27.163524558981617, 27.14061070025926, 27.141141454301703, 27.140610700258136, -61.519020266507695  …  0.2038339755958096, 0.36801041868584156, -0.7301651386964594, 0.36801041868584167, 0.3506680157060225, -0.712826313069014, 0.3506680157060228, 0.36801039680719794, -0.7301650399994803, 0.3680103968071976])

In [None]:
import DiscreteAdjoint

dg(out,u,p,t,i) = out .= 1
ts = 0:0.5:10.0

nothing #hide

In [None]:
# Discrete Sensitivity Analysis (ForwardDiff VJP)
autojacvec = DiscreteAdjoint.ForwardDiffVJP()
@btime dp_fd, du0_fd = DiscreteAdjoint.discrete_adjoint(sol, dg, ts; autojacvec)

  672.586 ms (10533302 allocations: 853.00 MiB)


([27.14069629153896, 27.140969918964764, 27.140696291538966, 27.16328812834905, 27.164461801267972, 27.163288128349052, 27.140696291552608, 27.14096991893798, 27.140696291552604, -61.5191821757343  …  0.20383391691934405, 0.36801045046768255, -0.730165156620558, 0.36801045046768255, 0.3506679475935389, -0.7128262673079332, 0.3506679475935389, 0.3680104504660461, -0.7301651566173452, 0.3680104504660461], [3.249338717382117, 3.2364820453010013, 3.249338717382118, 3.2360033265274106, 3.1834903560072565, 3.236003326527411, 3.2493387173821313, 3.236482045300956, 3.249338717382131, 4.254412150386502, 4.254189615920001, 4.254412150386502, 4.2267537529546635, 4.225378090751844, 4.2267537529546635, 4.254412150386499, 4.254189615920004, 4.254412150386499])

In [None]:
# Discrete Sensitivity Analysis (ReverseDiff VJP, not compiled)
autojacvec = DiscreteAdjoint.ReverseDiffVJP()
@btime dp_rdc, du0_rdc = DiscreteAdjoint.discrete_adjoint(sol, dg, ts; autojacvec)

  273.234 ms (21877 allocations: 1.39 MiB)


([27.14069629153895, 27.140969918964853, 27.140696291538934, 27.163288128349087, 27.164461801268217, 27.16328812834907, 27.14069629155256, 27.140969918938097, 27.140696291552565, -61.51918217573834  …  0.20383391691912187, 0.368010450467627, -0.7301651566204524, 0.3680104504676269, 0.3506679475934923, -0.7128262673078052, 0.3506679475934922, 0.36801045046598646, -0.730165156617234, 0.3680104504659865], [3.249338717379528, 3.236482045298386, 3.249338717379528, 3.236003326524797, 3.1834903560045302, 3.2360033265247976, 3.249338717379541, 3.236482045298342, 3.249338717379542, 4.254412150385988, 4.254189615919489, 4.254412150385989, 4.226753752954094, 4.225378090751271, 4.226753752954094, 4.254412150385988, 4.254189615919491, 4.254412150385987])

In [None]:
# Discrete Sensitivity Analysis (ReverseDiff VJP, compiled)
autojacvec = DiscreteAdjoint.ReverseDiffVJP(true)
@btime dp_rdc, du0_rdc = DiscreteAdjoint.discrete_adjoint(sol, dg, ts; autojacvec)

  62.585 ms (37285 allocations: 2.11 MiB)


([27.14069629153895, 27.140969918964853, 27.140696291538934, 27.163288128349087, 27.164461801268217, 27.16328812834907, 27.14069629155256, 27.140969918938097, 27.140696291552565, -61.51918217573834  …  0.20383391691912187, 0.368010450467627, -0.7301651566204524, 0.3680104504676269, 0.3506679475934923, -0.7128262673078052, 0.3506679475934922, 0.36801045046598646, -0.730165156617234, 0.3680104504659865], [3.249338717379528, 3.236482045298386, 3.249338717379528, 3.236003326524797, 3.1834903560045302, 3.2360033265247976, 3.249338717379541, 3.236482045298342, 3.249338717379542, 4.254412150385988, 4.254189615919489, 4.254412150385989, 4.226753752954094, 4.225378090751271, 4.226753752954094, 4.254412150385988, 4.254189615919491, 4.254412150385987])

In [None]:
# Discrete Sensitivity Analysis (Zygote VJP)
autojacvec = DiscreteAdjoint.ZygoteVJP()
@btime dp_z, du0_z = DiscreteAdjoint.discrete_adjoint(sol, dg, ts; autojacvec)

  9.123 s (41224104 allocations: 1.74 GiB)


([27.140696291538866, 27.140969918964775, 27.14069629153888, 27.16328812834898, 27.164461801268228, 27.163288128348963, 27.14069629155246, 27.140969918937977, 27.140696291552498, -61.51918217573783  …  0.20383391691907923, 0.368010450467627, -0.7301651566204443, 0.36801045046762704, 0.35066794759347175, -0.7128262673078025, 0.35066794759347186, 0.3680104504659926, -0.7301651566172276, 0.3680104504659927], [3.2493387173791883, 3.236482045298044, 3.2493387173791883, 3.236003326524452, 3.1834903560041723, 3.236003326524452, 3.249338717379203, 3.2364820452979974, 3.2493387173792025, 4.254412150385942, 4.254189615919442, 4.254412150385941, 4.226753752954039, 4.2253780907512155, 4.226753752954039, 4.2544121503859404, 4.254189615919444, 4.25441215038594])

The third model, POLLU, simulates air pollution. It is a stiff non-linear ODE system
which consists 20 ODEs:
$$
\frac{d u_1}{d t} = -p_1 u_1  - p_{10} u_{11} u_1 - p_{14} u_1 u_6 - p_{23} u_1 u_4 -
p_{24} u_{19} u_1 + p_2 u_2 u_4 + p_3 u_5 u_2 + p_9 u_{11} u_2 + p_{11} u_{13} + p_{12} u_{10} u_2 + p_{22} u_{19} + p_{25} u_{20} \\
\frac{d u_2}{d t} = −p_2 u_2 u_4 − p_3 u_5 u_2 − p_9 u_{11} u_2 − p_{12} u_{10} u_2 + p_1 u_1 + p_{21} u_{19} \\
\frac{d u_3}{d t}
= −p15u3 + p1u1 + p17u4 + p19u16 + p22u19 \\
\frac{d u_4}{d t}
= −p2u2u4 − p16u4 − p17u4 − p23u1u4 + p15u3 \\
\frac{d u_5}{d t}
= −p3u5u2 + p4u7 + p4u7 + p6u7u6 + p7u9 + p13u14 + p20u17u6 \\
\frac{d u_6}{d t}
= −p6u7u6 − p8u9u6 − p14u1u6 − p20u17u6 + p3u5u2 + p18u16 + p18u16 \\
\frac{d u_7}{d t}
= −p4u7 − p5u7 − p6u7u6 + p13u14 \\
\frac{d u_8}{d t}
= p4u7 + p5u7 + p6u7u6 + p7u9 \\
\frac{d u_9}{d t}
= −p7u9 − p8u9u6 \\
\frac{d u_{10}}{d t}
= −p12u10u2 + p7u9 + p9u11u2 \\
\frac{d u_{11}}{d t}
= −p9u11u2 − p10u11u1 + p8u9u6 + p11u13 \\
\frac{d u_{12}}{d t}
= p9u11u2 \\
\frac{d u_{13}}{d t}
= −p11u13 + p10u11u1 \\
\frac{d u_{14}}{d t}
= −p13u14 + p12u10u2 \\
\frac{d u_{15}}{d t}
= p14u1u6 \\
\frac{d u_{16}}{d t}
= −p18u16 − p19u16 + p16u4 \\
\frac{d u_{17}}{d t}
= −p20u17u6 \\
\frac{d u_{18}}{d t}
= p20u17u6 \\
\frac{d u_{19}}{d t}
= −p21u19 − p22u19 − p24u19u1 + p23u1u4 + p25u20 \\
\frac{d u_{20}}{d t} = −p_{25} u_{20} + p_{24} u_{19} u_1
$$