In [None]:
from summer2 import CompartmentalModel, Stratification
from summer2.parameters import CompartmentValues, Parameter, Time, Function
from summer2.functions import time as stf

from jax import numpy as jnp
import pandas as pd

In [None]:
pd.options.plotting.backend="plotly"

In [None]:
# infectious_compartments can (should?) default to None

In [None]:
def build_model():
    m = CompartmentalModel([0,100.0], ["unvacc", "vacc", "dose_avail"], [])
    m.set_initial_population({"unvacc": 1000.0, "dose_avail": 100.0})

    def vacc_rates(comp_vals, maxdd = 10.0):
        dose_idx = m.query_compartments({"name": "dose_avail"}, as_idx=True)
        cur_dose_avail = comp_vals[dose_idx].sum()
        unvacc_idx = m.query_compartments({"name": "unvacc"}, as_idx=True)
        cur_unvacc = comp_vals[unvacc_idx].sum()

        num_to_vacc = jnp.min(jnp.array((maxdd,cur_unvacc, cur_dose_avail)))

        vacc_rate = num_to_vacc/cur_unvacc
        dose_rate = num_to_vacc/cur_dose_avail
        
        return {"vacc_rate": vacc_rate, "dose_rate": dose_rate}
    
    n2v = Function(vacc_rates, [CompartmentValues, Parameter("maxdd")])

    m.add_transition_flow("vaccination", n2v["vacc_rate"], "unvacc", "vacc")
    m.add_death_flow("dose_depletion", n2v["dose_rate"], "dose_avail")

    return m

In [None]:
m = build_model()

In [None]:
p = {
    "maxdd": 5.0
}
m.run(p)

m.get_outputs_df().plot()