# Model Making Kitchen
Where models are made and stratified.

*This code was adapted from: [here](https://github.com/indralab/mira/blob/hackathon-ensemble/notebooks/evaluation_2023.07/Ensemble%20Evaluation%20Model%201.ipynb), [here](https://github.com/ciemss/pyciemss/blob/283-july-evaluation-scenario-3/notebook/july_evaluation/Scenario3/testbed.ipynb), [here](https://github.com/indralab/mira/blob/main/notebooks/hackathon_2023.07/scenario1-2-wastewater.ipynb), and [here](https://github.com/indralab/mira/blob/main/notebooks/evaluation_2023.07/scenario1.ipynb). Credit to: Charlie Hoyt, Toby Brett, Agustin Kruel, Klas Karis, Ben Gyori, and others.*

### Load dependencies, MIRA modeling tools

In [30]:
import sympy
from copy import deepcopy as _d
from mira.metamodel import *
# from mira.modeling import Model, Author
from mira.modeling.askenet.petrinet import AskeNetPetriNetModel
import jsonschema
import itertools as itt
from tqdm.auto import tqdm
from collections import defaultdict
import requests
from sympy import IndexedBase, Indexed
from datetime import datetime

# now = datetime.now().strftime("%m-%d %H:%M")
# now

### Define AMR sanity check

In [5]:
def sanity_check_amr(amr_json):
    import requests

    assert "schema" in amr_json
    schema_json = requests.get(amr_json["schema"]).json()
    jsonschema.validate(schema_json, amr_json)

### Define units

In [32]:
person_units = lambda: Unit(expression=sympy.Symbol('person'))
virus_units = lambda: Unit(expression=sympy.Symbol('virus'))
virus_per_gram_units = lambda: Unit(expression=sympy.Symbol('virus')/sympy.Symbol('gram'))
day_units = lambda: Unit(expression=sympy.Symbol('day'))
per_day_units = lambda: Unit(expression=1/sympy.Symbol('day'))
dimensionless_units = lambda: Unit(expression=sympy.Integer('1'))
gram_units = lambda: Unit(expression=sympy.Symbol('gram'))
per_day_per_person_units = lambda: Unit(expression=1/(sympy.Symbol('day')*sympy.Symbol('person')))

In [6]:
# STATUSES = [
#     "unvaccinated",
#     "vaccinated",
# ]
# AGES = [
#     "0-19",
#     "20-49",
#     "50-64",
#     "65",
# ]
# VARIANTS = [
#     "wild",
#     "delta",
#     "omicron",
# ]

# Define base-case SEIRHD model state variables, parameters, and initial conditions
**Model state variables**: Susceptible $(S)$, Exposed $(E)$, Infectious $(I)$, Recovered $(R)$, Hospitalized $(H)$, and Deceased $(D)$.

**Model equations**: \begin{align*} \frac{dS}{dt} &= -\frac{\beta}{N}SI \\
\frac{dE}{dt} &= \frac{\beta}{N}SI - \delta E \\
\frac{dI}{dt} &= \delta E - \gamma I \\
\frac{dR}{dt} &= \gamma (1 - \eta - \mu nh) I + \frac{1 - \mu h}{los} H \\
\frac{dH}{dt} &= \gamma \eta I -\frac{1}{los}H \\
\frac{dD}{dt} &= \gamma \mu nh I -\frac{\mu h}{los}H 
\end{align*}

**Parameters**: $N$ total population, $\beta$ transmission rate, $\delta$ latency period, $\gamma$ recovery rate, $\eta$ hospitalization rate given infected, $\mu h$ death rate of hospitalized individuals, $\mu nh$ death rate of infectious individuals (never hospitalized), $los$ average length of hospital stay.

In [33]:
BASE_CONCEPTS = {
    "S": Concept(
        name="S", units=person_units(), identifiers={"ido": "0000514"}
    ),
    "E": Concept(
        name="E", units=person_units(), identifiers={"apollosv": "0000154"}
    ),
    "I": Concept(
        name="I", units=person_units(), identifiers={"ido": "0000511"}
    ),
    "R": Concept(
        name="R", units=person_units(), identifiers={"ido": "0000592"}
    ),
    "H": Concept(
        name="H",
        units=person_units(),
        identifiers={"ido": "0000511"},
        context={"property": "ncit:C25179"},
    ),
    "D": Concept(
        name="D", units=person_units(), identifiers={"ncit": "C28554"}
    ),
}

total_population = 19_340_000
E0 = 1
I0 = 4

BASE_PARAMETERS = {
    'N': Parameter(name='N', value=total_population, units=person_units()),
    'beta': Parameter(name='beta', value=0.4, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.05,
                                                     'maximum': 0.8
                                                 })),
    'delta': Parameter(name='delta', value=0.25, units=per_day_units()),
    'gamma': Parameter(name='gamma', value=0.2, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.1,
                                                     'maximum': 0.5
                                                 })),
    'eta': Parameter(name='eta', value=0.1, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.005,
                                                     'maximum': 0.2
                                                 })),
    'muh': Parameter(name='muh', value=0.05, units=per_day_units()),
    'munh': Parameter(name='munh', value=0.01, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.001,
                                                     'maximum': 0.09
                                                 })),
    'los': Parameter(name='los', value=5, units=day_units()),
}

