In [1]:
import jax
import numpy as np
import jax.numpy as jnp
import equinox as eqx
import diffrax
jax.config.update("jax_enable_x64", True)
# jax.config.update('jax_platform_name', 'cpu')
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_disable_jit", True)
from pylinger_background import evolve_background
from pylinger_perturbations import evolve_one_mode


In [2]:
## Cosmological Parameters
Tcmb = 2.7255
YHe = 0.248
Omegam = 0.276
Omegab = 0.0455
OmegaL = 1.0-Omegam
num_massive_neutrinos = 1
mnu=0.06 #eV
Neff=2.046 # -1 if massive neutrino present
standard_neutrino_neff=Neff+num_massive_neutrinos
h = 0.703
A_s = 2.1e-9
n_s = 0.965
k_p = 0.05

In [3]:
# @eqx.filter_jit
def f_of_Omegam( args ):
    param = {}
    param['Omegam'] = args[0]
    param['Omegab'] = args[1]
    param['OmegaL'] = OmegaL
    param['Omegak'] = 0.0
    param['A_s'] = A_s
    param['n_s'] = n_s
    param['H0'] = 100*h
    param['Tcmb'] = Tcmb
    param['YHe'] = YHe
    param['Neff'] = Neff
    param['Nmnu'] = num_massive_neutrinos
    param['mnu'] = mnu
    param = evolve_background(param=param)
    
    k = 1e-2

    # Compute Perturbations
    lmaxg  = 12
    lmaxgp = 12
    lmaxr  = 17
    lmaxnu = 17
    nqmax  = 15

    rtol   = 1e-3
    atol   = 1e-4

    y = evolve_one_mode( tau_max=1.0, tau_out=jnp.array([1.0]), 
                      param=param, kmode=k, lmaxg=lmaxg, lmaxgp=lmaxgp, lmaxr=lmaxr, 
                      lmaxnu=lmaxnu, nqmax=nqmax, rtol=rtol, atol=atol )
    
    return y

In [4]:
# f_of_Omegam(0.273)

In [5]:
dP = jax.jacfwd( f_of_Omegam )
print( dP( jnp.array([Omegam, Omegab]) ) )

[[[ 1.37446120e-08  5.25311904e-30]
  [ 4.57971641e-13 -2.76063220e-18]
  [ 3.11128386e-05 -3.80690732e-10]
  [-2.12000171e-07  1.27647855e-13]
  [ 4.43773433e-24 -3.16322136e-25]
  [-2.11999816e-07  3.13073662e-12]
  [-1.77244928e-12 -1.50107967e-11]
  [-2.82666421e-07  4.17431572e-12]
  [-1.77244883e-12 -1.50107979e-11]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.000000