# Bringing it all together

In this notebook we aim to apply several different stratifications, mixing and vaccination concepts learnt previously into a single model. The goal here is to take everything that we have learnt previously using "toy" code with simplified stratifications/functionality loosely representative of Covid transmission - and extend this to a model that we can be satisfied reasonably replicates most of the most important dynamics of Covid transmission in Malaysia.

A little like notebook number 04, this is an opportunity to take stock and apply the epidemiological principles we have worked through in the preceding notebooks.

## Standard preliminaries
Before we get into building the model, let's start off with some of our standard (or "boilerplate") code to get everything set up.

In [None]:
# pip install the required packages if running in Colab
try:
    import google.colab
    IN_COLAB = True
    %pip install summerepi2==1.0.4
    %pip install estival==0.1.4

except:
    IN_COLAB = False

In [None]:
# Standard imports, plotting option and constant definition
from datetime import datetime, timedelta
from typing import List, Union
import pandas as pd
import plotly.express as px
import numpy as np
import pickle

from jax import numpy as jnp

from summer2.utils import ref_times_to_dti

from summer2 import CompartmentalModel, Stratification, StrainStratification, Overwrite
from summer2.parameters import Parameter, Function, DerivedOutput, Time, Data

pd.options.plotting.backend = "plotly"

COVID_BASE_DATE = datetime(2019, 12, 31)
region = "Malaysia"

In [None]:
# Get pregenerated data
if IN_COLAB:
    !wget https://raw.githubusercontent.com/monash-emu/AuTuMN/master/notebooks/capacity_building/malaysia/MYS_matrices.pkl
    !wget https://raw.githubusercontent.com/monash-emu/AuTuMN/master/notebooks/capacity_building/malaysia/MYS_age_pops.pkl
    !wget https://raw.githubusercontent.com/monash-emu/AuTuMN/master/notebooks/capacity_building/malaysia/MYS_vac_coverage.pkl

mixing_matrix = pd.read_pickle("MYS_matrices.pkl", compression='infer')
age_mixing_matrix = mixing_matrix["all_locations"]
# px.imshow(age_mixing_matrix)

In [None]:
# Get a function to access the Malaysia data if running in Colab
if IN_COLAB:
    !wget https://raw.githubusercontent.com/monash-emu/AuTuMN/master/notebooks/capacity_building/malaysia/get_mys_data.py

import get_mys_data

In [None]:
age_groups = range(0, 80, 5)

In [None]:
age_pops = pickle.load(open("MYS_age_pops.pkl", "rb"))

In [None]:
# ... and use it to get the actual data
df = get_mys_data.fetch_mys_data()
initial_population = get_mys_data.get_initial_population(region)
observations = get_mys_data.get_target_observations(df, region, "cases")

In [None]:
# Define a target set of observations to compare against our modelled outputs later
notifications_target = observations["cases_new"]
deaths_target = get_mys_data.get_target_observations(df, region, "deaths")["deaths_new"]

# Model

## Define a model

In [None]:
unstratified_compartments = ["S", "E", "I", "R", "S2"]