BASE_INITIALS = {
    "S": Initial(concept=Concept(name="S"), value=total_population - (E0 + I0)),
    "E": Initial(concept=Concept(name="E"), value=E0),
    "I": Initial(concept=Concept(name="I"), value=I0),
    "R": Initial(concept=Concept(name="R"), value=0),
    "H": Initial(concept=Concept(name="H"), value=0),
    "D": Initial(concept=Concept(name="D"), value=0),
}

observables = {}

# Stratify model into four age groups
Age groups: group "a0" is 0-19, group "a1" is 20-49, group "a2" is 50-64, and group "a3" is 65+

In [34]:
AGES = ["a0", "a1", "a2", "a4"]
ORDER = ["age"]
STRATA = {
    "age": AGES,
}

stratification_config = {
    "S": ["age"],
    "E": ["age"],
    "I": ["age"],
    "R": ["age"],
    "H": ["age"],
    "D": ["age"],
}

param_stratification_config = {
    "eta": ["age"],
    "muh": ["age"],
    "munh": ["age"],
    "los": ["age"],
}

templates = []

## Stratify concepts

In [36]:
# Index all p
concepts = {}

for concept, labels in stratification_config.items():
    for keys in itt.product(
        *(zip(itt.repeat(label), enumerate(STRATA[label])) for label in labels)
    ):
        d = {key: str(idx) for key, (idx, label) in keys}
        idx = tuple(d[k] for k in ORDER if k in d)
        concept_copy = _d(BASE_CONCEPTS[concept]).with_context(
            **d, do_rename=False
        )
        concept_copy.name = f"{concept_copy.name}_" + "_".join(idx)
        concepts[(concept, *idx)] = concept_copy

list(concepts.items())

