In [None]:
from jax import numpy as jnp

from aust_covid.model import build_model
from inputs.constants import SUPPLEMENT_PATH

from summer2.parameters import Function, Data, Time

from emutools.tex import StandardTexDoc
from aust_covid.inputs import get_base_vacc_data
from aust_covid.vaccination import add_booster_data_to_vacc

In [None]:
from computegraph import jaxify

fnp = jaxify.get_modules()['numpy']

if jaxify.get_using_jax():
    # Jax only
    from jax import lax

    def piecewise_function(x, breakpoints, functions):
        index = sum(x >= breakpoints)
        return lax.switch(index, functions, x)

else:

    def piecewise_function(x, breakpoints, functions):
        index = sum(x >= breakpoints)
        return functions[index](x)


def piecewise_constant(x, breakpoints, values):
    index = sum(x >= breakpoints)
    return values[index]

In [None]:
vacc_df = get_base_vacc_data()
vacc_df = add_booster_data_to_vacc(vacc_df)
vacc_data = vacc_df['prop boosted in preceding'].dropna()

In [None]:
vacc_data.plot()

In [None]:
app_doc = StandardTexDoc(SUPPLEMENT_PATH, 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')
epi_model = build_model(app_doc)
epoch = epi_model.get_epoch()

In [None]:
data = vacc_df['prop boosted in preceding'].dropna()
bp = Data(jnp.array([*epoch.datetime_to_number(vacc_data.index)]))
vals = Data(jnp.array((0.0, *vacc_data, 0.0)))
Function(piecewise_constant, [Time, bp, vals])