In [None]:
def build_unstratified_model(
    config: dict
) -> CompartmentalModel:
    """
    Create a compartmental model, with compartmental structure needed to run and produce some sort of 
    meaningful outputs.
    
    Args:
        parameters: Flow parameters
    Returns:
        A compartmental model currently without stratification applied
    """

    model = CompartmentalModel(
        times=(config["start_time"], config["end_time"]),
        compartments=unstratified_compartments,
        infectious_compartments=["I"],
        ref_date=COVID_BASE_DATE
    )

    infectious_seed = Parameter("infectious_seed")

    model.set_initial_population(
        distribution=
        {
            "S": initial_population - infectious_seed, 
            "I": infectious_seed,
        }
    )
    
    # Susceptible people can get infected
    model.add_infection_frequency_flow(
        name="infection", 
        contact_rate=Parameter("contact_rate"), 
        source="S",
        dest="E",
    )
    
    # Recovered people get infected at the same rate as the susceptibles
    model.add_infection_frequency_flow(
        name="reinfection", 
        contact_rate=Parameter("contact_rate"), 
        source="S2", 
        dest="E",
    )
    
    # Exposed people transition to infected at some rate
    model.add_transition_flow(
        name="progression",
        fractional_rate=1. / Parameter("latent_period"),
        source="E",
        dest="I",
    )
    
    recover_prop = 1. - Parameter("infection_fatality_prop")

    # Infectious people recover after some time spent infectious
    model.add_transition_flow(
        name="recovery",
        fractional_rate=recover_prop / Parameter("infectious_period"),
        source="I",
        dest="R",
    )
    
    # Recovered people lose their immunity at some rate
    model.add_transition_flow(
        name="waning_natural_immunity",
        fractional_rate=1. / Parameter("full_immunity_period"),
        source="R",
        dest="S2",
    )

    # Add an infection-specific death flow to the I compartment
    model.add_death_flow(
        name="infection_death", 
        death_rate=Parameter("infection_fatality_prop") / Parameter("infectious_period"), 
        source="I"
    )
    
    # Only a proportion of new cases are identified as cases
    model.request_output_for_flow(
        "progressions",
        "progression",
    )
    model.request_function_output(
        "notifications",
        func=DerivedOutput("progressions") * Parameter("cdr")
    )
    
    # Track proportion ever infected
    model.request_output_for_compartments(
        "n_ever_infected",
        ("E", "I", "R", "S2"),
        save_results=False,
    )
    model.request_output_for_compartments(
        "total_pop",
        unstratified_compartments,
        save_results=False,
    )
    model.request_function_output(
        "prop_ever_infected",
        func=DerivedOutput("n_ever_infected") / DerivedOutput("total_pop"),
    )

    return model

In [None]:
def get_age_stratification(
    compartments_to_stratify: List[str],
    strata: List[str],
    matrix: Union[np.ndarray, callable],
) -> Stratification:
    """
    Create a summer stratification object that stratifies all of the compartments into
    strata, which are intended to represent age bands according to the user inputs.
    This is essentially adapting the model's age stratification approach to the format
    of the mixing matrix, which is a reasonable approach.
    
    Args:
        compartments_to_stratify: List of the compartments to stratify, which should be all the compartments
        strata: The strata to be implemented in the age stratification
        matrix: The mixing matrix we are applying for the age structure
    Returns:
        A summer stratification object to represent age stratification (not yet applied)
    """
    
    if isinstance(matrix, np.ndarray):
        msg = "Mixing matrix is not 2-dimensional"
        assert matrix.ndim == 2, msg

        msg = f"Dimensions of the mixing matrix incorrect: {matrix.shape[0]}, {matrix.shape[1]}, {len(strata)}"
        assert matrix.shape[0] == matrix.shape[1] == len(strata), msg
    
    # Create the stratification, just naming the age groups by their starting value
    strat = Stratification(
        name="age", 
        strata=strata, 
        compartments=compartments_to_stratify
    )
    
    age_split_props = age_pops / age_pops.sum()
    strat.set_population_split(age_split_props.to_dict())
    
    # Add the mixing matrix to the stratification
    strat.set_mixing_matrix(matrix)
    
    return strat

In [None]:
def get_strain_stratification(
    compartments_to_stratify: List[str]
) -> Stratification:
    """
    Create a summer stratification object that stratifies compartments into
    strata, which are intended to represent infectious disease strains.
    
    Args:
        compartments_to_stratify: List of the compartments to stratify
        voc_params: A dictionary which speicifies the infectiousness and severity of strains
    Returns:
        A summer stratification object to represent strain stratification (not yet applied)
    """
    strata = ["delta", "omicron"]
    strat = StrainStratification(name="strain", strata=strata, compartments=compartments_to_stratify)

    # At the start of the simulation, a certain proportion of infected people have the variant strain.
    strat.set_population_split(
        {
            "delta": 1.,
            "omicron": 0.,
        }
    )

    for infection_flow in ["infection", "reinfection"]:
        strat.set_flow_adjustments(
            infection_flow,
            {
                "delta": None,
                "omicron": Parameter("omicron_rel_transmissibility"),
            },
        )

    return strat