[(('S', '0'),
  Concept(name='S_0', display_name=None, description=None, identifiers={'ido': '0000514'}, context={'age': '0'}, units=Unit(expression=person))),
 (('S', '1'),
  Concept(name='S_1', display_name=None, description=None, identifiers={'ido': '0000514'}, context={'age': '1'}, units=Unit(expression=person))),
 (('S', '2'),
  Concept(name='S_2', display_name=None, description=None, identifiers={'ido': '0000514'}, context={'age': '2'}, units=Unit(expression=person))),
 (('S', '3'),
  Concept(name='S_3', display_name=None, description=None, identifiers={'ido': '0000514'}, context={'age': '3'}, units=Unit(expression=person))),
 (('E', '0'),
  Concept(name='E_0', display_name=None, description=None, identifiers={'apollosv': '0000154'}, context={'age': '0'}, units=Unit(expression=person))),
 (('E', '1'),
  Concept(name='E_1', display_name=None, description=None, identifiers={'apollosv': '0000154'}, context={'age': '1'}, units=Unit(expression=person))),
 (('E', '2'),
  Concept(name='

In [12]:
initials = {}
for concept in concepts.values():
    orig_key = concept.name.split("_")[0]
    initials[concept.name] = Initial(
        concept=concept, value=BASE_INITIALS[orig_key].value
    )

In [13]:
concept_to_strata = defaultdict(list)
for idx, concept in concepts.items():
    concept_to_strata[idx[0]].append(concept)


def concept_strata_prod(*variables: str):
    yield from itt.product(
        *(concept_to_strata[variable] for variable in variables)
    )

In [14]:
# Index all possible parameters with the name as the
parameters = {
    "rho": Parameter(name="rho", value=1 / 2, units=per_day_units()),
    "q": Parameter(
        name="q", value=R_0 * gamma_val, units=dimensionless_units()
    ),  # transmission probability
    "gamma": Parameter(name="gamma", value=gamma_val, units=per_day_units()),
}
for parameter, labels in param_stratification_config.items():
    for keys in itt.product(
        *(zip(itt.repeat(label), enumerate(STRATA[label])) for label in labels)
    ):
        d = defaultdict(list)
        for key, (idx, label) in keys:
            d[key].append(str(idx))
        d = dict(d)
        idx = tuple(itt.chain.from_iterable(d[k] for k in ORDER if k in d))
        p = _d(BASE_PARAMETERS[parameter])
        p.name = f"{p.name}_" + "_".join(idx)
        parameters[(parameter, *idx)] = p

parameter_to_strata = defaultdict(list)
for idx, parameter in parameters.items():
    parameter_to_strata[idx[0]].append(parameter)

# parameters.values()

In [15]:
def context_idx(concept):
    return tuple(
        concept.context[part]
        for part in stratification_config[concept.name.split("_")[0]]
    )


context_idx(concepts["S", "0", "0"])

('0', '0')

In [16]:
def c_symbol(concept):
    return sympy.Symbol(concept.name)


c_symbol(concepts["S", "0", "0"])

S_0_0

In [17]:
def force_of_infection_component(e, i):
    return (
        sympy.Symbol("q")
        * sympy.Symbol(
            "chi_" + e.context["variant"]
        )  # implicitly i and e have the same
        * sympy.Symbol(
            "phi_" + e.context["status"]
        )  # implicitly s and e have the same
        * sympy.Symbol("d_" + e.context["age"] + "_" + i.context["age"])
        * c_symbol(i)
    )

In [18]:
templates = []

In [19]:
def not_conserved(c1, c2):
    return not all(
        c1.context[key] == c2.context[key]
        for key in ORDER
        if key in c1.context and key in c2.context
    )


# t1 = ControlledConversion(
#    subject=c["S"], outcome=c["E"], controller=c["I"], rate_law=S * I * q * d
# )
for s, e, i in tqdm(
    concept_strata_prod("S", "E", "I"),
    unit_scale=True,
):
    if not_conserved(s, e):
        continue
    if e.context["variant"] != i.context["variant"]:
        continue
    infection = ControlledConversion(
        subject=s,
        outcome=e,
        controller=i,
        rate_law=c_symbol(s) * force_of_infection_component(e, i),
    )
    templates.append(infection)

0.00it [00:00, ?it/s]

In [20]:
# do param lookup based on output


# t2 = NaturalConversion(subject=c["E"], outcome=c["I"], rate_law=rho * E)
for e, i in concept_strata_prod("E", "I"):
    if not_conserved(e, i):
        continue
    t2 = NaturalConversion(
        subject=e,
        outcome=i,
        rate_law=sympy.Symbol(BASE_PARAMETERS["rho"].name) * c_symbol(e),
    )
    templates.append(t2)

In [21]:
# t3 = NaturalConversion(subject=c["I"], outcome=c["R"], rate_law=gamma * (1 - eta) * I)
for i, r in concept_strata_prod("I", "R"):
    if not_conserved(i, r):
        continue
    t3 = NaturalConversion(
        subject=i,
        outcome=r,
        rate_law=(
            sympy.Symbol(BASE_PARAMETERS["gamma"].name)
            * (1 - sympy.Symbol(parameters[("eta", *context_idx(r))].name))
            * c_symbol(i)
        ),
    )
    templates.append(t3)

In [22]:
# t4 = NaturalConversion(subject=c["I"], outcome=c["H"], rate_law=gamma * eta * I)
for i, h in concept_strata_prod("I", "H"):
    if not_conserved(i, h):
        continue
    t4 = NaturalConversion(
        subject=i,
        outcome=h,
        rate_law=(
            sympy.Symbol(BASE_PARAMETERS["gamma"].name)
            * sympy.Symbol(parameters[("eta", *context_idx(h))].name)
            * c_symbol(i)
        ),
    )
    templates.append(t4)

In [23]:
# t5 = NaturalConversion(subject=c["H"], outcome=c["R"], rate_law=(1-mu) * H / lr)
for h, r in concept_strata_prod("H", "R"):
    if not_conserved(h, r):
        continue
    t5 = NaturalConversion(
        subject=h,
        outcome=r,
        rate_law=(
            (1 - sympy.Symbol(parameters[("mu", *context_idx(r))].name))
            * c_symbol(h)
            / sympy.Symbol(parameters["lr", r.context["age"]].name)
        ),
    )
    templates.append(t5)

In [24]:
# t6 = NaturalConversion(subject=c["H"], outcome=c["D"], rate_law=mu * H / ld)
for h, d in concept_strata_prod("H", "D"):
    if not_conserved(h, d):
        continue
    t6 = NaturalConversion(
        subject=h,
        outcome=d,
        rate_law=(
            sympy.Symbol(parameters[("mu", *context_idx(d))].name)
            * c_symbol(h)
            / sympy.Symbol(parameters[("ld", d.context["age"])].name)
        ),
    )
    templates.append(t6)

In [25]:
#Vaccination
#This is the conversion between susceptible individuals who aren't vaccinated to ones that are. This means that we need to make sure that age is the same for the individuals.

for i, age in enumerate(AGES):
    subject = concepts["S", str(i), "0"]
    outcome = concepts["S", str(i), "1"]
    t7 = NaturalConversion(
        subject=subject,
        outcome=outcome,
        rate_law=(
            sympy.Symbol(parameters["alpha", subject.context["age"]].name)
        ),
    )
    templates.append(t7)

In [29]:
model = TemplateModel(
    templates=templates,
    parameters={p.name: p for p in parameters.values()},
    initials=initials,
    time=Time(name="t", units=day_units()),
    observables=observables,
    annotations=Annotations(
        name=f"Toby's Great Adventure SEIRHD {now}",
        # Metadata breaks Terarium
        #         license="CC0",
        #         authors=[
        #           Author(name='Toby Brett'),
        #           Author(name='Charles Tapley Hoyt')
        #         ],
        #         pathogens=["ncbitaxon:2697049"],
        #         diseases=[
        #             "doid:0080600",
        #         ],
        #         hosts=[
        #             "ncbitaxon:9606",
        #         ],
        #         model_types=[
        #             "mamo:0000028",
        #             "mamo:0000046",
        #         ],
        description="""\
        This model has been stratified by age (4 ways),
        variant (3 ways), and vaccine status (2 ways).

        The naming convention is to have the base
        population type (e.g., S, E, I) followed by underscores
        in the order of age, variant, disase status. All compartments
        are stratified by all three except for the S compartment, which
        is only stratified by age and status.

        Because death is stratified, it might be useful to create observables
        which e.g. take the sum over all statuses/ages/variants.

        We used numerical indexes (encoded as strings) for each of the age,
        variants, and statuses.

        - Status: (0) unvaccinated and (1) vaccinated
        - Ages: (0) for 0-19, (1) for 20-49, (2) for 50-64, and (3) for 65+
        - Variants: (0) wild, (1) delta, and (2) omicron
        
        This model also includes the conversion from unvaccinated susceptible
        individuals to vaccinated susceptible individuals (with conserved age).
        The parameter alpha ideally is a time dependent function.
        
        TODO we can model the reverse process (i.e., waning) of vaccinated
        susceptible individuals becoming unvaccinated as a decay of vaccine
        efficacy, which would also be time dependant.
        """,
    ),
)
# sanity_check_tm(model)
am = AskeNetPetriNetModel(Model(model))
sanity_check_amr(am.to_json())
am.to_json_file("toby_seirhd.json")

In [2]:
person_units = lambda: Unit(expression=sympy.Symbol('person'))
virus_units = lambda: Unit(expression=sympy.Symbol('virus'))
virus_per_gram_units = lambda: Unit(expression=sympy.Symbol('virus')/sympy.Symbol('gram'))
day_units = lambda: Unit(expression=sympy.Symbol('day'))
per_day_units = lambda: Unit(expression=1/sympy.Symbol('day'))
dimensionless_units = lambda: Unit(expression=sympy.Integer('1'))
gram_units = lambda: Unit(expression=sympy.Symbol('gram'))
per_day_per_person_units = lambda: Unit(expression=1/(sympy.Symbol('day')*sympy.Symbol('person')))

In [None]:




# See Table 1 of the paper
c = {
    'S': Concept(name='S', units=person_units(), identifiers={'ido': '0000514'}),
    'E': Concept(name='E', units=person_units(), identifiers={'apollosv': '0000154'}),
    'I': Concept(name='I', units=person_units(), identifiers={'ido': '0000511'}),
    'V': Concept(name='V', units=person_units(), identifiers={'vido': '0001331'}),
}


parameters = {
    'gamma': Parameter(name='gamma', value=0.08, units=per_day_units()),
    'delta': Parameter(name='delta', value=1/8, units=per_day_units()),
    'alpha': Parameter(name='alpha', value=500, units=gram_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 51,
                                                     'maximum': 796
                                                 })),
    'lambda': Parameter(name='lambda', value=9.66e-8, units=per_day_per_person_units()),
    'beta': Parameter(name='beta', value=4.49e7, units=virus_per_gram_units()),
    'k': Parameter(name='k', value=1/3, units=per_day_units()),
}

initials = {
    'S': Initial(concept=Concept(name='S'), value=2_300_000),
    'E': Initial(concept=Concept(name='E'), value=1000),
    'I': Initial(concept=Concept(name='I'), value=0),
    'V': Initial(concept=Concept(name='V'), value=0),
}

S, E, I, V, gamma, delta, alpha, lmbd, beta, k = \
    sympy.symbols('S E I V gamma delta alpha lambda beta k')

t1 = ControlledConversion(subject=c['S'],
                          outcome=c['E'],
                          controller=c['I'],
                          rate_law=S*I*lmbd)
t2 = NaturalConversion(subject=c['E'],
                       outcome=c['I'],
                       rate_law=k*E)
t3 = NaturalDegradation(subject=c['I'],
                        rate_law=delta*I)
t4 = ControlledProduction(outcome=c['V'],
                          controller=c['I'],
                          rate_law=alpha*beta*(1-gamma)*I)
templates = [t1, t2, t3, t4]
observables = {}
tm = TemplateModel(
    templates=templates,
    parameters=parameters,
    initials=initials,
    time=Time(name='t', units=day_units()),
    observables=observables,
    annotations=Annotations(name='Scenario 3 base model'))

In [None]:
# Add uncertainty
parameters = {
    'gamma': Parameter(name='gamma', value=0.08, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.06,
                                                     'maximum': 0.09
                                                 })),
    'delta': Parameter(name='delta', value=1/8, units=per_day_units()),
    'alpha': Parameter(name='alpha', value=500, units=gram_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 51,
                                                     'maximum': 796
                                                 })),
    'lambda': Parameter(name='lambda', value=9.66e-8, units=per_day_per_person_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 6.66e-8,
                                                     'maximum': 12.66e-8
                                                 })),
    'beta': Parameter(name='beta', value=4.49e7, units=virus_per_gram_units()),
    'k': Parameter(name='k', value=1/3, units=per_day_units()),
}

