# Bringing it all together

In this notebook we aim to apply several different stratifications, mixing and vaccination concepts learnt previously into a single model. The goal here is to take everything that we have learnt previously using "toy" code with simplified stratifications/functionality loosely representative of Covid transmission - and extend this to a model that we can be satisfied reasonably replicates most of the most important dynamics of Covid transmission in Malaysia.

A little like notebook number 04, this is an opportunity to take stock and apply the epidemiological principles we have worked through in the preceding notebooks.

## Standard preliminaries
Before we get into building the model, let's start off with some of our standard (or "boilerplate") code to get everything set up.

In [None]:
# pip install the required packages if running in Colab
try:
    import google.colab
    IN_COLAB = True
    %pip install summerepi2==1.0.4
    %pip install estival==0.1.3

except:
    IN_COLAB = False

In [None]:
# Standard imports, plotting option and constant definition
from datetime import datetime, timedelta
from typing import List, Union
import pandas as pd
import plotly.express as px
import numpy as np
import pickle
from jax import numpy as jnp

from summer2.utils import ref_times_to_dti

from summer2 import CompartmentalModel, Stratification, StrainStratification, Overwrite
from summer2.parameters import Parameter, Function, DerivedOutput, Time, Data

pd.options.plotting.backend = "plotly"

COVID_BASE_DATE = datetime(2019, 12, 31)

In [None]:
# The data import module lives in a file on AuTuMN github - download it for colab use
if IN_COLAB:
    !wget https://raw.githubusercontent.com/monash-emu/AuTuMN/master/notebooks/capacity_building/philippines/import_phl_data.py
    !wget https://raw.githubusercontent.com/monash-emu/AuTuMN/master/notebooks/capacity_building/philippines/PHL_matrices.pkl
    !wget https://raw.githubusercontent.com/monash-emu/AuTuMN/master/notebooks/capacity_building/philippines/NCR_age_pops.pkl
    !wget https://raw.githubusercontent.com/monash-emu/AuTuMN/master/notebooks/capacity_building/philippines/NCR_vac_coverage.pkl

import import_phl_data
from import_phl_data import get_population_and_epi_data, get_timeseries_data

mixing_matrix = pd.read_pickle("PHL_matrices.pkl", compression='infer')


In [None]:
# Shareable google drive links
PHL_DOH_LINK = "1ULjAmO7dE9YEEI8j7MWeSujRSRPhlvE1"  # sheet 05 daily report
PHL_FASSSTER_LINK = "1Cg_jsjhXsOtsqcMVUSHK6F7y9Ky8VxZL"  # Fassster google drive zip file
# initial_population, df = get_population_and_epi_data(PHL_DOH_LINK, PHL_FASSSTER_LINK) 
initial_population, df = get_timeseries_data() 

# We define a day zero for the analysis
COVID_BASE_DATE = datetime(2019, 12, 31)

In [None]:
# Define a target set of observations to compare against our modelled outputs later
notifications_target = df["cases"]

In [None]:
age_groups = range(0, 80, 5)

In [None]:
age_pops = pickle.load(open("NCR_age_pops.pkl", "rb"))
age_pops.index = age_pops.index.map(str)

# Model

## Define a model

In [None]:
unstratified_compartments = ["S", "E", "I", "R", "S2"]

In [None]:
def build_unstratified_model(config: dict) -> CompartmentalModel:
    """
    Create a compartmental model, with the minimal compartmental structure needed to run and produce some sort of 
    meaningful outputs.
    
    Args:
        parameters: Flow parameters
    Returns:
        A compartmental model currently without stratification applied
    """

    model = CompartmentalModel(
        times=(config["start_time"], config["end_time"]),
        compartments=unstratified_compartments,
        infectious_compartments=["I"],
        ref_date=COVID_BASE_DATE
    )

    infectious_seed = Parameter("infectious_seed")

    model.set_initial_population(
        distribution=
        {
            "S": initial_population - infectious_seed, 
            "I": infectious_seed
        }
    )
    
    # Susceptible people can get infected
    model.add_infection_frequency_flow(
        name="infection", 
        contact_rate=Parameter("contact_rate"), 
        source="S", 
        dest="E",
    )
    
    # Recovered people can also get infected
    model.add_infection_frequency_flow(
        name="reinfection", 
        contact_rate=Parameter("contact_rate"), 
        source="S2", 
        dest="E",
    )
    
    # Expose people transition to infected
    model.add_transition_flow(
        name="progression",
        fractional_rate=Parameter("progression_rate"),
        source="E",
        dest="I",
    )

    # Infectious people recover
    model.add_transition_flow(
        name="recovery",
        fractional_rate=Parameter("recovery_rate"),
        source="I",
        dest="R",
    )
    
    # Infectious people recover
    model.add_transition_flow(
        name="waning_natural_immunity",
        fractional_rate=Parameter("waning_immunity_rate"),
        source="R",
        dest="S2",
    )

    # Add an infection-specific death flow to the I compartment
    model.add_death_flow(name="infection_death", death_rate=Parameter("death_rate"), source="I")
    
    model.request_output_for_flow(
        "progressions",
        "progression",
    )
    
    #def prop_detected(progressions, cdr):
    #    return progressions * config["cdr"]
    
    model.request_function_output(
        "notifications",
        func=Function(jnp.multiply, [DerivedOutput("progressions"),Parameter("cdr")]),
    )

    return model