In [None]:
full_dose_coverage = pickle.load(open("MYS_vac_coverage.pkl", "rb"))
full_dose_coverage.plot.area(title="two-dose vaccination coverage")

In [None]:
# To save on calculations a little, let's thin out the data
thinning_interval = 7
thinned_full_coverage = full_dose_coverage[::thinning_interval]

def get_prop_of_remaining_covered(old_prop, new_prop):
    return (new_prop - old_prop) / (1. - old_prop)

interval_prop_unvacc_vaccinated = [
    get_prop_of_remaining_covered(
        thinned_full_coverage[i],
        thinned_full_coverage[i + 1],
    ) 
    for i in range(len(thinned_full_coverage) - 1)
]

coverage_times = thinned_full_coverage.index

pd.Series(interval_prop_unvacc_vaccinated, index=coverage_times[1:]).plot(
    title="proportion of remaining unvaccinated vaccinated during each interval"
)

In [None]:
def get_rate_from_coverage_and_duration(coverage_increment: float, duration: float) -> float:
    assert duration >= 0.0, f"Duration request is negative: {duration}"
    assert 0.0 <= coverage_increment <= 1.0, f"Coverage increment not in [0, 1]: {coverage_increase}"
    return -np.log(1.0 - coverage_increment) / duration

vacc_times_int = (coverage_times - COVID_BASE_DATE).days

interval_lengths = [
    vacc_times_int[i + 1] - vacc_times_int[i] 
    for i in range(len(vacc_times_int) - 1)
]

vaccination_rates = [
    get_rate_from_coverage_and_duration(i, j) for 
    i, j in zip(interval_prop_unvacc_vaccinated, interval_lengths)
]
pd.Series(vaccination_rates, index=coverage_times[1:]).plot(kind="scatter")

In [None]:
def get_vacc_rate_func(end_times, vaccination_rates):
    def get_vaccination_rate(time, derived_outputs):

        # Identify the index of the first list element greater than the time of interest
        # If there is such an index, return the corresponding vaccination rate
        for end_i, end_t in enumerate(end_times):
            if end_t > time:
                return vaccination_rates[end_i]

        # Return zero if the time is after the last end time
        return 0.0
    return get_vaccination_rate

vacc_rate_func = get_vacc_rate_func(vacc_times_int, vaccination_rates)

In [None]:
from summer2.functions import get_piecewise_scalar_function

# Pad the endpoints with 0.0 - ie this is the vaccination rate outside of the bounds we have computed,
# and therefore what should be returned when requesting data outside this range
full_vacc_rates = np.concatenate([(0.0,), vaccination_rates, (0.0,)])

# Need to convert the coverage times to our 'days since covid-0' format
coverage_time_idx = (coverage_times - COVID_BASE_DATE).days

vacc_rate_func = get_piecewise_scalar_function(coverage_time_idx, full_vacc_rates)

In [None]:
def get_vaccine_stratification(
    compartments_to_stratify: List[str]
) -> Stratification:
    """
    Create a summer stratification object that stratifies compartments into
    strata, which are intended to represent vaccine stratifications.
    
    Args:
        compartments_to_stratify: List of the compartments to stratify
        vaccine_params: A dictionary which speicifies the vaccination-related parameters to implement
    Returns:
        A summer stratification object to represent strain stratification (not yet applied)
    """
    strata = ["vaccinated", "unvaccinated"]
    
    # Create the stratification
    vaccine_strat = Stratification(name="vaccination", strata=strata, compartments=compartments_to_stratify)

    # Create our population split dictionary, whose keys match the strata with 80% vaccinated and 20% unvaccinated
    pop_split = {
        "vaccinated": 0., 
        "unvaccinated": 1.,
    }

    # Set a population distribution
    vaccine_strat.set_population_split(pop_split)

    # Adjusting the death risk associated with vaccination
    vaccine_strat.set_flow_adjustments(
        "infection_death",
        {
            "unvaccinated": None,
            "vaccinated": 1. - Parameter("ve_death"),
        }
    )
    
    # Susceptibility
    for infection_flow in ["infection", "reinfection"]:
        vaccine_strat.set_flow_adjustments(
            infection_flow,
            {
                "unvaccinated": None,
                "vaccinated": 1. - Parameter("ve_infection"),
            }
        )

    return vaccine_strat

