# Paper

Run:
```
$ julia

julia> using IJulia
julia> notebook(dir = pwd(), verbose = true)
# open this notebook and select Julia (multithreaded) kernel
```

In [None]:
Base.current_project() # should be in same directory as notebook

In [None]:
using Base.Threads
Threads.nthreads() # should be more than 1

In [None]:
using Plots
Plots.default(
    linewidth = 1.5, grid = false, framestyle = :box,
    tickfontsize = 8, labelfontsize = 8, legendfontsize = 8,
)
using LinearAlgebra: BLAS; BLAS.set_num_threads(1)

## Basic usage

In [None]:
using SymBoltz

In [None]:
M = ΛCDM()

In [None]:
pars = Dict(
  M.γ.T₀ => 2.7, M.b.Ω₀ => 0.05, M.b.YHe => 0.25,
  M.ν.Neff => 3.0, M.c.Ω₀ => 0.27, M.h.m_eV => 0.02,
  M.I.ln_As1e10 => 3.0, M.I.ns => 0.96, M.g.h => 0.7
)

In [None]:
prob = CosmologyProblem(M, pars)

In [None]:
ks = [4, 40, 400, 4000]
sol = solve(prob, ks; verbose = true)

In [None]:
using Plots
p = plot(layout=(2, 3), size=(1200, 500), link = :x, xlims = (-7, 0), grid = false, right_margin = -1*Plots.mm)
plot!(p[1], sol, log10(M.g.a), [M.τ/M.τ0, 1/M.g.ℰ, 1/M.g.E]; xlabel = "", xformatter = :none, yticks = 0.0:0.2:1.0, bottom_margin = -5*Plots.mm)
plot!(p[2], sol, log10(M.g.a), [M.b.ρ, M.c.ρ, M.γ.ρ, M.ν.ρ, M.h.ρ, M.Λ.ρ] ./ M.G.ρ; xlabel = "", xformatter = :none, yticks = 0.0:0.2:1.0, legend_position = :topleft, bottom_margin = -5*Plots.mm)
plot!(p[3], sol, log10(M.g.a), [M.b.rec.XHe⁺⁺, M.b.rec.XHe⁺, M.b.rec.XH⁺, M.b.Xe]; xlabel = "", xformatter = :none, yticks = 0.0:0.2:1.2, legend_position = :left, bottom_margin = -5*Plots.mm)
plot!(p[4], sol, log10(M.g.a), [M.g.Φ, M.g.Ψ], ks; bottom_margin=4*Plots.mm) # bottom margin to show labels
plot!(p[5], sol, log10(M.g.a), log10.(abs.([M.b.δ, M.c.δ, M.γ.δ, M.ν.δ, M.h.δ])), ks; ylims = (-3, +5), klabel = false, bottom_margin=4*Plots.mm)
plot!(p[6], sol, log10(M.g.a), [M.γ.F0, M.γ.F[1], M.γ.F[2]], ks; klabel = false, bottom_margin=4*Plots.mm)

In [None]:
savefig(p, "evolution.pdf")

## Modifying models

In [None]:
g, τ, k = M.g, M.τ, M.k
a, ℰ, Φ, Ψ = g.a, g.ℰ, g.Φ, g.Ψ
D = Differential(τ)
@parameters w₀ wₐ cₛ² Ω₀ ρ₀
@variables ρ(τ) P(τ) w(τ) cₐ²(τ) δ(τ,k) θ(τ,k) σ(τ,k)
eqs = [
  w ~ w₀ + wₐ*(1-a)
  ρ₀ ~ 3*Ω₀ / (8*Num(π))
  ρ ~ ρ₀ * a^(-3(1+w₀+wₐ)) * exp(-3wₐ*(1-a))
  P ~ w * ρ
  cₐ² ~ w - 1/(3ℰ) * D(w)/(1+w)
  D(δ) ~ 3ℰ*(w-cₛ²)*δ - (1+w) * (
         (1+9(ℰ/k)^2*(cₛ²-cₐ²))*θ + 3*D(Φ))
  D(θ) ~ (3cₛ²-1)*ℰ*θ + k^2*cₛ²*δ/(1+w) + k^2*Ψ
  σ ~ 0
]
initialization_eqs = [
  δ ~ -3//2 * (1+w) * Ψ
  θ ~ 1//2 * (k^2*τ) * Ψ
]
X = System(eqs, τ; initialization_eqs, name = :X)

