In [None]:
import jax
import numpy as np
import jax.numpy as jnp
import equinox as eqx
from pylinger_cosmo import cosmo
from pylinger_pt_single_mode_exposed_param import evolve_one_mode, adiabatic_ics


In [None]:
## Cosmological Parameters
Tcmb = 2.7255
YHe = 0.248
Omegam = 0.276
Omegab = 0.0455
OmegaL = 1.0-Omegam
num_massive_neutrinos = 1
mnu=0.06 #eV
Neff=2.046 # -1 if massive neutrino present
standard_neutrino_neff=Neff+num_massive_neutrinos
h = 0.703
A_s = 2.1e-9
n_s = 0.965
k_p = 0.05

In [None]:
@eqx.filter_jit
def f_of_Omegam( Omegam ):
    # Compute Background
    cp = cosmo(Omegam=Omegam, Omegab=Omegab, OmegaL=OmegaL, H0=100*h, Tcmb=Tcmb, YHe=YHe, Neff=Neff, Nmnu=num_massive_neutrinos, mnu=mnu )
    
    # Compute Perturbations
    lmaxg  = 12
    lmaxgp = 12
    lmaxr  = 17
    lmaxnu = 17
    nqmax  = 15
    nvar   = 7 + (lmaxg + 1) + (lmaxgp + 1) + (lmaxr + 1) + nqmax * (lmaxnu + 1)
    rtol   = 1e-3
    atol   = 1e-4

    tau_start = 0.01
    tau_max   = 1.0
    
    kmode  = 1e-1

    y0 = adiabatic_ics( omegam=Omegam, omegab=Omegab, tau=tau_start, param=cp.param, kmodes=jnp.array([kmode]), num_k=1, nvar=nvar, 
                       lmaxg=lmaxg, lmaxgp=lmaxgp, lmaxr=lmaxr, lmaxnu=lmaxnu, nqmax=nqmax )
    
    y1 = evolve_one_mode( y0=y0[-1,:], tau_start=tau_start, tau_max=tau_max, tau_out=jnp.array([tau_max]), omegam=Omegam, omegab=Omegab, 
                        param=cp.param, kmode=kmode, lmaxg=lmaxg, lmaxgp=lmaxgp, lmaxr=lmaxr, lmaxnu=lmaxnu, nqmax=nqmax, rtol=rtol, atol=atol )
    
    # everything works if y0 is returned here
    return y1

In [None]:
dP = jax.jacrev( f_of_Omegam )

In [None]:
print( dP( Omegam ) )