In [1]:
# pip install the required packages if running in Colab
try:
    import google.colab
    IN_COLAB = True
    %pip install summerepi2==1.0.4
    %pip install estival==0.1.7
except:
    IN_COLAB = False

In [2]:
# Standard imports, plotting option and constant definition
from datetime import datetime, timedelta
from typing import List, Union
import pandas as pd
import plotly.express as px
import numpy as np
import pickle

from jax import numpy as jnp

from summer2.utils import ref_times_to_dti

from summer2 import CompartmentalModel, Stratification, StrainStratification, Overwrite
from summer2.parameters import Parameter, Function, DerivedOutput, Time, Data

pd.options.plotting.backend = "plotly"
COVID_BASE_DATE = datetime(2019, 12, 31)


In [3]:
# Get a function to access the Malaysia data if running in Colab
if IN_COLAB:
    !wget https://raw.githubusercontent.com/monash-emu/AuTuMN/master/notebooks/capacity_building/malaysia/get_mys_data.py

from notebooks.capacity_building.malaysia import get_mys_data 
# ... and use it to get the actual data
df = get_mys_data.fetch_mys_data()
initial_population = get_mys_data.get_initial_population("Malaysia")
observations = get_mys_data.get_target_observations(df, "Malaysia", "cases")
incidence_target = observations["cases_new"]
incidence_target.name = "incidence"
incidence_target_trimmed = incidence_target.loc["jul 2021": "nov 2021"]
start_date = datetime(2021, 3, 10)
end_date = start_date + timedelta(days=500)
start_date_int = (start_date - COVID_BASE_DATE).days
end_date_int = (end_date - COVID_BASE_DATE).days

# Model

## Define a model

In [4]:
unstratified_compartments = ["S", "I", "R"]

In [5]:
def build_unstratified_model() -> CompartmentalModel:
    """
    Create a compartmental model, with compartmental structure needed to run and produce some sort of 
    meaningful outputs.
    
    Args:
        parameters: Flow parameters
    Returns:
        A compartmental model currently without stratification applied
    """
    model = CompartmentalModel(
        times=(start_date_int, end_date_int),
        compartments=unstratified_compartments,
        infectious_compartments=["I"],
        ref_date=COVID_BASE_DATE
    )

    infectious_seed = Parameter("infectious_seed")

    model.set_initial_population(
        distribution=
        {
            "S": initial_population - infectious_seed, 
            "I": infectious_seed,
        }
    )
    
    # Susceptible people can get infected
    model.add_infection_frequency_flow(
        name="infection", 
        contact_rate=Parameter("beta"), 
        source="S",
        dest="I",
    )  

    # Infectious people recover after some time spent infectious
    model.add_transition_flow(
        name="recovery",
        fractional_rate=1. / Parameter("infectious_period"),
        source="I",
        dest="R",
    )
    
    # Only a proportion of new cases are identified as cases
    model.request_output_for_flow(
        name="incidence",
        flow_name="infection",
    )

    return model

In [6]:
def get_infectiousness_stratification(
    compartments_to_stratify: List[str],
    n_infectiousness_levels: int
) -> Stratification:
    """
    Create a summer stratification object that stratifies all of the infectious compartments into
    strata, which are intended to represent different levels of infectiousness.

    Returns:
        A summer stratification object to represent age stratification (not yet applied)
    """   
    # Some preparation
    strata = [f"spreader_{i}" for i in range(n_infectiousness_levels)]

    # Create the stratification, just naming the age groups by their starting value
    strat = Stratification(
        name="infectiousness", 
        strata=strata, 
        compartments=compartments_to_stratify
    )  
    
    # split between the different spredear categories
    infection_splits = {f"spreader_{i}": Parameter(f"prop_spreader_{i}") for i in range(n_infectiousness_levels)}
    strat.set_flow_adjustments(
        "infection",
        infection_splits
    )

    # adjust the infectiousness level of the different categories
    infectiousness_adjustments = {}
    relative_infectiousness = 1.
    for i in range(n_infectiousness_levels):      
        
        if i > 0:
            relative_infectiousness *= Parameter(f"alpha_{i}")

        infectiousness_adjustments[f"spreader_{i}"] = relative_infectiousness
        
    strat.add_infectiousness_adjustments(
        "I",
        adjustments=infectiousness_adjustments
    )

    return strat