In [None]:
M = ΛCDM(Λ = X, name = :w0waCDM)
pars[M.X.w₀] = -0.9
pars[M.X.wₐ] = 0.2
pars[M.X.cₛ²] = 1.0
prob = CosmologyProblem(M, pars)

## Computing spectra

In [None]:
using Unitful, UnitfulAstro # for units
ks = 10 .^ range(-4, 0, length=200) / u"Mpc"
Ps = spectrum_matter(prob, ks);
ks = ks * u"Mpc" # 1/Mpc
Ps = Ps / u"Mpc^3"; # Mpc^3

In [None]:
ls = 20:20:2000
Dls = spectrum_cmb([:TT, :TE, :EE], prob, ls; normalization = :Dl, unit = u"μK", verbose = false)
Dls = Dls / u"(μK)^2"

In [None]:
using CLASS
lmax = lastindex(M.γ.F)
classopts = Dict(
    "output" => "mPk, tCl, pCl",

    "ic" => "ad",
    "modes" => "s",
    "gauge" => "newtonian",

    # metric
    "h" => pars[M.g.h],

    # photons
    "T_cmb" => pars[M.γ.T₀],
    "l_max_g" => lmax,
    "l_max_pol_g" => lmax,

    # baryons
    "Omega_b" => pars[M.b.Ω₀],
    "YHe" => pars[M.b.YHe],
    "recombination" => "recfast",
    "recfast_Hswitch" => 1,
    "recfast_Heswitch" => 6,
    "reio_parametrization" => "reio_camb",

    # cold dark matter
    "Omega_cdm" => pars[M.c.Ω₀],

    # neutrinos
    "N_ur" => SymBoltz.have(M, :ν) ? pars[M.ν.Neff] : 0.0,
    "N_ncdm" => SymBoltz.have(M, :h) ? 1 : 0,
    "deg_ncdm" => SymBoltz.have(M, :h) ? prob.bg.ps[M.h.N] : 0,
    "m_ncdm" => SymBoltz.have(M, :h) ? pars[M.h.m_eV] : 0.0,
    "T_ncdm" => SymBoltz.have(M, :h) ? (4/11)^(1/3) : 0.0,
    "l_max_ur" => lmax,
    "l_max_ncdm" => lmax,

    # primordial power spectrum
    "ln_A_s_1e10" => pars[M.I.ln_As1e10],
    "n_s" => pars[M.I.ns],

    # w0wa dark energy
    "Omega_Lambda" => 0.0, # unspecified
    "w0_fld" => SymBoltz.have(M, :X) ? pars[M.X.w₀] : -1.0,
    "wa_fld" => SymBoltz.have(M, :X) ? pars[M.X.wₐ] : 0.0,
    "cs2_fld" => SymBoltz.have(M, :X) ? pars[M.X.cₛ²] : 1.0,
    "use_ppf" => SymBoltz.have(M, :X) ? "no" : "yes", # full w0wa equations

    # curvature
    "Omega_k" => SymBoltz.have(M, :K) ? pars[M.K.Ω₀] : 0.0, # curvature

    # from https://arxiv.org/pdf/2303.09451 / https://arxiv.org/pdf/2405.06047
    "background_Nloga" => 6000, # helps a lot!
    "tol_perturbations_integration" => 1e-6,

    # disable approximations
    "tight_coupling_trigger_tau_c_over_tau_h" => 1e-2, # cannot turn off
    "tight_coupling_trigger_tau_c_over_tau_k" => 1e-3, # cannot turn off
    "radiation_streaming_approximation" => 3, # turns off RSA
    "ur_fluid_approximation" => 3, # turns off UFA
    "ncdm_fluid_approximation" => 3, # turns off NCDM fluid approximation
)
class = CLASSProblem(classopts)
class = solve(class)

