In [None]:
# We want to this first because it handles all our jax init functions etc
import summer2

In [None]:
from jax import lax, jit, numpy as jnp
import numpy as np
import pandas as pd

In [None]:
from emu_renewal.renew import RenewalModel

In [None]:
# Set some global values - these get consumed by our jax model function later

population = 50000.0
n_times = 300
run_in = 10
n_process_periods = 4


# example parameter values for validation
gen_time_mean = 8.5
gen_time_sd = 3.0
proc_vals = np.array([0.1,0.2,0.3,0.4])
seed_peak = 10.0

In [None]:
m = RenewalModel(population, n_times, run_in, n_process_periods)

In [None]:
%%timeit
_ = m.func(gen_time_mean,gen_time_sd,proc_vals,seed_peak)

In [None]:
incidence_orig = m.func(gen_time_mean,gen_time_sd,proc_vals,seed_peak).incidence
pd.Series(incidence_orig).plot()

In [None]:
from typing import NamedTuple

# Maximum window length for gamma densities (ie width of sliding window of incidence history)
WLEN = 30

# Our state class, used by lax.scan
class RenewalState(NamedTuple):
    incidence: jnp.array # WLEN sliding window
    suscept: float # previous suscept value


In [None]:
from jax.scipy.stats.gamma import cdf
from jax import numpy as jnp, vmap, grad

In [None]:
# jax gamma density function

def get_densities(window_len, mean, sd):
    scale = (sd*sd)/mean
    a = mean/scale
    return jnp.diff(cdf(jnp.arange(window_len+1),a=a,scale=scale))

In [None]:
from emu_renewal.process import cosine_multicurve, sinterp

In [None]:
def seed_func(t, seed):
    xvals = sinterp.get_scale_data(jnp.array([0.0,run_in*0.5,run_in]))
    yvals = sinterp.get_scale_data(jnp.array([0.0,seed,0.0]))
    return cosine_multicurve(t, xvals, yvals)

@jit
def model_func(gen_time_mean, gen_time_sd, process_req, seed):
    
    densities = get_densities(WLEN, gen_time_mean, gen_time_sd)

    init_state = RenewalState(np.zeros(WLEN), population)

    xvals = sinterp.get_scale_data(jnp.linspace(0.0, n_times, n_process_periods))
    yvals = sinterp.get_scale_data(process_req)
    process_vals = jnp.exp(vmap(cosine_multicurve, in_axes=(0,None,None))(jnp.arange(n_times),xvals,yvals))

    def state_update(state: RenewalState, t) -> tuple[RenewalState, jnp.array]:
        r_t = process_vals[t] * state.suscept / population
        renewal = (densities * state.incidence).sum() * r_t
        seed_component = seed_func(t, seed)
        total_new_incidence = renewal + seed_component
        total_new_incidence = jnp.where(total_new_incidence > state.suscept, state.suscept, total_new_incidence)
        suscept = state.suscept - total_new_incidence
        incidence = jnp.zeros_like(state.incidence)
        incidence = incidence.at[1:].set(state.incidence[:-1])
        incidence = incidence.at[0].set(total_new_incidence)
        return RenewalState(incidence, suscept), jnp.array([total_new_incidence, suscept])

    end_state, outputs = lax.scan(state_update, init_state, jnp.arange(n_times))

    return outputs

In [None]:
incidence_jax = model_func(gen_time_mean, gen_time_sd, proc_vals, seed_peak)[:,0]
pd.Series(incidence_jax).plot()