In [None]:
def get_age_stratification(
    compartments_to_stratify: List[str],
    strata: List[str],
    matrix: Union[np.ndarray, callable],
) -> Stratification:
    """
    Create a summer stratification object that stratifies all of the compartments into
    strata, which are intended to represent age bands according to the user inputs.
    This is essentially adapting the model's age stratification approach to the format
    of the mixing matrix, which is a reasonable approach.
    
    Args:
        compartments_to_stratify: List of the compartments to stratify, which should be all the compartments
        strata: The strata to be implemented in the age stratification
        matrix: The mixing matrix we are applying for the age structure
    Returns:
        A summer stratification object to represent age stratification (not yet applied)
    """
    
    if isinstance(matrix, np.ndarray):
        msg = "Mixing matrix is not 2-dimensional"
        assert matrix.ndim == 2, msg

        msg = f"Dimensions of the mixing matrix incorrect: {matrix.shape[0]}, {matrix.shape[1]}, {len(strata)}"
        assert matrix.shape[0] == matrix.shape[1] == len(strata), msg
    
    # Create the stratification, just naming the age groups by their starting value
    strat = Stratification(name="age", strata=strata, compartments=compartments_to_stratify)
    
    age_split_props = age_pops / age_pops.sum()
    strat.set_population_split(age_split_props.to_dict())
    
    # Add the mixing matrix to the stratification
    strat.set_mixing_matrix(matrix)
    
    return strat

In [None]:
def get_strain_stratification(
    compartments_to_stratify: List[str], 
) -> Stratification:
    """
    Create a summer stratification object that stratifies compartments into
    strata, which are intended to represent infectious disease strains.
    
    Args:
        compartments_to_stratify: List of the compartments to stratify
        voc_params: A dictionary which speicifies the infectiousness and severity of strains
    Returns:
        A summer stratification object to represent strain stratification (not yet applied)
    """
    strata = [
        "delta", 
        "omicron"
    ]
    strat = StrainStratification(name="strain", strata=strata, compartments=compartments_to_stratify)

    # At the start of the simulation, a certain proportion of infected people have the variant strain.
    strat.set_population_split(
        {
            "delta": 1.,
            "omicron": 0.,
        }
    )

    for infection_flow in ["infection", "reinfection"]:
        strat.set_flow_adjustments(
            infection_flow,
            {
                "delta": None,
                "omicron": Parameter("omicron_rel_transmissibility"),
            },
        )

    return strat

In [None]:
full_dose_coverage = pickle.load(open("NCR_vac_coverage.pkl", "rb"))
full_dose_coverage.plot.area(title="two-dose vaccination coverage")

In [None]:
# To save on calculations a little, let's thin out the data
thinning_interval = 7
thinned_full_coverage = full_dose_coverage[::thinning_interval]

def get_prop_of_remaining_covered(old_prop, new_prop):
    return (new_prop - old_prop) / (1. - old_prop)

interval_prop_unvacc_vaccinated = [
    get_prop_of_remaining_covered(
        thinned_full_coverage.iloc[i],
        thinned_full_coverage.iloc[i + 1],
    ) 
    for i in range(len(thinned_full_coverage) - 1)
]

coverage_times = thinned_full_coverage.index

pd.Series(interval_prop_unvacc_vaccinated, index=coverage_times[1:]).plot(
    title="proportion of remaining unvaccinated vaccinated during each interval"
)