In [None]:
p = plot(layout = grid(2, 4, heights=[0.75, 0.25]), size = (1200, 280), left_margin = 5*Plots.mm, bottom_margin = 7*Plots.mm)
plot!(p[1,1], log10.(ks), log10.(Ps); xformatter = :none, ylabel = "lg(P / Mpc³)", label = nothing, xlims = (-4, 0), ylims = (2, 5), color = 1)
plot!(p[1,2], ls, Dls[:, 1]; xformatter = :none, ylabel = "Cₗᵀᵀ l(l+1) / 2π (μK)²", color = 2, label = nothing, xlims = (0, 2000)) #, ribbon = @. ls*(ls+1)/2π * σs[:, 1]/u"(μK)^2")
plot!(p[1,3], ls, Dls[:, 2]; xformatter = :none, ylabel = "Cₗᵀᴱ l(l+1) / 2π (μK)²", label = nothing, xlims = (0, 2000), color = 3)
plot!(p[1,4], ls, Dls[:, 3]; xformatter = :none, ylabel = "Cₗᴱᴱ l(l+1) / 2π (μK)²", label = nothing, xlims = (0, 2000), color = 4)
p

Compare to CLASS

In [None]:
using DataInterpolations

h = pars[M.g.h] # TODO: plot in Mpc/h instead?
Tγ0 = pars[M.γ.T₀]

# interpolate CLASS to the same k
# interpolate y1(x1) to y2(x2) # TODO: log flag?
interp(y1, x1, x2) = LinearInterpolation(y1, x1)(x2)
Ps_class = interp(class["pk"][:,"P (Mpc/h)^3"]/h^3, class["pk"][:,"k (h/Mpc)"]*h, ks)
DlTTs_class = interp(class["cl"][:,"TT"], class["cl"][:,"l"], ls) * (1e6*Tγ0)^2
DlTEs_class = interp(class["cl"][:,"TE"], class["cl"][:,"l"], ls) * (1e6*Tγ0)^2
DlEEs_class = interp(class["cl"][:,"EE"], class["cl"][:,"l"], ls) * (1e6*Tγ0)^2

alpha = 0.3 # for overplotting
color = :black
linestyle = :dash

plot!(p[1,1], log10.(ks), @.(log10(Ps_class)); xformatter = :none, ylabel = "lg(P / Mpc³)", label = nothing, xlims = (-4, 0), ylims = (2, 5), color, alpha, linestyle)
plot!(p[1,2], ls, @. DlTTs_class; xformatter = :none, label = nothing, xlims = (0, 2000), color, alpha, linestyle) #, ribbon = @. ls*(ls+1)/2π * σs[:, 1]/u"(μK)^2")
plot!(p[1,3], ls, @. DlTEs_class; xformatter = :none, label = nothing, xlims = (0, 2000), color, alpha, linestyle) #, ribbon = @. ls*(ls+1)/2π * σs[:, 1]/u"(μK)^2")
plot!(p[1,4], ls, @. DlEEs_class; xformatter = :none, label = nothing, xlims = (0, 2000), color, alpha, linestyle) #, ribbon = @. ls*(ls+1)/2π * σs[:, 1]/u"(μK)^2")

plot!(p[2,1], log10.(ks), @.(log10(abs(Ps/Ps_class-1))), xlabel = "lg(k / Mpc⁻¹)", ylabel = "lg(rel.err.)", yticks = -5:1:-1, xlims = (-4, 0), ylims = (-5, -2), color = 1, label = nothing, top_margin = -10*Plots.mm)
plot!(p[2,2], ls, @.(log10(abs(Dls[:,1]/DlTTs_class-1))), xlabel = "l", xlims = (0, 2000), ylims = (-5, -2), yticks = -5:1:-2, color = 2, label = nothing, top_margin = -10*Plots.mm)
plot!(p[2,3], ls, @.(log10(abs(Dls[:,2]/DlTEs_class-1))), xlabel = "l", xlims = (0, 2000), ylims = (-5, -2), yticks = -5:1:-2, color = 3, label = nothing, top_margin = -10*Plots.mm)
plot!(p[2,4], ls, @.(log10(abs(Dls[:,3]/DlEEs_class-1))), xlabel = "l", xlims = (0, 2000), ylims = (-5, -2), yticks = -5:1:-2, color = 4, label = nothing, top_margin = -10*Plots.mm, right_margin = 2*Plots.mm)