# E here is E(0) -> (10, 100, 1000, 5000)
initials = {
    'S': Initial(concept=Concept(name='S'), value=2_300_000),
    'E': Initial(concept=Concept(name='E'), value=1000),
    'I': Initial(concept=Concept(name='I'), value=0),
    'V': Initial(concept=Concept(name='V'), value=0),
}

tm = TemplateModel(
    templates=templates,
    parameters=parameters,
    initials=initials,
    time=Time(name='t', units=day_units()),
    observables=observables,
    annotations=Annotations(name='Scenario 3 base model'))

In [None]:
AskeNetPetriNetModel(Model(tm)).to_json_file('my_model.json')

In [None]:
STATUSES = [
    "unvaccinated",
    "vaccinated",
]
AGES = [
    "0-19",
    "20-49",
    "50-64",
    "65",
]
VARIANTS = [
    "wild",
    "delta",
    "omicron",
]

In [None]:
BASE_CONCEPTS = {
    "S": Concept(name="S", units=person_units(), identifiers={"ido": "0000514"}),
    "E": Concept(name="E", units=person_units(), identifiers={"apollosv": "0000154"}),
    "I": Concept(name="I", units=person_units(), identifiers={"ido": "0000511"}),
    "R": Concept(name="R", units=person_units(), identifiers={"ido": "0000592"}),
    "H": Concept(
        name="H",
        units=person_units(),
        identifiers={"ido": "0000511"},
        context={"property": "ncit:C25179"},
    ),
    "D": Concept(name="D", units=person_units(), identifiers={"ncit": "C28554"}),
}


