In [None]:
import pandas as pd
import numpy as np
from jax import numpy as jnp

from summer2 import CompartmentalModel, Stratification, population
from summer2.parameters import Function, Parameter

pd.options.plotting.backend="plotly"

In [None]:
m = CompartmentalModel([0,100], ["S","I","R"],["I"])

age_strat = Stratification("agegroup", ["young", "old"], ["S","I","R"])
m.stratify_with(age_strat)

state_strat = Stratification("state", ["WA","other"], ["S","I","R"])
m.stratify_with(state_strat)

imm_strat = Stratification("imm", ["vacc","unvacc"], ["S","I","R"])
m.stratify_with(imm_strat)


In [None]:
state_pop_info = {
    "WA_young": 1000.0,
    "WA_old": 2000.0,
    "other_young": 10000.0,
    "other_old": 30000.0
}

imm_scale = {
    "vacc_young": Parameter("vacc_young"),
    "vacc_old": Parameter("vacc_old"),
    "unvacc_young": 1.0 - Parameter("vacc_young"),
    "unvacc_old": 1.0 - Parameter("vacc_old")
}

In [None]:
def get_init_pop(imm_scale):
    init_pop = jnp.zeros(len(m.compartments), dtype=np.float64)
    for agegroup in m.stratifications["agegroup"].strata:
        for imm in m.stratifications["imm"].strata:
            for state in m.stratifications["state"].strata:
                q = m.query_compartments({"name": "S", "agegroup": agegroup, "imm": imm, "state": state}, as_idx=True)
                state_pinfo_str = f"{state}_{agegroup}"
                imm_scale_str = f"{imm}_{agegroup}"
                init_pop = init_pop.at[q].set(state_pop_info[state_pinfo_str] * imm_scale[imm_scale_str])
    return init_pop

In [None]:
m.init_population_with_graphobject(Function(get_init_pop, [imm_scale]))

In [None]:
parameters = {"vacc_young": 0.2, "vacc_old": 0.6}
m.get_initial_population(parameters)

In [None]:
m.run(parameters)
m.get_outputs_df().plot()