In [None]:
savefig(p, "spectra.pdf")

## Fisher forecasting

In [None]:
vary = [
  M.g.h, M.c.Ω₀, M.b.Ω₀, M.b.YHe, M.ν.Neff,
  M.h.m_eV, M.X.w₀, M.X.wₐ, M.I.ln_As1e10, M.I.ns,
]
genprob = parameter_updater(prob, vary)
ls, ls′ = 100:1:1000, 100:25:1000
Dl(θ) = (println("SymBoltz: ", θ); spectrum_cmb(:TT, genprob(θ), ls, ls′; normalization = :Dl, verbose = true))
θ₀ = map(par -> pars[par], vary)
Dls = Dl(θ₀)

In [None]:
symboltz2class = Dict(
    M.g.h => "h",
    M.c.Ω₀ => "Omega_cdm",
    M.b.Ω₀ => "Omega_b",
    M.b.YHe => "YHe",
    M.ν.Neff => "N_ur",
    M.h.m_eV => "m_ncdm",
    M.I.ln_As1e10 => "ln_A_s_1e10",
    M.I.ns => "n_s",
    M.X.w₀ => "w0_fld",
    M.X.wₐ => "wa_fld",
)
function Dl_class(θ)
    println("CLASS:", θ)
    newclassopts = Dict(symboltz2class[vary[i]] => θ[i] for i in eachindex(θ))
    prob = CLASSProblem(merge(classopts, newclassopts))
    sol = solve(prob)
    Dls = sol["cl"][:,"TT"]
    idxs = map(l -> Int(l) in ls, sol["cl"][:,"l"])
    Dls = Dls[idxs]
    return Dls
end
Dls_class = Dl_class(θ₀);

In [None]:
plot(ls, Dls)
plot!(ls, Dls_class; linestyle = :dash)

### Automatic differentiation: all parameters

In [None]:
using ForwardDiff: jacobian
dDl_dθ_ad = jacobian(Dl, θ₀)

### Finite differences: parameter-by-parameter

In [None]:
# Compute ∂f/∂xᵢ with finite difference step size Δx
function fd_jacobian(f, x, i, Δx)
    xi0 = x[i]
    try
        x[i] = xi0 + Δx/2
        f₊ = f(x)
        x[i] = xi0 - Δx/2
        f₋ = f(x)
        Δf = f₊ .- f₋
        return Δf ./ Δx
    finally
        x[i] = xi0 # always reset x[i], even if any calls to f errors
    end
end

# Compute ∂f/∂xᵢ with finite differences
# by halving the initial step size Δx until the L2-deviation from df_target no longer falls
function fd_jacobian(f, x, i, Δx, df_target::AbstractArray; verbose = true)
    df0 = fd_jacobian(f, x, i, Δx)
    error(df) = sum((df[i]-df_target[i])^2 for i in eachindex(df)) # L2
    err0 = error(df0) # previous
    verbose && println("Δx = $Δx, err = $err0")
    while true
        Δx = Δx / 2
        df = fd_jacobian(f, x, i, Δx)
        err = error(df)
        verbose && println("Δx = $Δx, err = $err0")
        err > err0 && return df0
        df0 = df
        err0 = err
    end
end

# Decrease step size as long as difference between successive iterations is decreasing (i.e. converging)
function fd_jacobian(f, x, i; initrelstep = 0.8, verbose = true, div = 2)
    Δx1 = initrelstep * x[i] # initial relative step size
    df1 = fd_jacobian(f, x, i, Δx1)
    error(df1, df2) = sum((df1[i]-df2[i])^2 for i in eachindex(df1)) # L2
    err1 = error(df1, zeros(size(df1)))
    while true
        Δx2 = Δx1 / div
        df2 = fd_jacobian(f, x, i, Δx2)
        err2 = error(df1, df2)
        verbose && println("error = $err2")
        err2 > err1 && return df1 # result no longer improving
        df1 = df2
        Δx1 = Δx2
        err1 = err2
    end
end

#f_test(x) = [sin(x[1]+2x[2]+3x[3]), cos(x[1]+2x[2]+3x[3])]
#x = [1.0, 2.0, 3.0]
#f_test(x)
#fd_jacobian(f_test, x, 3)