N_val = 19_340_000
E_val = 1
I_val = 4
R_0 = 2.6
gamma_val = 1/5

BASE_PARAMETERS = {
    "gamma":  Parameter(name="gamma", value=gamma_val, units=per_day_units()),
    "eta": Parameter(name="eta", value=0.1, units=dimensionless_units()),
    "mu":  Parameter(name="mu", value=0.003, units=dimensionless_units()),
    "lr":  Parameter(name="lr", value=5, units=day_units()), # average time to recovery (duration of hospital stay if they recover)
    "ld":  Parameter(name="ld", value=9.25, units=day_units()), # average time to recovery (duration of hospital stay if they die)
    "rho": Parameter(name="rho", value=1/2, units=per_day_units()),
    "q": Parameter(name="q", value=R_0 * gamma_val, units=dimensionless_units()), # transmission probability
    "d": Parameter(name="d", value=1/N_val, units=per_day_per_person_units()), # scaled contact rate
    "phi": Parameter(name="phi", value=1.0, units=dimensionless_units()), # host susceptibility
    "chi": Parameter(name="chi", value=1.0, units=dimensionless_units()), # relative transmissibility of variant

}

BASE_INITIALS = {
    "S": Initial(concept=Concept(name="S"), value=N_val - (E_val + I_val)),
    "E": Initial(concept=Concept(name="E"), value=E_val),
    "I": Initial(concept=Concept(name="I"), value=I_val),
    "R": Initial(concept=Concept(name="R"), value=0),
    "H": Initial(concept=Concept(name="H"), value=0),
    "D": Initial(concept=Concept(name="D"), value=0),
}