In [None]:
# Define some helper functions for seeding multiple strains

# Basic function to get windowed bounds for seeding

def get_window(start, length):
    return jnp.array((start,start+length))

def get_seed_function(start, length):
    # Because this is running 'inside the model' (and might be parameterized), these need to be
    # Function objects rather than just python functions
    
    time_window = Function(get_window, [start, length])
    return get_piecewise_scalar_function(time_window, [0.0, 1.0, 0.0])


In [None]:
start_date = datetime(2021, 3, 10)
end_date = start_date + timedelta(days=500)
start_date_int = (start_date - COVID_BASE_DATE).days
end_date_int = (end_date - COVID_BASE_DATE).days

config = {
    "start_time": start_date_int,
    "end_time": end_date_int,
}

def build_full_model(config):

    # Get an unstratified model object
    model = build_unstratified_model(config)
    
    base_compartments = model.compartments

    # Get and apply the age stratification
    age_strat = get_age_stratification(
        model.compartments, 
        range(0, 80, 5), 
        mixing_matrix["all_locations"],
    )
    model.stratify_with(age_strat)

    vacc_strat = get_vaccine_stratification(
        base_compartments,
    )
    model.stratify_with(vacc_strat)

    # Get and apply the strain stratification
    strain_strat = get_strain_stratification(
        ["E", "I"]
    )
    model.stratify_with(strain_strat)

    omicron_seed_func = get_seed_function(Parameter("omicron_seed_start"), 10.0)

    model.add_importation_flow(
        "omicron_seeding",
        omicron_seed_func,
        "E",
        split_imports=True,
        dest_strata={"strain": "omicron"},
    )

    for comp in unstratified_compartments:
        model.add_transition_flow(
            name="vaccination",
            fractional_rate=vacc_rate_func,
            source=comp,
            dest=comp,
            source_strata={"vaccination": "unvaccinated"},
            dest_strata={"vaccination": "vaccinated"},
        )

    model.request_output_for_compartments(
        "vaccinated",
        unstratified_compartments,
        strata={"vaccination": "vaccinated"},
    )
    model.request_output_for_compartments(
        "unvaccinated",
        unstratified_compartments,
        strata={"vaccination": "unvaccinated"},
    )
    
    return model

In [None]:
model = build_full_model(config)

In [None]:
model.get_input_parameters()

In [None]:
parameters = {
    "cdr": 0.1,
    "contact_rate": 0.05,
    "full_immunity_period": 180.0,
    "infection_fatality_prop": 0.005,
    "infectious_period": 5.0,
    "infectious_seed": 200.0,
    "latent_period": 3.0,
    "omicron_rel_transmissibility": 3.,
    "omicron_seed_start": 690.0,
    "ve_death": 0.9,
    "ve_infection": 0.3,
}


In [None]:
model.run(parameters)

comparison_df = pd.DataFrame({
    "modelled": model.get_derived_outputs_df()["notifications"],
    "reported": notifications_target,
})
comparison_df.plot()

In [None]:
from estival import priors, targets
from estival.calibration.mcmc.adaptive import AdaptiveChain

In [None]:
notifications_target_trimmed = notifications_target.loc["jul 2021": "may 2022"]

In [None]:
mcmc_targets = [
    targets.NormalTarget("notifications", notifications_target_trimmed, np.std(notifications_target) * 0.1)  # priors.UniformPrior("notif_dispersion", (200.0, 2000.0)))
]

In [None]:
parameters