In [None]:
function plot_fd_stepsizes(f, x, i, Δxs; kwargs...)
    plot()
    for (j, Δx) in enumerate(Δxs)
        linewidth = 7 - j
        df = fd_jacobian(f, x, i, Δx)
        plot!(ls, df; palette = :RdYlGn_6, linewidth, kwargs..., label = "Δx = $Δx")
    end
    return plot!()
end

#i = 1
#steps = [0.1, 0.01, 0.001] * θ₀[i]
#plot_fd_stepsizes(Dl_class, θ₀, i, steps)

In [None]:
function convergence(f, x, i, Δxs; kwargs...)
    dfs = map(Δx -> fd_jacobian(f, x, i, Δx), Δxs)
    error(df1, df2) = sum((df1[i]-df2[i])^2 for i in eachindex(df1)) # L2
    errs = [error(dfs[i], dfs[i-1]) for i in 2:length(dfs)]
    return errs
end

function plot_convergence(f, x, is, relsteps)
    p = plot(xlabel="lg(relative step size)", ylabel="lg(L₂(∇fᵢ-∇fᵢ₋₁))")
    for i in is
        errs = convergence(f, x, i, relsteps .* x[i])
        plot!(p, log10.(relsteps[1:end-1]), log10.(errs))
    end
    return p
end

#relsteps = 1 ./ 10 .^ range(0.5, 3.0, step=0.5)
#plot_convergence(Dl_class, θ₀, eachindex(vary), relsteps)

In [None]:
dDl_dθ_fd = similar(dDl_dθ_ad)
dDl_dθ_fd[:,1] = fd_jacobian(Dl_class, θ₀, 1, 0.05*θ₀[1])
dDl_dθ_fd[:,2] = fd_jacobian(Dl_class, θ₀, 2, 0.05*θ₀[2])
dDl_dθ_fd[:,3] = fd_jacobian(Dl_class, θ₀, 3, 0.05*θ₀[3])
dDl_dθ_fd[:,4] = fd_jacobian(Dl_class, θ₀, 4, 0.05*θ₀[4])
dDl_dθ_fd[:,5] = fd_jacobian(Dl_class, θ₀, 5, 0.05*θ₀[5])
dDl_dθ_fd[:,6] = fd_jacobian(Dl_class, θ₀, 6, 0.05*θ₀[6])
dDl_dθ_fd[:,7] = fd_jacobian(Dl_class, θ₀, 7, 0.05*θ₀[7])
dDl_dθ_fd[:,8] = fd_jacobian(Dl_class, θ₀, 8, 0.05*θ₀[8])
dDl_dθ_fd[:,9] = fd_jacobian(Dl_class, θ₀, 9, 0.05*θ₀[9])
dDl_dθ_fd[:,10] = fd_jacobian(Dl_class, θ₀, 10, 0.05*θ₀[10])

### Compare all parameters

In [None]:
is = eachindex(vary)
label_ad = "SymBoltz (AD)"
label_fd = "CLASS (FD)"
θnames = replace.(string.(vary), "₊" => ".")
color = permutedims(eachindex(θnames)[is])
p = hline([NaN NaN]; xticks = 100:100:1000, ylims = (-4.5, 1.5), size = (600, 400), color = :black, linestyle = [:solid :dash], xlabel = "l", ylabel = "lg(|∂Cₗ / ∂θᵢ / Cₗ|)", label = [label_ad label_fd], legend_position = :bottomright, legend_columns=2, xlims = extrema(ls), right_margin=2*Plots.mm)
f = x -> log10(abs(x))
plot!(p, ls, f.(dDl_dθ_ad[:,is] ./ Dls); color, linestyle = :solid, label = "θᵢ = " .*  permutedims(θnames[is]))
plot!(p, ls, f.(dDl_dθ_fd[:,is] ./ Dls_class); color = :black, alpha, linestyle = :dash, label = nothing)
p

In [None]:
savefig(p, "derivatives.pdf")

In [None]:
plot(ls, dDl_dθ_ad[:,is] ./ Dls .- (dDl_dθ_fd[:,is] ./ Dls_class); ylabel = "difference", label = nothing, size = (600, 200))

