In [None]:
from summer2 import CompartmentalModel, Stratification
from summer2.parameters import CompartmentValues, Parameter, Time, Function
from summer2.functions import time as stf

from jax import numpy as jnp
import pandas as pd
import numpy as np

In [None]:
pd.options.plotting.backend="plotly"

In [None]:
# infectious_compartments can (should?) default to None

In [None]:
#  Vaccination regimen; up until a certain date, only C eligible, after this date, everyone is eligible (and will vaccinate at the same rate)

#  

In [None]:
def get_category_indexer(m, query):
    return np.array([m.query_compartments(q, as_idx=True) for q in query])

In [None]:
def build_model():
    m = CompartmentalModel([0,100.0], ["unvacc", "vacc", "dose_avail"], [])
    m.set_initial_population({"unvacc": 1000.0, "dose_avail": 500.0})

    def vacc_rates(time, comp_vals, maxdd = 10.0):
        dose_idx = m.query_compartments({"name": "dose_avail"}, as_idx=True)
        cur_dose_avail = comp_vals[dose_idx].sum()

        indexer = get_category_indexer(m, [{"age": a, "name": "unvacc"} for a in ["A","B","C"]])

        # Unvaccinated population by age group
        cur_unvacc = comp_vals[indexer].sum(axis=1)

        # Total unvaccinated population
        cur_unvacc_tot = cur_unvacc.sum()

        # Until time==20.0, we only allow group C to be vaccinated
        # after this, allocation happens according to the relative populations of unvaccinated in different groups
        #alloc_ratio = jnp.where(time < 20.0, jnp.array((0.0,0.0,1.0)), cur_unvacc/cur_unvacc_tot)
        
        # Change vaccination strategy the moment 80% of group C are vaccinated
        group_c_pop = comp_vals[m.query_compartments({"age": "C"}, as_idx=True)].sum()
        prop_c_vacc = 1.0 - (cur_unvacc[2] / group_c_pop)

        alloc_ratio = jnp.where(prop_c_vacc < 0.8, jnp.array((0.0,0.0,1.0)), cur_unvacc/cur_unvacc_tot)

        # Calculate the total number of realised unvaccinated (ie those who are unvaccinated according to the current allocation)
        cur_unvacc_real = (alloc_ratio * cur_unvacc).sum()
        num_to_vacc = jnp.min(jnp.array((maxdd,cur_unvacc_real, cur_dose_avail)))

        vacc_rates = alloc_ratio * (num_to_vacc/cur_unvacc_real)
        dose_rate = num_to_vacc/cur_dose_avail
        
        return {"vacc_rates": vacc_rates, "dose_rate": dose_rate}
    
    n2v = Function(vacc_rates, [Time, CompartmentValues, Parameter("maxdd")])

    m.add_transition_flow("vaccination", 1.0, "unvacc", "vacc")
    m.add_death_flow("dose_depletion", n2v["dose_rate"], "dose_avail")

    #m.request_track_modelled_value("vacc_rate_A", n2v["vacc_rate_A"])

    age_strat = Stratification("age", ["A","B","C"], ["unvacc", "vacc"])
    age_strat.set_population_split({"A": 0.2, "B": 0.5, "C": 0.3})
    age_strat.set_flow_adjustments("vaccination", {k:n2v["vacc_rates"][i] for i,k in enumerate(age_strat.strata)})
    m.stratify_with(age_strat)

    #loc_strat = Stratification("location", ["urban", "rural"], ["unvacc", "vacc"])
    #m.stratify_with(loc_strat)

    return m

In [None]:
m = build_model()


In [None]:
p = {
    "maxdd": 10.0
}
m.run(p)

m.get_outputs_df().plot()