# Variation of Abundances with Nuisance Parameters

A major advantage of LINX is the ability to easily calculate primordial abundances that account for the uncertainties on the reaction rates in the BBN network.  Here, we will explore the impact of those uncertainties on the predicted nuclear rates.

## Preamble

In [None]:
%load_ext autoreload
%autoreload
import numpy as np
import jax.numpy as jnp
import jax
from jax import jit, vmap
import sys

sys.path.append("../")
import linx.const as const 
from linx.nuclear import NuclearRates
from linx.background import BackgroundModel
from linx.abundances import AbundanceModel

First, we set up our thermodynamics model (see "background_evolution" notebook for more information about the LINX computation of background thermodynamics) and the abundance model.  We'll set up with the `key_PRIMAT_2023` network, though it's simple to switch to a different network if desired.

In [None]:
thermo_model_DNeff = BackgroundModel()

(
    t_vec_ref, a_vec_ref, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec, Neff_vec 
) = thermo_model_DNeff(0.)

In [None]:
# comment or uncomment to toggle which reaction network to use

network = 'key_PRIMAT_2023'
# network = 'key_PRIMAT_2018'
# network = 'key_PArthENoPE'
# network = 'key_YOF'
abundance_model = AbundanceModel(NuclearRates(nuclear_net=network))

## Abundances with nuisances

The rate for each nuclear reaction $i$ is $r_i(T) \equiv u^{-1} \langle \sigma v \rangle (T)$, where $u$ is the atomic mass unit, $\langle \sigma v \rangle$.  Rate uncertainties in LINX are captured by taking $r_i$ to be log-normally distributed, with a mean value $ \overline{r}_i(T)$ and standard deviation $\sigma_i(T)$.  Specifically, $\log r_i (T) = \log \overline{r}_i (T) + q_i \sigma_i (T)$, where $q_i$ is a unit Gaussian random variable.

Let's define a new wrapper function that takes an argument for `nuclear_rates_q`, where `nuclear_rates_q` is an array of $q_i$.  Passing in nonzero values for the entries of `nuclear_rates_q` will allow us to vary the nuclear rates.

In [None]:
def get_abundance_eta_tau_q(eta_fac, tau_n_fac, nuclear_rates_q):
    Yn, Yp, Yd, Yt, YHe3, Ya, YLi7, YBe7 = abundance_model(
        jnp.array(rho_g_vec),
        jnp.array(rho_nu_vec),
        jnp.zeros_like(rho_g_vec),
        jnp.zeros_like(rho_g_vec),
        t_vec=jnp.array(t_vec_ref),
        a_vec=jnp.array(a_vec_ref),
        eta_fac=jnp.asarray(eta_fac),
        tau_n_fac=jnp.asarray(tau_n_fac),
        nuclear_rates_q=nuclear_rates_q
    )
    return jnp.array([Yn, Yp, Yd, Yt, YHe3, Ya, YLi7, YBe7])

get_abundance_v = vmap(get_abundance_eta_tau_q, in_axes=(None, None, 0))

The last line vectorizes our wrapper function over `nuclear_rates_q`, so that we can pass in many arrays of `q` at once.

Let's compute our fiducial values of the primordial abundances, with everything set to its median values:

In [None]:
num_reactions = len(abundance_model.nuclear_net.reactions)
fiducial = get_abundance_eta_tau_q(1., 1., jnp.zeros(num_reactions))

Next we compute abundances when we vary either $\eta$ or $\tau_n$ by one sigma:

In [None]:
eta_vary = get_abundance_eta_tau_q(1.006708,1.,jnp.zeros(num_reactions)) # vary within one sigma according to Planck 2018
tau_vary = get_abundance_eta_tau_q(1.,1.000682,jnp.zeros(num_reactions)) # vary within one sigma according to PDG

Finally, we vary each of the reactions by one sigma, one at a time, and compute the resulting abundances:

In [None]:
reac_arrays = jnp.diag(jnp.ones(num_reactions))
reac_vary = get_abundance_v(1.,1.,reac_arrays)

We can stack up the data and print it out in a table:

In [None]:
all_vary = np.vstack((eta_vary,tau_vary,reac_vary))

In [None]:
reac_names = [abundance_model.nuclear_net.reactions[i].name for i in range(num_reactions)]
varied_params = np.concatenate((["eta","tau_n"],reac_names))
abundances = ['D/H','Yp','He3/H','Li7/H']

In [None]:
abundances = ['D/H x 1e5','Yp','He3/H x 1e5','Li7/H x 1e11']

table = []
for i in range(len(varied_params)):
    row = [varied_params[i],all_vary[i][2]/all_vary[i][1]*1e5, 4*all_vary[i][5],all_vary[i][4]/all_vary[i][1]*1e5,all_vary[i][6]/all_vary[i][1]*1e11]
    table.append(row)

col_width = 15
decimal_places = 5

print(f"{"".ljust(col_width)} {abundances[0].ljust(col_width)} {abundances[1].ljust(col_width-5)} {abundances[2].ljust(col_width)} {abundances[3].ljust(col_width)}")
print("-" * (col_width * 5))
for row in table:
    formatted_row = f"{str(row[0])[:col_width-1]:<{col_width}}" 
    for item in row[1:]:
        formatted_row += f"{item:<{col_width}.{decimal_places}f}"
    print(formatted_row)

In [None]:
abundances = ['% D/H','% Yp','% He3/H','% Li7/H']

def percentage(varied,fiducial):
    return 100*np.abs((varied - fiducial)/((varied + fiducial)/2))

table = []
for i in range(len(varied_params)):
    row = [varied_params[i],
           percentage(all_vary[i][2]/all_vary[i][1],(fiducial[2]/fiducial[1])), 
           percentage(all_vary[i][5],fiducial[5]),
           percentage(all_vary[i][4]/all_vary[i][1],fiducial[4]/fiducial[1]),
           percentage(all_vary[i][6]/all_vary[i][1],fiducial[6]/fiducial[1])]
    table.append(row)

col_width = 15
decimal_places = 5

print(f"{"".ljust(col_width)} {abundances[0].ljust(col_width)} {abundances[1].ljust(col_width-4)} {abundances[2].ljust(col_width)} {abundances[3].ljust(col_width)}")
print("-" * (col_width * 5))
for row in table:
    formatted_row = f"{str(row[0])[:col_width-1]:<{col_width}}" 
    for item in row[1:]:
        formatted_row += f"{item:<{col_width}.{decimal_places}f}" 
    print(formatted_row)