In [None]:
from jax import lax, jit, numpy as jnp
from jax.scipy.stats import gamma
from jax import numpy as jnp, vmap, grad
from IPython.display import Markdown
import numpy as np
import pandas as pd
pd.options.plotting.backend = 'plotly'

import summer2

from emu_renewal.renew import RenewalModel
from emu_renewal.process import cosine_multicurve, sinterp
from typing import NamedTuple

In [None]:
# Parameter values that previously worked for fitting to the example Malaysia data
gen_time_mean = 8.974486528334516
gen_time_sd = 5.444942511395974
proc_vals = np.array([0.11303378, -0.75406457, -0.04278638, -0.08685292, 0.34915623, 0.41983523, 0.27799167, 0.47966875, 0.18586957, 0.3891201, 0.22046122, 0.02434282])
seed_peak = 13892.00570877276

In [None]:
class JaxGammaDens():
    def get_params(self, mean, sd):
        scale = (sd * sd) / mean
        a = mean / scale
        return {'a': a, 'scale': scale}
    
    def get_densities(self, window_len, mean, sd):
        return jnp.diff(gamma.cdf(jnp.arange(window_len + 1), **self.get_params(mean, sd)))
    
    def get_description(self):
        '''Get the description of this code.

        Returns:
            The description in markdown format
        '''
        return '\n\n### Generation times\n' \
            'Generation times for each day are calculated by ' \
            'first finding the parameters needed to construct ' \
            'a gamma distribution with mean and standard deviation ' \
            'equal to those specified by the submitted parameter values. ' \
            'The integrals of the probability density of this distribution ' \
            'between consecutive integer values are then calculated for ' \
            'later combination with the incidence time series. '

In [None]:
class RenewalState(NamedTuple):
    incidence: jnp.array
    suscept: float

class JaxModel:
    def __init__(self, population, window_len, n_times, run_in, n_process_periods, dens_obj):
        self.pop = population
        self.n_times = n_times
        self.run_in = run_in
        self.n_process_periods = n_process_periods
        self.window_len = window_len
        self.x_proc_vals = sinterp.get_scale_data(jnp.linspace(0.0, self.n_times, self.n_process_periods))
        self.dens_obj = dens_obj
        self.model_times = jnp.arange(self.n_times)
        self.seed_x_vals = [0.0, round(self.run_in * 0.5), self.run_in]

    def seed_func(self, t, seed):
        x_vals = sinterp.get_scale_data(jnp.array(self.seed_x_vals))
        y_vals = sinterp.get_scale_data(jnp.array([0.0, jnp.exp(seed), 0.0]))
        return cosine_multicurve(t, x_vals, y_vals)
    
    def model_func(self, gen_time_mean, gen_time_sd, process_req, seed):
        densities = self.dens_obj.get_densities(self.window_len, gen_time_mean, gen_time_sd)

        y_proc_vals = sinterp.get_scale_data(process_req)
        process_vals = jnp.exp(vmap(cosine_multicurve, in_axes=(0, None, None))(self.model_times, self.x_proc_vals, y_proc_vals))

        init_state = RenewalState(np.zeros(self.window_len), self.pop)
        
        def state_update(state: RenewalState, t) -> tuple[RenewalState, jnp.array]:
            r_t = process_vals[t] * state.suscept / self.pop
            renewal = (densities * state.incidence).sum() * r_t
            seed_component = self.seed_func(t, seed)
            total_new_incidence = renewal + seed_component
            total_new_incidence = jnp.where(total_new_incidence > state.suscept, state.suscept, total_new_incidence)
            suscept = state.suscept - total_new_incidence
            incidence = jnp.zeros_like(state.incidence)
            incidence = incidence.at[1:].set(state.incidence[:-1])
            incidence = incidence.at[0].set(total_new_incidence)
            return RenewalState(incidence, suscept), jnp.array([total_new_incidence, suscept])

        end_state, outputs = lax.scan(state_update, init_state, self.model_times)

        return outputs
    
    def get_description(self):
        renew_desc = (
            '\n\n### Renewal process\n'
            'Calculation of the renewal process '
            'consists of multiplying the incidence values for the preceding days '
            'by the reversed generation time distribution values. '
            'This follows a standard formula, '
            'described elsewhere by several groups,[@cori2013; @faria2021] i.e. '
            '$$i_t = R_t\sum_{\\tau<t} i_\\tau g_{t-\\tau}$$\n'
            '$R_t$ is calculated as the product of the proportion '
            'of the population remaining susceptible '
            'and the non-mechanistic random process '
            'generated external to the renewal model. '
            'The susceptible population is calculated by '
            'subtracting the number of new incident cases from the '
            'running total of susceptibles at each iteration.\n'
        )

        non_mech_desc = (
            '\n\n### Non-mechanistic process\n'
            'The time values corresponding to the submitted process values '
            'are set to be evenly spaced throughout the simulation period. '
            'Next, a continuous function of time was constructed from '
            'the non-mechanistic process series values submitted to the model. '
            'After curve fitting, the sequence of parameter values pertaining to '
            'the non-mechanistic process are exponentiated, '
            'such that parameter exploration for these quantities is '
            'undertaken in the log-transformed space. '
        )
        
        seed_desc = (
            '\n\n### Seeding\n'
            'Seeding was achieved by interpolating using a cosine function. '
            f'The number of seeded cases scaled from ? at time {self.seed_x_vals[0]} '
        )
        
        return renew_desc + non_mech_desc + self.dens_obj.get_description() + seed_desc

In [None]:
dist = JaxGammaDens()
j = JaxModel(33e6, 50, 276, 30, 12, dist)

@jit
def get_inc_result(gen_time_mean, gen_time_sd, proc_vals, seed_peak):
    return j.model_func(gen_time_mean, gen_time_sd, proc_vals, jnp.log(seed_peak))[:, 0]

In [None]:
Markdown(j.get_description())

In [None]:
incidence_jax = get_inc_result(gen_time_mean, gen_time_sd, proc_vals, seed_peak)

In [None]:
m = RenewalModel(33e6, 276, 30, 12)
incidence_orig = m.func(gen_time_mean, gen_time_sd, proc_vals, np.log(seed_peak)).incidence
pd.DataFrame(
    {
        'original': incidence_orig,
        'jax': incidence_jax,
    }
).plot()