In [None]:
import numpy as np
import pandas as pd
from jax import numpy as jnp
import jax
from typing import List

pd.options.plotting.backend = "plotly"

from summer2 import CompartmentalModel, Stratification
from summer2.parameters import Parameter, Function, Data

In [None]:
def find_latent_transition_rate(
    base_comps: List[str], 
    latent_parameter: str,
) -> Parameter:
    n_exposed_comps = len([c for c in base_comps if "_exposed" in c])
    latent_period = Parameter(latent_parameter) / n_exposed_comps
    return 1. / latent_period

In [None]:
def build_unstratified_model(settings: dict) -> CompartmentalModel:
    """
    Create a compartmental model, with the minimal compartmental structure needed to run and produce some sort of 
    meaningful outputs.
    
    Args:
        parameters: Parameters for use in model construction    
    Returns:
        A compartmental model currently without stratification applied
    """

    base_comps = [
        "susceptible",
        "early_exposed",
        "late_exposed",
        "infectious",
        "recovered",
    ]
    
    model = CompartmentalModel(
        [
            settings["start_time"], 
            settings["end_time"],
        ],
        base_comps,
        ["infectious"],
        timestep=0.1,
    )
    
    model.set_initial_population(
        {
            "susceptible": Parameter("total_population") - Parameter("seed_size"), 
            "infectious": Parameter("seed_size"),
        }
    )
    
    model.add_infection_frequency_flow(
        "infection", 
        Parameter("contact_rate"),
        "susceptible", 
        "early_exposed",
    )

    latent_rate = find_latent_transition_rate(base_comps, "latent_period")
        
    model.add_transition_flow(
        "early_progression",
        latent_rate,
        "early_exposed",
        "late_exposed",
    )
    
    model.add_transition_flow(
        "late_progression",
        latent_rate,
        "late_exposed",
        "infectious",
    )
    
    model.add_transition_flow(
        "recovery",
        Parameter("recovery_rate"),
        "infectious",
        "recovered",
    )
    
    return model

In [None]:
parameters = {
    "contact_rate": 10.0,
    "recovery_rate": 1.0,
    "total_population": 1000.,
    "seed_size": 10.,
    "latent_period": 1.,
}
settings = {
    "start_time": 0.,
    "end_time": 10.,
}

malaysia_model = build_unstratified_model(settings)
malaysia_model.run(parameters)
malaysia_model.get_outputs_df().plot()

In [None]:
malaysia_model.get_input_parameters()

In [None]:
malaysia_model.graph.filter(sources=["parameters.latent_period"]).draw()