In [None]:
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"
from summer2 import CompartmentalModel
from summer2.parameters import Parameter, DerivedOutput

In [None]:
# Hack around summer's insistence that we don't have multiple birth flows
# When David is back, we could easily patch this into a version of summer2
from summer2 import CompartmentalModel, flows
from summer2.model import _validate_flowparam
from summer2.adjust import FlowParam
from typing import Dict, Optional

def add_extra_crude_birth_flow(
    model: CompartmentalModel,
    name: str,
    birth_rate: FlowParam,
    dest: str,
    dest_strata: Optional[Dict[str, str]] = None,
    expected_flow_count: Optional[int] = None,
):
    _validate_flowparam(birth_rate)
    model._add_entry_flow(
        flows.CrudeBirthFlow,
        name,
        birth_rate,
        dest,
        dest_strata,
        expected_flow_count,
    )

In [None]:
# Base model construction - also quite arbitrary
model_comps = ["susceptible", "early latent", "late latent", "infectious", "recovered"]
model_times = [1850.0, 2024.0]
model = CompartmentalModel(
    times=model_times,
    compartments=model_comps,
    infectious_compartments=["infectious"],
)
init_pops = {"susceptible": 500.0, "infectious": 500.0}
model.set_initial_population(init_pops)
model.request_output_for_compartments("total_population", model_comps);

In [None]:
# TB transitions, some meaningless TB-related flows
model.add_death_flow("TB death", Parameter("death rate"), "infectious")
model.add_transition_flow("silly_transition", Parameter("silly transition rate"), "susceptible", "infectious")

In [None]:
# Demographic transitions
model.add_universal_death_flows("population_death", Parameter("population death rate"))
model.add_replacement_birth_flow("replacement_birth", "susceptible")
add_extra_crude_birth_flow(model, "extra_birth", Parameter("population growth rate"), "susceptible")

In [None]:
params = {
    "silly transition rate": 1.0,
    "population growth rate": 0.001,
    "death rate": 0.1,
    "population death rate": 0.01,
}

In [None]:
model.run(params)

In [None]:
model.get_outputs_df().plot.area()

In [None]:
outputs = model.get_derived_outputs_df()
outputs["check"] = sum(init_pops.values()) * np.exp(np.arange(0.0, model_times[1] - model_times[0] + 1, 1.0) * params["population growth rate"])
outputs.plot()