In [None]:
def get_rate_from_coverage_and_duration(coverage_increment: float, duration: float) -> float:
    assert duration >= 0.0, f"Duration request is negative: {duration}"
    assert 0.0 <= coverage_increment <= 1.0, f"Coverage increment not in [0, 1]: {coverage_increment}"
    return -np.log(1.0 - coverage_increment) / duration


interval_lengths = [
    coverage_times[i + 1] - coverage_times[i] 
    for i in range(len(coverage_times) - 1)
]

vaccination_rates = [
    get_rate_from_coverage_and_duration(i, j) for 
    i, j in zip(interval_prop_unvacc_vaccinated, interval_lengths)
]
pd.Series(vaccination_rates, index=coverage_times[1:]).plot(kind="scatter")

In [None]:
from summer2.functions import get_piecewise_scalar_function

In [None]:
# Pad the endpoints with 0.0 - ie this is the vaccination rate outside of the bounds we have computed,
# and therefore what should be returned when requesting data outside this range
full_vacc_rates = np.concatenate([(0.0,), vaccination_rates, (0.0,)])

In [None]:
vacc_rate_func = get_piecewise_scalar_function(coverage_times, full_vacc_rates)

In [None]:
def get_vaccine_stratification(
    compartments_to_stratify: List[str], 
) -> Stratification:
    """
    Create a summer stratification object that stratifies compartments into
    strata, which are intended to represent vaccine stratifications.
    
    Args:
        compartments_to_stratify: List of the compartments to stratify
        vaccine_params: A dictionary which speicifies the vaccination-related parameters to implement
    Returns:
        A summer stratification object to represent strain stratification (not yet applied)
    """
    strata = ["vaccinated", "unvaccinated"]
    
    # Create the stratification
    vaccine_strat = Stratification(name="vaccination", strata=strata, compartments=compartments_to_stratify)

    # Create our population split dictionary, whose keys match the strata with 80% vaccinated and 20% unvaccinated
    pop_split = {
        "vaccinated": 0., 
        "unvaccinated": 1.,
    }

    # Set a population distribution
    vaccine_strat.set_population_split(pop_split)

    # Adjusting the death risk associated with vaccination
    vaccine_strat.set_flow_adjustments(
        "infection_death",
        {
            "unvaccinated": None,
            "vaccinated": 1. - Parameter("ve_death"),
        }
    )
    
    # Susceptibility
    for infection_flow in ["infection", "reinfection"]:
        vaccine_strat.set_flow_adjustments(
            infection_flow,
            {
                "unvaccinated": None,
                "vaccinated": 1. - Parameter("ve_infection"),
            }
        )

    return vaccine_strat

In [None]:
start_date = datetime(2021, 5, 15)
end_date = start_date + timedelta(days=500)
start_date_int = (start_date - COVID_BASE_DATE).days
end_date_int = (end_date - COVID_BASE_DATE).days

config = {
    "start_time": start_date_int,
    "end_time": end_date_int,
}

def build_full_model(config):

    # Get an unstratified model object
    model = build_unstratified_model(config)

    base_compartments = model.compartments

    # Get and apply the age stratification
    age_strat = get_age_stratification(
        base_compartments, 
        age_groups, 
        mixing_matrix["all_locations"],
    )
    model.stratify_with(age_strat)

    # Get and apply vaccination stratification
    vacc_params = {
        "ve_death": 0.9,
        "ve_infection": 0.3,
    }

    vacc_strat = get_vaccine_stratification(
        base_compartments,
    )
    model.stratify_with(vacc_strat)

    # Get and apply the strain stratification
    strain_strat = get_strain_stratification(
        ["E", "I"],
    )
    model.stratify_with(strain_strat)

    # Compose a function returning the bounds of a window starting at "omicron_seed_start",
    # with a fixed length of 10

    # The basic function to get our bounds
    def get_window(start, length):
        return jnp.array((start,start+length))

    # Composed function
    time_window = Function(get_window, [Parameter("omicron_seed_start"), 10.0])

    # Define a piecewise function returning 1.0 during time_window, and 0.0 outside of this
    omicron_seed_func = get_piecewise_scalar_function(time_window, [0.0,1.0,0.0])

    model.add_importation_flow(
        "omicron_seeding",
        omicron_seed_func,
        "E",
        split_imports=True,
        dest_strata={"strain": "omicron"},
    )

    for comp in unstratified_compartments:
        model.add_transition_flow(
            name="vaccination",
            fractional_rate=vacc_rate_func,
            source=comp,
            dest=comp,
            source_strata={"vaccination": "unvaccinated"},
            dest_strata={"vaccination": "vaccinated"},
        )

    model.request_output_for_compartments(
        "vaccinated",
        unstratified_compartments,
        strata={"vaccination": "vaccinated"},
    )
    model.request_output_for_compartments(
        "unvaccinated",
        unstratified_compartments,
        strata={"vaccination": "unvaccinated"},
    )
    
    return model

