In [None]:
using Pkg; Pkg.activate()
using Bessels, ForwardDiff, ForwardDiffChainRules, ChainRulesCore, FiniteDiff
using Plots; Plots.default(margin = 4*Plots.mm, linewidth = 4, dpi = 150, fmt = :png)

# Define chain rule for spherical Bessel function

In [None]:
jl(l, x) = Bessels.sphericalbesselj(l, x)
jl′(l, x) = l/(2l+1)*jl(l-1,x) - (l+1)/(2l+1)*jl(l+1,x)

# Overload chain rule for spherical Bessel function
ChainRulesCore.frule((_, _, Δx), ::typeof(jl), l, x) = jl(l, x), jl′(l, x) * Δx # (value, derivative)
@ForwardDiff_frule jl(l::Integer, x::ForwardDiff.Dual) # define dispatch

# Test on a crazy composite function
crazy(l, x) = sin(7*jl(l, x^2))
dcrazy_anal(l, x) = cos(7*jl(l, x^2)) * 7*jl′(l, x^2) * 2*x
dcrazy_fin(l, x) = FiniteDiff.finite_difference_derivative(x -> crazy(l, x), x)
dcrazy_auto(l, x) = ForwardDiff.derivative(x -> crazy(l, x), x)

l = 5
x = 0.0:0.01:6.0
p = plot(layout = (2, 1), size = (800, 600))
plot!(p[1], x, crazy.(l, x), xlabel = "x", title = "crazy(x) = sin(7jₗ(x²))", label = nothing)
plot!(p[2], x, dcrazy_anal.(l, x); label = "analytical derivative", linestyle = :solid, xlabel = "x", title = "d(crazy(x)) / dx")
plot!(p[2], x, dcrazy_fin.(l, x); label = "finite differences", linestyle = :dash)
plot!(p[2], x, dcrazy_auto.(l, x); label = "automatic differentiation", linestyle = :dot)