In [None]:
from jax import numpy as jnp
from datetime import timedelta
from typing import List
import pandas as pd
import numpy as np
import scipy
from scipy.optimize import minimize

from summer2.parameters import Function, Data, Time, DerivedOutput

from aust_covid.inputs import get_ifrs
from emutools.tex import StandardTexDoc
from aust_covid.inputs import get_base_vacc_data
from aust_covid.vaccination import add_booster_data_to_vacc
from aust_covid.model import build_model
from aust_covid.tracking import track_immune_prop
from inputs.constants import PROJECT_PATH, SUPPLEMENT_PATH, IMMUNITY_LAG
from emutools.inputs import load_param_info

In [None]:
def calculate_rates_for_interval(
        start_props: pd.core.series.Series, end_props: pd.core.series.Series, delta_t: float, strata: List[str],
        active_flows: dict
) -> dict:
    """
    Calculate the transition rates associated with each inter-stratum flow for a given time interval.

    The system can be described using a linear ordinary differential equations such as:
    X'(t) = M.X(t) ,
    where M is the transition matrix and X is a column vector representing the proportions over time

    The solution of this equation is X(t) = exp(M.t).X_0,
    where X_0 represents the proportions at the start of the time interval.

    The transition parameters informing M must then verify the following equation:
    X(t_end) = exp(M.delta_t).X_0,
    where t_end represents the end of the time interval.

    Args:
        start_props: user-requested stratum proportions at the start of the time interval
        end_props: user-requested stratum proportions at the end of the time interval
        delta_t: width of the time interval
        strata: list of strata
        active_flows: Dictionary listing the flows driving the inter-stratum transitions. Keys are flow names and values
        are length-two tuples representing the flows' sources and destinations.
    Returns:
        The estimated transition rates stored in a dictionary using the flow names as keys.

    """
    # Determine some basic characteristics
    n_strata = len(strata)
    n_params = len(active_flows)
    ordered_flow_names = list(active_flows.keys())

    # Create the function that we need to find the root of
    def function_to_zero(params):
        # params is a list ordered in the same order as ordered_flow_names

        # Create the transition matrix associated with a given set of transition parameters
        m = np.zeros((n_strata, n_strata))
        for i_row, stratum_row in enumerate(strata):
            for i_col, stratum_col in enumerate(strata):
                if i_row == i_col:
                    # Diagonal components capture flows starting from the associated stratum
                    relevant_flow_names = [f_name for f_name, f_ends in active_flows.items() if f_ends[0] == stratum_row]
                    for f_name in relevant_flow_names:
                        m[i_row, i_col] -= params[ordered_flow_names.index(f_name)]
                else:
                    # Off-diagonal components capture flows from stratum_col to stratum_row
                    for f_name, f_ends in active_flows.items():
                        if f_ends == (stratum_col, stratum_row):
                            m[i_row, i_col] = params[ordered_flow_names.index(f_name)]

        # Calculate the matrix exponential, accounting for the time interval width
        exp_mt = scipy.linalg.expm(m * delta_t)

        # Calculate the difference between the left and right terms of the equation
        diff = np.matmul(exp_mt, start_props) - end_props

        # Return the norm of the vector to make the minimised function a scalar function
        return scipy.linalg.norm(diff)

    # Define bounds to force the parameters to be positive
    bounds = [(0., None)] * n_params

    # Numerical solving
    solution = minimize(function_to_zero, x0=np.zeros(n_params), bounds=bounds, method="TNC")

    return {ordered_flow_names[i]: solution.x[i] for i in range(len(ordered_flow_names))}

In [None]:
def piecewise_constant(x, breakpoints, values):
    index = sum(x >= breakpoints)
    return values[index]

In [None]:
vacc_df = get_base_vacc_data()
vacc_df = add_booster_data_to_vacc(vacc_df)
vacc_data = vacc_df['prop boosted in preceding'].dropna()
vacc_data.index += timedelta(days=IMMUNITY_LAG)
vacc_data = vacc_data[~vacc_data.index.duplicated(keep='first')]

In [None]:
app_doc = StandardTexDoc(SUPPLEMENT_PATH, 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')
param_info = load_param_info(PROJECT_PATH / 'inputs' / 'parameters.yml')
param_info['value'].update(get_ifrs(app_doc))
parameters = param_info['value'].to_dict()
parameters['imm_prop'] = vacc_data[0]

epi_model = build_model(app_doc)
epoch = epi_model.get_epoch()
track_immune_prop(epi_model)

In [None]:
rates_df = pd.DataFrame(columns=['vaccination', 'waning'])
rates = []
for i, date in enumerate(vacc_data.index[:-1]):
    next_date = vacc_data.index[i + 1]
    duration = (next_date - date).days
    start_prop = vacc_data.loc[date]
    start_props = [start_prop, 1.0 - start_prop]
    end_prop = vacc_data.loc[next_date]
    end_props = [end_prop, 1.0 - end_prop]
    flows = {'vaccination': ['nonimm', 'imm'], 'waning': ['imm', 'nonimm']}
    i_rates = calculate_rates_for_interval(start_props, end_props, duration, ['imm', 'nonimm'], flows)
    rates_df.loc[date, :] = i_rates
    rates.append(i_rates)

In [None]:
vacc_rates = [r['vaccination'] for r in rates]
wane_rates = [r['waning'] for r in rates]

In [None]:

time_vals = Data(jnp.array([*epoch.datetime_to_number(vacc_data.index)]))

functions = {}

for process in ['vaccination', 'waning']:
    vals = Data(jnp.array((0.0, *rates_df[process], 0.0)))
    functions[process] = Function(piecewise_constant, [Time, time_vals, vals])

# vacc_function = Function(piecewise_constant, [Time, time_vals, vals])
# vals = Data(jnp.array((0.0, *rates_df['waning'], 0.0)))
# wane_function = Function(piecewise_constant, [Time, time_vals, vals])

In [None]:
for comp in set([c.name for c in epi_model.compartments]):
    epi_model.add_transition_flow(
        'vaccination',
        functions['vaccination'],
        source=comp,
        dest=comp,
        source_strata={'immunity': 'nonimm'},
        dest_strata={'immunity': 'imm'},
    )
    epi_model.add_transition_flow(
        'waning',
        functions['waning'],
        source=comp,
        dest=comp,
        source_strata={'immunity': 'imm'},
        dest_strata={'immunity': 'nonimm'},
    )

In [None]:
from plotly import graph_objects as go

epi_model.run(parameters=parameters)
fig = epi_model.get_derived_outputs_df()[['prop_imm', 'prop_nonimm']].plot.area()
fig.data[0].line.width = 0
fig.add_trace(go.Scatter(x=vacc_data.index, y=vacc_data, name='input data', line={'color': 'black', 'dash': 'dash'}))
fig.update_layout(xaxis_range=epoch.index_to_dti([epi_model.times[0], epi_model.times[-1]]), yaxis_range=[0.0, 1.0])