observables = {}

In [None]:
# make function to do this automatically
(
    S,
    E,
    I,
    R,
    D,
    H,
    q,
    d,  # d eats the N
    rho, # r_E_to_I,
    gamma,
    eta,
    mu,
    lr,
    ld,
    chi,
    phi,
) = sympy.symbols(
    "S E I R D H q d rho gamma eta mu lr ld chi phi"
)

In [None]:
ORDER = ["age", "variant", "status"]
STRATA = {
    "age": AGES, 
    "variant": VARIANTS,
    "status": STATUSES,
}

stratification_config = {
    "S": ["age", "status"],
    "E": ["age", "variant", "status"],
    "I": ["age", "variant", "status"],
    "R": ["age", "variant", "status"],
    "H": ["age", "variant", "status"],
    "D": ["age", "variant", "status"],
}

param_stratification_config = {
    "d": ["age", "age"],
    "eta": ["age", "variant", "status"],
    "mu": ["age", "variant", "status"],
    "ld": ["age"],
    "lr": ["age"],
    "phi": ["status"],
    "chi": ["variant"],
}


templates = []

In [None]:
stratification_config["S"]

In [None]:
# Index all p
concepts = {}

for concept, labels in stratification_config.items():
    for keys in itt.product(*(
        zip(itt.repeat(label), enumerate(STRATA[label]))
        for label in labels
    )):
        d = {
            key: str(idx)
            for key, (idx, label) in keys
        }
        idx = tuple(d[k] for k in ORDER if k in d) 
        concept_copy = _d(BASE_CONCEPTS[concept]).with_context(**d, do_rename=False)
        concept_copy.name = f"{concept_copy.name}_" + "_".join(idx)
        concepts[(concept, *idx)] = concept_copy
        