In [7]:
def build_full_model(n_infectiousness_levels):

    # Get an unstratified model object
    model = build_unstratified_model()
    
    base_compartments = model.compartments

    # Get and apply the infectiousness stratification
    infectiousness_strat = get_infectiousness_stratification(
        model.compartments, 
        n_infectiousness_levels,
    )
    model.stratify_with(infectiousness_strat)
    
    return model

In [8]:
model = build_full_model(n_infectiousness_levels=1)

In [9]:
model.get_input_parameters()

{'beta', 'infectious_period', 'infectious_seed', 'prop_spreader_0'}

In [10]:
parameters = {    
    'beta': .18, 
    'infectious_period': 7., 
    'infectious_seed': 100.,     
    'prop_spreader_0': 1.
}


In [11]:
model.run(parameters)

comparison_df = pd.DataFrame({
    "modelled": model.get_derived_outputs_df()["incidence"],
    "reported": incidence_target_trimmed,
})
comparison_df.plot()

In [12]:
from estival import priors, targets
from estival.calibration.mcmc.adaptive import AdaptiveChain

In [16]:
mcmc_targets = [
    targets.NormalTarget(incidence_target_trimmed, 0.1)  # priors.UniformPrior("notif_dispersion", (200.0, 2000.0)))
]

TypeError: Can't instantiate abstract class NormalTarget with abstract method get_evaluator

In [None]:
parameters

In [None]:
mcmc_priors = [
    priors.TruncNormalPrior("contact_rate", 0.05, 0.01, [0.03, 0.07]),
    priors.UniformPrior("latent_period", [2.0, 5.0]),
    priors.UniformPrior("infectious_period", [3.0, 8.0]),
    priors.UniformPrior("cdr", [0.01, 0.2]),
    priors.UniformPrior("omicron_seed_start", [600, 700]),
    priors.UniformPrior("omicron_rel_transmissibility", [2.0, 6.0]),
]

In [None]:
# Make a copy of the parameters to use as init_params
# We can update this in place if we want a different starting point
init_p = parameters.copy()

In [None]:
iterations = 1000
mcmc = AdaptiveChain(build_full_model, parameters, mcmc_priors, mcmc_targets, init_p, {"config": config})
mcmc.run(max_iter=iterations)

In [None]:
import arviz as az
import random

In [None]:
# Some values to adjust to produce the desired outputs
burn_in_prop = 0.2
sample_for_plot = 20

In [None]:
burn_in = round(burn_in_prop * iterations)  # Find the integer number of burn-in iterations

burnt_results = mcmc.results[burn_in:]  # Get the MCMC results after burn-in
accepted_mcmc = [burnt_results[i] for i in range(len(burnt_results)) if burnt_results[i].accept]  # Extract the accepted iterations
mcmc_sample = random.sample(accepted_mcmc, sample_for_plot)  # Choose a sample to run for plotting later

inf_data = mcmc.to_arviz(burn_in)  # Get the post-burn in chain in arviz format

In [None]:
out_df = {}
recovered_df = {}
for i, r in enumerate(mcmc_sample):
    cur_params = parameters.copy()
    cur_params.update(r.parameters)
    model.run(cur_params)
    derived_out = model.get_derived_outputs_df()
    out_df[i] = derived_out["notifications"]
    recovered_df[i] = derived_out["prop_ever_infected"]

In [None]:
pd.options.plotting.backend = "matplotlib"
ax = pd.DataFrame(out_df).plot(style='-', figsize=(15, 6), legend=False)
notifications_target.loc["jul 2021": "may 2022"].plot(style='.', color="black")
pd.options.plotting.backend = "plotly"