In [None]:
model = build_full_model(config)

In [None]:
model.get_input_parameters()

In [None]:
parameters = {
    "contact_rate": 0.04,
    "progression_rate": 0.3,
    "recovery_rate": 0.2,
    "death_rate": 0.001,
    "infectious_seed": 200.,
    "cdr": 0.05,
    "waning_immunity_rate": 1. / 365.,
    "ve_death": 0.9,
    "ve_infection": 0.3,
    "omicron_rel_transmissibility": 4.,
    "omicron_seed_start": 650.0
}


In [None]:
model.run(parameters)

comparison_df = pd.DataFrame({
    "modelled": model.get_derived_outputs_df()["notifications"],
    "reported": notifications_target,
})
comparison_df.plot()

In [None]:
from estival import priors, targets
from estival.calibration.mcmc.adaptive import AdaptiveChain

In [None]:
notifications_target_trimmed = notifications_target.loc['jul 2021':"mar 2022"]

In [None]:
def get_dispersion_prior(name, data):
    max_val = np.max(data)
    # sd_ that would make the 95% gaussian CI cover half of the max value (4*sd = 95% width)
    sd_ = 0.25 * max_val / 4.0
    lower_sd = sd_ / 2.0
    upper_sd = 2.0 * sd_

    return priors.UniformPrior(f"{name}_dispersion",(lower_sd,upper_sd))

In [None]:
mcmc_targets = [
    targets.NormalTarget("notifications", notifications_target_trimmed, get_dispersion_prior("notifications", notifications_target))
]

In [None]:
mcmc_targets = [
    targets.NormalTarget("notifications", notifications_target_trimmed, np.std(notifications_target) * 0.1)# priors.UniformPrior("notif_dispersion",(200.0,2000.0)))
]

In [None]:
mcmc_priors = [
    priors.UniformPrior("contact_rate", [0.01,0.1]),
    priors.UniformPrior("progression_rate", [0.2,0.4]),
    priors.UniformPrior("recovery_rate", [0.15,0.25]),
    priors.UniformPrior("cdr", [0.01,0.2]),
    priors.UniformPrior("omicron_seed_start", [600,760]),
    priors.UniformPrior("omicron_rel_transmissibility", [2.0,6.0]),
]

In [None]:
parameters

In [None]:
init_p = parameters.copy()
#init_p["notifications_dispersion"] = np.std(notifications_target)

In [None]:
mcmc = AdaptiveChain(build_full_model, parameters, mcmc_priors, mcmc_targets, init_p,{"config": config}, fixed_proposal_steps=500)

In [None]:
mcmc.run(max_iter=10000)

In [None]:
import arviz as az

In [None]:
inf_data = mcmc.to_arviz(1000)

In [None]:
model.run(parameters)# | mcmc.results[2316].parameters)

comparison_df = pd.DataFrame({
    "modelled": model.get_derived_outputs_df()["notifications"],
    "reported": notifications_target,
})
comparison_df.plot()

In [None]:
out_df = {}
for i,r in enumerate(mcmc.results[2000::100]):
    model.run(parameters | r.parameters)
    out_df[i] = model.get_derived_outputs_df()["notifications"]

In [None]:
pd.options.plotting.backend = "matplotlib"
ax = pd.DataFrame(out_df).plot(style='-',figsize=(16,12),legend=False)
notifications_target.loc["jul 2021":].plot(style='.', color="black")
pd.options.plotting.backend = "plotly"

In [None]:
best_ll = -np.inf
best_res = None

In [None]:
for r in mcmc.results:
    if r.ll > best_ll:
        best_ll = r.ll
        best_res = r

In [None]:
model.run(parameters | best_res.parameters)

comparison_df = pd.DataFrame({
    "modelled": model.get_derived_outputs_df()["notifications"],
    "reported": notifications_target,
})
comparison_df.plot()

In [None]:
az.summary(inf_data)

In [None]:
az.plot_trace(inf_data, figsize=(16,19))

In [None]:
az.plot_posterior(inf_data)