### Comparison between jax and non-jax models
Quick notebook to verify that jax and non-jax models can produce similar results
under equivalent configurations.
Specifically, if the window length for looking backwards at 
previous incidence values for the jax values is set to a high value
(relative to the size of the tail of its generation time distribution)
we get very similar numerical outputs back.

In [None]:
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"

from emu_renewal.distributions import JaxGammaDens
from emu_renewal.renew import RenewalModel, JaxModel

In [None]:
# Reasonable parameter values for Malaysia data from previous calibration
gen_time_mean = 8.9744
gen_time_sd = 5.4449
proc_vals = np.array([0.1130, -0.7540, -0.0427, -0.0868, 0.3491, 0.4198, 0.2779, 0.4796, 0.1858, 0.389, 0.2204, 0.0243])
seed_peak = np.log(13892.0057)

In [None]:
window_len = 50
j = JaxModel(33e6, 50, 326, 30, 12, JaxGammaDens(), window_len)
m = RenewalModel(33e6, 276, 30, 12)

def get_inc_result(gen_mean, gen_sd, proc, seed):
    return j.func(gen_mean, gen_sd, proc, seed)

In [None]:
incidence_j = get_inc_result(gen_time_mean, gen_time_sd, proc_vals, seed_peak).incidence
incidence_m = m.func(gen_time_mean, gen_time_sd, proc_vals, seed_peak).incidence

In [None]:
pd.DataFrame({"original": incidence_m, "jax": incidence_j}).plot()