list(concepts.items())[30:35]   

In [None]:
initials = {}
for concept in concepts.values():
    orig_key = concept.name.split("_")[0]
    initials[concept.name] =  Initial(
        concept=concept, value=BASE_INITIALS[orig_key].value
    )

In [None]:
concept_to_strata = defaultdict(list)
for idx, concept in concepts.items():
    concept_to_strata[idx[0]].append(concept)
    
    
def concept_strata_prod(*variables: str):
    yield from itt.product(*(
        concept_to_strata[variable]
        for variable in variables
    ))

In [None]:
# Index all possible parameters with the name as the 
parameters = {}
for parameter, labels in param_stratification_config.items():
    for keys in itt.product(*(
        zip(itt.repeat(label), enumerate(STRATA[label]))
        for label in labels
    )):
        d = defaultdict(list)
        for key, (idx, label) in keys:
            d[key].append(str(idx))
        d = {k:sorted(v) for k,v in d.items()}
        idx = tuple(itt.chain.from_iterable(d[k] for k in ORDER if k in d))
        p = _d(BASE_PARAMETERS[parameter])
        p.name = f"{p.name}_" + "_".join(idx)
        parameters[(parameter, *idx)] = p
        
parameter_to_strata = defaultdict(list)
for idx, parameter in parameters.items():
    parameter_to_strata[idx[0]].append(parameter)
        
list(parameters.items())[:5]

In [None]:
def force_of_infection(a, v, i):
    d = IndexedBase('d')
    I = IndexedBase('I')
    exprs = [
        d[a, age_index] * I[age_index, v , status_index]
        for age_index, _ in enumerate(AGES)
        for status_index, _ in enumerate(STATUSES)
    ]
    return (
        q 
        * Indexed("chi", v)
        * Indexed("psi", i)
        * sum(exprs)
    )
    
force_of_infection(a='0-19', v='omicron', i='vaccinated')

In [None]:
def context_idx(concept):
    return tuple(
        concept.context[part]
        for part in stratification_config[concept.name]
    )
context_idx(concepts['S', '0', '0'])

In [None]:
def c_symbol(concept):
    return sympy.Symbol(concept.name + "_" + "_".join(idx))

c_symbol(concepts['S', '0', '0'])

In [None]:
def force_of_infection_component(e, i):
    d = IndexedBase('d')
    I = IndexedBase('I')
    return (
        q 
        * sympy.Symbol("chi_" + e.context["status"])
        * sympy.Symbol("psi_" + e.context["variant"])
        * sympy.Symbol("d_" + e.context["age"] + "_" + i.context["age"])
        * c_symbol(i)
    )