In [None]:
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"

from jax import numpy as jnp

# Also import jax itself - normally you'll never need this directly, but we want to have a look...
import jax

from summer2 import CompartmentalModel, Stratification
from summer2.parameters import Parameter, Function, Data

In [None]:
def build_unstratified_model(contact_rate, recovery_rate) -> CompartmentalModel:
    """
    Create a compartmental model, with the minimal compartmental structure needed to run and produce some sort of 
    meaningful outputs.
    
    Args:
        parameters: Flow parameters
    Returns:
        A compartmental model currently without stratification applied
    """

    model = CompartmentalModel(
        [0, 100],
        ["susceptible", "infectious", "recovered"], 
        ["infectious"]
    )
    model.set_initial_population(
        {
            "susceptible": 990., 
            "infectious": 10.
        }
    )
    model.add_infection_frequency_flow(
        "infection", 
        contact_rate, 
        "susceptible", 
        "infectious"
    )
    
    # Infectious people recover
    model.add_transition_flow(
        name="recovery",
        fractional_rate=recovery_rate,
        source="infectious",
        dest="recovered",
    )
    
    return model

In [None]:
contact_rate = Parameter("contact_rate")
recovery_rate = Parameter("recovery_rate")
m = build_unstratified_model(contact_rate, recovery_rate)
pdict = {
    "contact_rate": 5.0,
    "recovery_rate": 1.0,
}
m.run(pdict)
m.get_outputs_df().plot()