### Compute Fisher matrices

In [None]:
using LinearAlgebra
function fisher_matrix(ls, Dls, dDls_dθ, is = nothing)
    !isnothing(is) && return fisher_matrix(ls, Dls, dDls_dθ[:,is])
    N = size(dDls_dθ)[2]
    F = zeros(eltype(Dls), (N, N))
    for i in 1:N
        for j in 1:i
            F[i,j] = sum((ls[il]+1/2) * dDls_dθ[il,i]/Dls[il] * dDls_dθ[il,j]/Dls[il] for il in eachindex(ls))
            F[j,i] = F[i,j]
        end
    end
    return F
end
is = eachindex(vary)
dDl_dθ1 = copy(dDl_dθ_ad)
dDl_dθ2 = copy(dDl_dθ_fd)
F_fd = Symmetric(fisher_matrix(ls, Dls, dDl_dθ1, is))
F_ad = Symmetric(fisher_matrix(ls, Dls_class, dDl_dθ2, is))

# Regularize Fisher matrices by adding small constant to diagonal
ϵ = 1e-2 * min(minimum(diag(F_ad)), minimum(diag(F_fd)))
F_fd += ϵ * I
F_ad += ϵ * I

C_fd = inv(F_fd)
C_ad = inv(F_ad)

p = plot(layout = (2,2))
heatmap!(p[1,1], log10.(abs.(F_fd)); title = "FD")
heatmap!(p[2,1], log10.(abs.(C_fd)))
heatmap!(p[1,2], log10.(abs.(F_ad)); title = "AD")
heatmap!(p[2,2], log10.(abs.(C_ad)))

### Compute covariances and plot contours

1. Find eigenvalues of 2D marginalized sub-matrices
2. Scale with χ² with 2 DOF (because 2D marginalized) to get 68% and 95% confidence ellipses

In [None]:
using Distributions

function ellipse(C, i, j, μ = (0.0, 0.0); conf = 0.68, N = 33)
    σᵢ², σⱼ², σᵢⱼ = C[i,i], C[j,j], C[i,j]
    θ = (atan(2σᵢⱼ, σᵢ²-σⱼ²)) / 2
    a = √((σᵢ²+σⱼ²)/2 + √((σᵢ²-σⱼ²)^2/4+σᵢⱼ^2)) # 1st eigenvalue of marginalized 2x2 matrix
    b = √(max(0.0, (σᵢ²+σⱼ²)/2 - √((σᵢ²-σⱼ²)^2/4+σᵢⱼ^2))) # 2nd eigenvalue of marginalized 2x2 matrix
    
    c = √(quantile(Chisq(2), conf))
    a *= c
    b *= c

    μx, μy = μ
    ts = range(0, 2π, length=N)
    xs = μx .+ a*cos(θ)*cos.(ts) - b*sin(θ)*sin.(ts)
    ys = μy .+ a*sin(θ)*cos.(ts) + b*cos(θ)*sin.(ts)
    return xs, ys
end

function plot_ellipses(C; kwargs...)
    N = size(C)[1]
    p = plot(layout = (N-1, N-1), size = (550, 550), aspect = 1, margin=-2*Plots.mm)
    return plot_ellipses!(p, C; kwargs...)
end

function plot_ellipses!(p, C; label = nothing, digits = 2, kwargs...)
    N = size(C)[1]
    for i in eachindex(IndexCartesian(), C)
        ix, iy = i[1], i[2]
        if iy == 1 || iy > size(p)[1] + 1 || ix > size(p)[2]
            continue # out of bounds; skip
        end
        subplot = p[iy-1, ix]
        if ix >= iy
            # upper triangular part
            _label = (iy-1, ix) == (1, size(p)[2]) ? label : nothing
            hline!(subplot, [NaN]; framestyle = :none, label = _label, legendfontsize = 10, kwargs...)
        else
            # lower triangular part
            μx = θ₀[ix]
            μy = θ₀[iy]
            xlabel = iy == N ? θnames[ix] : ""
            ylabel = ix == 1 ? θnames[iy] : ""
            scatter!(subplot, [(μx, μy)]; marker = :cross, color = :black, label = nothing)
            for conf in [0.68, 0.95]
                xs, ys = ellipse(C, ix, iy, (μx, μy); conf)
                plot!(subplot, xs, ys; xlabel, ylabel, label = nothing, labelfontsize = 7, tickfontsize = 7, kwargs...)
            end
        end
    end
    return p