In [None]:
mcmc_priors = [
    priors.UniformPrior("contact_rate", [0.01, 0.1]),
    priors.UniformPrior("latent_period", [0.2, 5.0]),
    priors.UniformPrior("infectious_period", [3.0, 8.0]),
    priors.UniformPrior("cdr", [0.01, 0.2]),
    priors.UniformPrior("omicron_seed_start", [650, 750]),
    priors.UniformPrior("omicron_rel_transmissibility", [2.0, 6.0]),
]

In [None]:
# Make a copy of the parameters to use as init_params
# We can update this in place if we want a different starting point
init_p = parameters.copy()

In [None]:
iterations = 1000
mcmc = AdaptiveChain(build_full_model, parameters, mcmc_priors, mcmc_targets, init_p, {"config": config}, fixed_proposal_steps=500)
mcmc.run(max_iter=iterations)

In [None]:
import arviz as az
import random

In [None]:
# Some values to adjust to produce the desired outputs
burn_in_prop = 0.2
sample_for_plot = 20

In [None]:
burn_in = round(burn_in_prop * iterations)  # Find the integer number of burn-in iterations

burnt_results = mcmc.results[burn_in:]  # Get the MCMC results after burn-in
accepted_mcmc = [burnt_results[i] for i in range(len(burnt_results)) if burnt_results[i].accept]  # Extract the accepted iterations
mcmc_sample = random.sample(accepted_mcmc, sample_for_plot)  # Choose a sample to run for plotting later

inf_data = mcmc.to_arviz(burn_in)  # Get the post-burn in chain in arviz format

In [None]:
out_df = {}
recovered_df = {}
for i, r in enumerate(mcmc_sample):
    cur_params = parameters.copy()
    cur_params.update(r.parameters)
    model.run(cur_params)
    derived_out = model.get_derived_outputs_df()
    out_df[i] = derived_out["notifications"]
    recovered_df[i] = derived_out["prop_ever_infected"]

In [None]:
pd.options.plotting.backend = "matplotlib"
ax = pd.DataFrame(out_df).plot(style='-', figsize=(15, 6), legend=False)
notifications_target.loc["aug 2021": "mar 2022"].plot(style='.', color="black")
pd.options.plotting.backend = "plotly"

In [None]:
pd.options.plotting.backend = "matplotlib"
ax = pd.DataFrame(recovered_df).plot(style='-', figsize=(15, 6), legend=False)
pd.options.plotting.backend = "plotly"

In [None]:
# Find the parameter set with the highest log likelihood obtained

best_ll = -np.inf
best_res = None

for r in mcmc.results:
    if r.ll > best_ll:
        best_ll = r.ll
        best_res = r
        
best_res.parameters

In [None]:
max_ll_params = parameters.copy()
max_ll_params.update(best_res.parameters)

model.run(max_ll_params)

comparison_df = pd.DataFrame({
    "modelled": model.get_derived_outputs_df()["notifications"],
    "reported": notifications_target,
})
comparison_df.plot()

In [None]:
az.summary(inf_data)

In [None]:
az.plot_trace(inf_data, figsize=(16, 19));

In [None]:
az.plot_posterior(inf_data);

In [None]:
from concurrent.futures import ThreadPoolExecutor
from functools import partial

def run_chain(chain_idx, n_iter, rand_init=False):
    init_p = parameters.copy()
    if rand_init:
        for p in mcmc_priors:
            init_p[p.name] = p.rv.rvs(1)[0]
    
    mcmc = AdaptiveChain(build_full_model, parameters, mcmc_priors, mcmc_targets, init_p,{"config": config}, chain_id=chain_idx,fixed_proposal_steps=500)
    mcmc.run(max_iter=n_iter)
    return mcmc



In [None]:
n_chains = 4

run_func = partial(run_chain, rand_init=True,n_iter=1000)

with ThreadPoolExecutor(n_chains) as p:
    results = p.map(run_func,range(n_chains))

multichain_results = list(results)

In [None]:
from estival.utils import to_arviz

inf_multichain = to_arviz(multichain_results, 200)


In [None]:
az.summary(inf_multichain)

In [None]:
az.plot_trace(inf_multichain, compact=False, legend=True, figsize=(16, 19));

In [None]:
az.plot_posterior(inf_multichain);