In [None]:
from jax.scipy.stats import gamma
from jax import jit
from jax import numpy as jnp
from IPython.display import Markdown
import numpy as np
import pandas as pd
pd.options.plotting.backend = 'plotly'

import summer2

from emu_renewal.renew import RenewalModel, JaxModel

In [None]:
# Parameter values that previously worked for fitting to the example Malaysia data
gen_time_mean = 8.974486528334516
gen_time_sd = 5.444942511395974
proc_vals = np.array([0.11303378, -0.75406457, -0.04278638, -0.08685292, 0.34915623, 0.41983523, 0.27799167, 0.47966875, 0.18586957, 0.3891201, 0.22046122, 0.02434282])
seed_peak = 13892.00570877276

In [None]:
class JaxGammaDens():
    def get_params(self, mean, sd):
        scale = (sd * sd) / mean
        a = mean / scale
        return {'a': a, 'scale': scale}
    
    def get_densities(self, window_len, mean, sd):
        return jnp.diff(gamma.cdf(jnp.arange(window_len + 1), **self.get_params(mean, sd)))
    
    def get_description(self):
        '''Get the description of this code.

        Returns:
            The description in markdown format
        '''
        return '\n\n### Generation times\n' \
            'Generation times for each day are calculated by ' \
            'first finding the parameters needed to construct ' \
            'a gamma distribution with mean and standard deviation ' \
            'equal to those specified by the submitted parameter values. ' \
            'The integrals of the probability density of this distribution ' \
            'between consecutive integer values are then calculated for ' \
            'later combination with the incidence time series. '

In [None]:
dist = JaxGammaDens()
j = JaxModel(33e6, 50, 276, 30, 12, dist)

# @jit
def get_inc_result(gen_mean, gen_sd, proc, seed):
    return j.model_func(gen_mean, gen_sd, proc, jnp.log(seed))

In [None]:
# Markdown(j.get_description())

In [None]:
results_jax = get_inc_result(gen_time_mean, gen_time_sd, proc_vals, seed_peak)

In [None]:
incidence_jax = results_jax.incidence

In [None]:
m = RenewalModel(33e6, 276, 30, 12)
incidence_orig = m.func(gen_time_mean, gen_time_sd, proc_vals, np.log(seed_peak)).incidence
pd.DataFrame(
    {
        'original': incidence_orig,
        'jax': incidence_jax,
    }
).plot()