end

p = plot_ellipses(C_ad; color = 1, linestyle = :solid, linewidth = 2, label = " $label_ad")
plot_ellipses!(p, C_fd; color = 2, linestyle = :dash, linewidth = 2, label = " $label_fd")

ticks = Dict(
    M.g.h => 0.55:0.10:0.85,
    M.c.Ω₀ => 0.10:0.10:0.50,
    M.b.Ω₀ => 0.01:0.02:0.09,
    M.b.YHe => 0.15:0.10:0.35,
    M.ν.Neff => 2.4:0.3:3.6,
    M.h.m_eV => -0.50:0.25:0.50,
    M.X.w₀ => -1.2:0.3:0.3,
    M.X.wₐ => -0.5:0.3:0.7,
    M.I.ln_As1e10 => 2.92:0.02:3.08,
    M.I.ns => 0.92:0.04:1.04,
)
for i in eachindex(IndexCartesian(), C_fd)
    ix, iy = i[1], i[2]
    (iy == 1 || iy > size(p)[1] + 1 || ix > size(p)[2]) && continue # out of bounds; skip
    subplot = p[iy-1, ix]
    if ix > 1
        plot!(subplot, yformatter = :none)
    end
    if iy < size(p)[1]+1
        plot!(subplot, xformatter = :none)
    end
    px = vary[ix]
    py = vary[iy]
    px in keys(ticks) && xticks!(subplot, ticks[px])
    py in keys(ticks) && yticks!(subplot, ticks[py])
end
plot!(p, xrotation = 45, yrotation = 45, foreground_color_legend = nothing)
plot!(p[1,1], left_margin=0*Plots.mm, top_margin=0*Plots.mm)
plot!(p[1, size(C_ad)[1]-1], right_margin=0*Plots.mm)

In [None]:
savefig(p, "forecast.pdf")

## Parameter fitting

In [None]:
using DataFrames, CSV, PDMats

docsdir = joinpath(pkgdir(SymBoltz), "docs")
data = joinpath(docsdir, "Pantheon/lcparam_full_long.txt")
Csyst = joinpath(docsdir, "Pantheon/sys_full_long.txt")

# Read data table
data = CSV.read(data, DataFrame, delim = " ", silencewarnings = true)

# Read covariance matrix of apparent magnitudes (mb)
Csyst = CSV.read(Csyst, DataFrame, header = false) # long vector
Csyst = collect(reshape(Csyst[2:end, 1], (Int(Csyst[1, 1]), Int(Csyst[1, 1])))) # to matrix
Cstat = Diagonal(data.dmb)^2 # TODO: should this be squared?
C = Csyst + Cstat

# Sort data and covariance matrix with decreasing redshift
is = sortperm(data, :zcmb, rev = true)
C = C[is, is]
C = PDMat(Symmetric(C)) # efficient sym-pos-def matrix with Cholesky factorization
data = data[is, :]

In [None]:
g = SymBoltz.metric()
X = SymBoltz.w0wa(g; analytical = true)
K = SymBoltz.curvature(g)
M = RMΛ(K = K, Λ = X)
M = complete(SymBoltz.background(M); flatten = false)
M = change_independent_variable(M, M.g.a; add_old_diff = true)
pars_fixed = Dict(M.τ => 0.0, M.r.T₀ => NaN, M.X.cₛ² => NaN, M.X.wa => 0.0, M.r.Ω₀ => 9.3e-5, M.K.Ω₀ => 0.0)
pars_varying = [M.m.Ω₀, M.g.h, M.X.w0]

dL = SymBoltz.distance_luminosity_function(M, pars_fixed, pars_varying, data.zcmb)
μ(p) = 5 * log10.(dL(p)[begin:end-1] / (10*SymBoltz.pc)) # distance modulus

