In [None]:
import numpy as np
import pandas as pd
from datetime import datetime
from jax import numpy as jnp
import jax
from typing import List

pd.options.plotting.backend = "plotly"

from summer.utils import ref_times_to_dti

from summer2 import CompartmentalModel, Stratification
from summer2.parameters import Parameter, Function, Data

from autumn.projects.sm_sir.australia.northern_territory.project import get_ts_date_indexes
from autumn.core.project import load_timeseries
from autumn.settings.constants import COVID_BASE_DATETIME

In [None]:
# This should be adapted before actually using this properly,
# but just gets us some data in that we could use for calibration
ts_set = load_timeseries("../../../autumn/projects/sm_sir/malaysia/malaysia/timeseries.json")

COVID_BASE_DATETIME = datetime(2019, 12, 31)

In [None]:
def find_latent_transition_rate(
    base_comps: List[str], 
    latent_parameter: str,
) -> Parameter:
    n_exposed_comps = len([c for c in base_comps if "_exposed" in c])
    latent_period = Parameter(latent_parameter) / n_exposed_comps
    return 1. / latent_period

In [None]:
def build_unstratified_model(settings: dict) -> CompartmentalModel:
    """
    Create a compartmental model, with the minimal compartmental structure needed to run and produce some sort of 
    meaningful outputs.
    
    Args:
        parameters: Parameters for use in model construction    
    Returns:
        A compartmental model currently without stratification applied
    """

    base_comps = [
        "susceptible",
        "early_exposed",
        "late_exposed",
        "infectious",
        "recovered",
    ]
    
    model = CompartmentalModel(
        [
            settings["start_time"], 
            settings["end_time"],
        ],
        base_comps,
        ["infectious"],
        timestep=1.,
    )
    
    model.set_initial_population(
        {
            "susceptible": Parameter("total_population") - Parameter("seed_size"), 
            "infectious": Parameter("seed_size"),
        }
    )
    
    model.add_infection_frequency_flow(
        "infection", 
        Parameter("contact_rate"),
        "susceptible", 
        "early_exposed",
    )

    latent_rate = find_latent_transition_rate(base_comps, "latent_period")
        
    model.add_transition_flow(
        "early_progression",
        latent_rate,
        "early_exposed",
        "late_exposed",
    )
    
    model.add_transition_flow(
        "late_progression",
        latent_rate,
        "late_exposed",
        "infectious",
    )
    
    model.add_transition_flow(
        "recovery",
        Parameter("recovery_rate"),
        "infectious",
        "recovered",
    )
    
    model.request_output_for_flow(
        "progressions",
        "late_progression",
    )
    
    def prop_detected(progressions):
        return progressions * settings["cdr"]
    
    model.request_function_output(
        "notifications",
        func=prop_detected,
        sources=["progressions"],
    )
    
    return model

In [None]:
parameters = {
    "contact_rate": 1.2,
    "recovery_rate": 1.0,
    "total_population": 3e7,
    "seed_size": 10.,
    "latent_period": 1.,
}
settings = {
    "start_time": 400.,
    "end_time": 900.,
    "cdr": 0.1,
}

malaysia_model = build_unstratified_model(settings)
malaysia_model.run(parameters)

In [None]:
# Process the modelled notifications
notifs = malaysia_model.get_derived_outputs_df()["notifications"]
notifs_dates = notifs
notifs_dates.index = ref_times_to_dti(COVID_BASE_DATETIME, notifs.index)

# Process the observed notifications
ts_set_dates = get_ts_date_indexes(ts_set, COVID_BASE_DATETIME)

# Collate
comparison_outputs = pd.DataFrame(
    {
        "modelled": notifs_dates,
        "observed": ts_set_dates["notifications"],
    }
)

# Plot
comparison_outputs.plot()

In [None]:
malaysia_model.get_input_parameters()

In [None]:
malaysia_model.graph.draw()