In [1]:
using Zygote
using OrdinaryDiffEq
using Interpolations

In [2]:
struct CosmoPar{T}
    Ωm::T
    Ωb::T
    h::T
    n_s::T
    σ8::T
end

mutable struct Settings
    nz::Int
    nz_pk::Int
    nk::Int
end

In [13]:
function _dgrowth!(dd, d, cosmo::CosmoPar, a)
    ez = _Ez(cosmo, 1.0/a-1.0)
    dd[1] = d[2] * 1.5 * cosmo.Ωm / (a^2*ez)
    dd[2] = d[1] / (a^3*ez)
end

function _Ez(cosmo::CosmoPar, z)
    E2 = @. (cosmo.Ωm*(1+z)^3+(1-cosmo.Ωm))
    return sqrt.(E2)
end

function cosmology(Wm, s8, settings, z)
    cpar = CosmoPar(Wm, 0.05, 0.67, 0.96, s8)
    # Load settings
    nk = settings.nk
    nz = settings.nz
    nz_pk = settings.nz_pk
    zs_pk = LinRange(0., 3.0, nz_pk)
    zs = LinRange(0., 3.0, nz)
    # Compute linear power spectrum at z=0.
    logk = range(log(0.0001), stop=log(7.0), length=nk)
    ks = exp.(logk)
    dlogk = log(ks[2]/ks[1])

    z_ini = 1000.0
    a_ini = 1.0/(1.0+z_ini)
    ez_ini = _Ez(cpar, z_ini)
    d0 = [a_ini^3*ez_ini, a_ini]
    a_s = reverse(@. 1.0 / (1.0 + zs))
    prob = ODEProblem(_dgrowth!, d0, (a_ini, 1.0), cpar)
    sol = solve(prob, Tsit5(), reltol=1E-6,
                abstol=1E-8, saveat=a_s)
    # OPT: interpolation (see below), ODE method, tolerances
    # Note that sol already includes some kind of interpolation,
    # so it may be possible to optimize this by just using
    # sol directly.
    s = vcat(sol.u'...)
    Dzs = reverse(s[:, 2] / s[end, 2])
    # OPT: interpolation method
    Dzi = LinearInterpolation(zs, Dzs)

    return Dzs[2]
end


cosmology (generic function with 1 method)

In [14]:
settings = Settings(100, 100, 100)

Settings(100, 100, 100)

In [15]:
cosmology(0.3, 0.81, settings, 0.1)

0.9845408965169671

In [None]:
Zygote.gradient((Wm, s8) -> cosmology(Wm, s8, settings, 0.1), 0.3, 0.81)