# Show example predictions
Mb = -19.3 # absolute supernova brightness (constant since SN-Ia are standard candles)
bgopts = (alg = SymBoltz.Tsit5(), reltol = 1e-5, maxiters = 1e3)
p0 = [0.3, 0.7, -1.0] # fiducial parameters
μs = μ(p0)
mbs = μs .+ Mb

In [None]:
using Turing, LinearAlgebra

@model function supernova(μ_pred, mbs, C; Mb = Mb)
    # Parameter priors
    h ~ Uniform(0.1, 1.0)
    Ωm0 ~ Uniform(0.0, 1.0)
    w0 ~ Uniform(-2.0, 0.0)

    p = [Ωm0, h, w0]
    μs_pred = μ_pred(p)
    if isempty(μs_pred)
        Turing.@addlogprob! -Inf
        return nothing
    end
    mbs_pred = μs_pred .+ Mb
    return mbs ~ MvNormal(mbs_pred, C) # read "measurements sampled from multivariate normal with predictions and covariance matrix"

    # equivalently:
    #Δmb = mbs .- mbs_pred
    #χ² = transpose(Δmb) * invC * Δmb
    #Turing.@addlogprob! -1/2 * χ²
    #return nothing
end

# https://github.com/JuliaStats/Distributions.jl/issues/1964 # TODO: get rid of? PR?
function MvNormal(μ::AbstractVector{<:Real}, Σ::AbstractPDMat{<:Real})
    R = Base.promote_eltype(μ, Σ)
    Distributions.MvNormal{R, typeof(Σ), typeof(μ)}(μ, Σ)
end
function MvNormal(μ, Σ)
    return Distributions.MvNormal(μ, Σ)
end

sn_w0CDM_flat = supernova(μ, data.mb, C)

In [None]:
chain = sample(sn_w0CDM_flat, NUTS(), 5000; initial_params = (h = 0.5, Ωm0 = 0.5, w0 = -1.0), progress = false)

In [None]:
using CairoMakie, PairPlots
layout = (
    PairPlots.Scatter(),
    PairPlots.Contourf(sigmas = 1:2),
    PairPlots.MarginHist(),
    PairPlots.MarginDensity(color = :black),
    PairPlots.MarginQuantileText(color = :black, font = :regular),
    PairPlots.MarginQuantileLines(),
)
pp = pairplot(chain => layout)

In [None]:
save("constraints.pdf", pp)

## Appendix

In [None]:
using Latexify, LaTeXString

In [None]:
function process(eqs; fold = false, latex = true)
    eqs = substitute(eqs, M.g.ℰ => M.g.ℋ; fold)
    eqs = substitute(eqs, SymBoltz.ϵ => 1; fold)
    str = string(latexify(eqs))
    str = replace(str, 
        "\\left( \\tau \\right)" => "",
        "\\left" => "",
        "\\right" => "",
        "\\mathtt" => "",
        "\\begin{align}" => "\\begin{equation}\\begin{gathered}",
        "\\end{align}" => "\\end{gathered}\\end{equation}",
        "\\\\" => ", \\qquad",
        "&" => "",
    )
    str = replace(str, r"\( (.) \)" => s"\1")
    return latex ? Latexify.LaTeXString(str) : str
end

In [None]:
for comp in [M.g, M.G, M.c, M.b, M.γ, M.ν, M.h, M.X]
    println(nameof(comp), ": ", ModelingToolkit.description(comp))
    process(equations(comp)) |> println
end

### Massive neutrino momentum bins

In [None]:
using QuadGK
f₀(x) = 1 / (exp(x) + 1)
Nx = 8
xWs = [SymBoltz.momentum_quadrature(f₀, nx) for nx in 1:Nx];

In [None]:
using Printf
for (nx, (xs, _)) in enumerate(xWs)
    @printf("%d", nx)
    for i in 1:nx
        @printf(" & %.5f", xs[i])
    end
    for i in nx+1:Nx
        @printf(" &")
    end
    @printf(" \\\\\n")
end
for (nx, (_, Ws)) in enumerate(xWs)
    @printf("%d", nx)
    for i in 1:nx
        @printf(" & %.5f", Ws[i])
    end
    for i in nx+1:Nx
        @printf(" &")
    end
    @printf(" \\\\\n")
end