In [None]:
pd.options.plotting.backend = "matplotlib"
ax = pd.DataFrame(recovered_df).plot(style='-', figsize=(15, 6), legend=False)
pd.options.plotting.backend = "plotly"

In [None]:
# Find the parameter set with the highest log likelihood obtained

best_ll = -np.inf
best_res = None

for r in mcmc.results:
    if r.ll > best_ll:
        best_ll = r.ll
        best_res = r
        
best_res.parameters

In [None]:
max_ll_params = parameters.copy()
max_ll_params.update(best_res.parameters)

model.run(max_ll_params)

comparison_df = pd.DataFrame({
    "modelled": model.get_derived_outputs_df()["notifications"],
    "reported": notifications_target,
})
comparison_df.plot()

In [None]:
az.summary(inf_data)

In [None]:
az.plot_trace(inf_data, figsize=(16, 19));

In [None]:
az.plot_posterior(inf_data);

In [None]:
from estival.utils import to_df, to_arviz

In [None]:
caldf = to_df(mcmc)

In [None]:
px.line(caldf, x="iteration", y="log_likelihood")

In [None]:
from concurrent.futures import ThreadPoolExecutor
from functools import partial

def run_chain(chain_idx, n_iter, rand_init=False):
    init_p = parameters.copy()
    if rand_init:
        for p in mcmc_priors:
            init_p[p.name] = p.rv.rvs(1)[0]
    
    mcmc = AdaptiveChain(build_full_model, parameters, mcmc_priors, mcmc_targets, init_p,{"config": config},chain_id=chain_idx)
    mcmc.run(max_iter=n_iter)
    return mcmc



In [None]:
n_chains = 4

run_func = partial(run_chain, rand_init=False,n_iter=2000)

with ThreadPoolExecutor(n_chains) as p:
    results = p.map(run_func,range(n_chains))

multichain_results = list(results)

In [None]:
def get_chained_var(mcdf, variable):
    var_table = mcdf.pivot_table(index=["chain","iteration"])[variable]

    return pd.DataFrame(
        {c: var_table.loc[c] for c in mcdf["chain"].unique()}
    )

In [None]:
mcdf = to_df(multichain_results, full_trace=True)
get_chained_var(mcdf, "log_likelihood").plot()

In [None]:
# Burnin matters - the convergence statistics can often be improved by discarding more data...
inf_multichain = to_arviz(multichain_results, burnin=500)
az.summary(inf_multichain)

In [None]:
az.plot_trace(inf_multichain, compact=False, legend=True, trace_kwargs={"alpha": 0.9}, figsize=(16, 19));

In [None]:
az.plot_posterior(inf_multichain);

In [None]:
px.scatter_matrix(mcdf, dimensions=[p.name for p in mcmc.priors], color="log_posterior", height=1000)

In [None]:
px.scatter(mcdf, x="infectious_period", y="log_posterior", color="chain")

In [None]:
def resume_chain(mcmc, max_iter):
    mcmc.run(max_iter=max_iter)
    return mcmc

resume_func = partial(resume_chain,max_iter=2000)

with ThreadPoolExecutor(n_chains) as p:
    results = p.map(resume_func,multichain_results)

multichain_results = list(results)

In [None]:
mcdf = to_df(multichain_results, full_trace=True)
get_chained_var(mcdf, "log_likelihood").plot()

In [None]:
mcdf = to_df(multichain_results, burnin=1500)

In [None]:
px.scatter_matrix(mcdf, dimensions=[p.name for p in mcmc.priors], color="log_posterior", height=1000)

In [None]:
px.scatter(mcdf, x="contact_rate", y="log_posterior", color="chain")

In [None]:
inf_multichain = to_arviz(multichain_results, burnin=1500)
az.summary(inf_multichain)

In [None]:
az.plot_trace(inf_multichain, compact=False, legend=True, trace_kwargs={"alpha": 0.9}, figsize=(16, 19));