In [18]:
import numpy as np
import jax.numpy as jnp
from jax import lax

from diffstar.kernels.main_sequence_kernels import (
    _lax_ms_sfh_scalar_kern_scan, 
    _sfr_eff_plaw, 
    MS_BOUNDING_SIGMOID_PDICT,
)
from diffstar.utils import _inverse_sigmoid, _jax_get_dt_array, _sigmoid
from diffstar.kernels.main_sequence_kernels import DEFAULT_MS_PARAMS
from diffmah.defaults import DEFAULT_MAH_PARAMS
from diffstar.defaults import T_TABLE_MIN, TODAY
from diffstar.defaults import FB

lgt0 = jnp.log10(TODAY)

t_form = 12.0 # time t at which we compute the value SFR(t)

# define integration table of t' used to compute and sum up contributions to SFR(t) 
# from gas parcels accreted at earlier times t'<t
t_table = jnp.linspace(T_TABLE_MIN, t_form, 20) 

_lax_ms_sfh_scalar_kern_scan(t_form, DEFAULT_MAH_PARAMS, DEFAULT_MS_PARAMS, lgt0, FB, t_table)
# Array(0.27356067, dtype=float32)

Array(0.27356064, dtype=float32)

In [37]:
from jax import jit as jjit
from jax import vmap
from diffmah.defaults import MAH_K
from diffmah.individual_halo_assembly import (
    _calc_halo_history,
    _rolling_plaw_vs_logt,
)
from diffstar.kernels.gas_consumption import _gas_conversion_kern


_vmap_gas_conversion_kern = vmap(
    _gas_conversion_kern,
    in_axes=(None, 0, None, None, None)
)


@jjit
def _lax_ms_sfh_scalar_kern_sum(t_form, mah_params, ms_params, lgt0, fb, t_table):
    logmp, logtc, early, late = mah_params
    all_mah_params = lgt0, logmp, logtc, MAH_K, early, late
    lgt_form = jnp.log10(t_form)
    log_mah_at_tform = _rolling_plaw_vs_logt(lgt_form, *all_mah_params)

    sfr_eff_params = ms_params[:4]
    sfr_eff = _sfr_eff_plaw(log_mah_at_tform, *sfr_eff_params)

    tau_dep = ms_params[4]
    tau_dep_max = MS_BOUNDING_SIGMOID_PDICT["tau_dep"][3]

    # compute inst. gas accretion
    lgtacc = jnp.log10(t_table)
    res = _calc_halo_history(lgtacc, *all_mah_params)
    dmhdt_at_tacc, log_mah_at_tacc = res
    dmgdt_inst = fb * dmhdt_at_tacc

    # compute the consumption kernel
    dt = t_table[1]-t_table[0]
    kern = _vmap_gas_conversion_kern(t_form, t_table, dt, tau_dep, tau_dep_max)

    # convolve
    dmgas_dt = jnp.sum(dmgdt_inst * kern * dt)
    sfr = dmgas_dt * sfr_eff
    return sfr


print("scan:", _lax_ms_sfh_scalar_kern_scan(t_form, DEFAULT_MAH_PARAMS, DEFAULT_MS_PARAMS, lgt0, FB, t_table))
print("conv:", _lax_ms_sfh_scalar_kern_conv(t_form, DEFAULT_MAH_PARAMS, DEFAULT_MS_PARAMS, lgt0, FB, t_table))

scan: 0.27356064
conv: 0.27356067
