# Model Making Kitchen
Where models are made and stratified.

*This code was taken and adapted from: [here](https://github.com/indralab/mira/blob/hackathon-ensemble/notebooks/evaluation_2023.07/Ensemble%20Evaluation%20Model%201.ipynb), [here](https://github.com/ciemss/pyciemss/blob/283-july-evaluation-scenario-3/notebook/july_evaluation/Scenario3/testbed.ipynb), [here](https://github.com/indralab/mira/blob/main/notebooks/hackathon_2023.07/scenario1-2-wastewater.ipynb), and [here](https://github.com/indralab/mira/blob/main/notebooks/evaluation_2023.07/scenario1.ipynb). Credit to: Charlie Hoyt, Toby Brett, Agustin Kruel, Klas Karis, Ben Gyori, and others.*

### Load dependencies, MIRA modeling tools

In [1]:
import sympy
from copy import deepcopy as _d
from mira.metamodel import *
from mira.modeling import Model #, Author
from mira.modeling.askenet.petrinet import AskeNetPetriNetModel
import jsonschema
import itertools as itt
from tqdm.auto import tqdm
from collections import defaultdict
import requests
from sympy import IndexedBase, Indexed
from datetime import datetime

# now = datetime.now().strftime("%m-%d %H:%M")
# now

### Define AMR sanity check

In [2]:
def sanity_check_amr(amr_json):
    import requests

    assert "schema" in amr_json
    schema_json = requests.get(amr_json["schema"]).json()
    jsonschema.validate(schema_json, amr_json)

### Define units

In [3]:
person_units = lambda: Unit(expression=sympy.Symbol('person'))
virus_units = lambda: Unit(expression=sympy.Symbol('virus'))
virus_per_gram_units = lambda: Unit(expression=sympy.Symbol('virus')/sympy.Symbol('gram'))
day_units = lambda: Unit(expression=sympy.Symbol('day'))
per_day_units = lambda: Unit(expression=1/sympy.Symbol('day'))
dimensionless_units = lambda: Unit(expression=sympy.Integer('1'))
gram_units = lambda: Unit(expression=sympy.Symbol('gram'))
per_day_per_person_units = lambda: Unit(expression=1/(sympy.Symbol('day')*sympy.Symbol('person')))

# Create base-case SEIRHD model
**Model state variables**: Susceptible $(S)$, Exposed $(E)$, Infectious $(I)$, Recovered $(R)$, Hospitalized $(H)$, and Deceased $(D)$.

**Model equations**: \begin{align*} \frac{dS}{dt} &= -\frac{\beta}{N}SI \\
\frac{dE}{dt} &= \frac{\beta}{N}SI - \delta E \\
\frac{dI}{dt} &= \delta E - \gamma I \\
\frac{dR}{dt} &= \gamma (1 - \eta) I + \frac{1 - \mu}{los} H \\
\frac{dH}{dt} &= \gamma \eta I - \frac{1}{los}H \\
\frac{dD}{dt} &= \frac{\mu}{los}H 
\end{align*}

**Parameters**: $N$ total population, $\beta$ transmission rate, $\delta$ latency period, $\gamma$ recovery rate, $\eta$ hospitalization rate given infected, $\mu$ death rate of hospitalized individuals (model assumption: all deaths attributed to COVID-19 occur in hospital), $los$ average length of hospital stay.

## Define model state variables, parameters, and initial conditions

In [None]:
BASE_CONCEPTS = {
    "S": Concept(
        name="S", units=person_units(), identifiers={"ido": "0000514"}
    ),
    "E": Concept(
        name="E", units=person_units(), identifiers={"apollosv": "0000154"}
    ),
    "I": Concept(
        name="I", units=person_units(), identifiers={"ido": "0000511"}
    ),
    "R": Concept(
        name="R", units=person_units(), identifiers={"ido": "0000592"}
    ),
    "H": Concept(
        name="H",
        units=person_units(),
        identifiers={"ido": "0000511"},
        context={"property": "ncit:C25179"},
    ),
    "D": Concept(
        name="D", units=person_units(), identifiers={"ncit": "C28554"}
    ),
}

total_population = 19_340_000.0
E0 = 1.0
I0 = 4.0

BASE_PARAMETERS = {
    'N': Parameter(name='N', value=total_population, units=person_units()),
    'beta': Parameter(name='beta', value=0.4, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.05,
                                                     'maximum': 0.8
                                                 })),
    'delta': Parameter(name='delta', value=0.25, units=per_day_units()),
    'gamma': Parameter(name='gamma', value=0.2, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.1,
                                                     'maximum': 0.5
                                                 })),
    'eta': Parameter(name='eta', value=0.1, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.005,
                                                     'maximum': 0.2
                                                 })),
    'mu': Parameter(name='mu', value=0.07, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.001,
                                                     'maximum': 0.1
                                                 })),
    'los': Parameter(name='los', value=5, units=day_units()),
}

BASE_INITIALS = {
    "S": Initial(concept=Concept(name="S"), value=total_population - (E0 + I0)),
    "E": Initial(concept=Concept(name="E"), value=E0),
    "I": Initial(concept=Concept(name="I"), value=I0),
    "R": Initial(concept=Concept(name="R"), value=0),
    "H": Initial(concept=Concept(name="H"), value=0),
    "D": Initial(concept=Concept(name="D"), value=0),
}

observables = {}

## Define rate laws

In [None]:
S, E, I, R, H, D, N, beta, delta, gamma, eta, mu, los = \
    sympy.symbols('S E I R H D N beta delta gamma eta mu los')

t1 = ControlledConversion(subject=BASE_CONCEPTS['S'],
                          outcome=BASE_CONCEPTS['E'],
                          controller=BASE_CONCEPTS['I'],
                          rate_law=S*I*beta / N)
t2 = NaturalConversion(subject=BASE_CONCEPTS['E'],
                       outcome=BASE_CONCEPTS['I'],
                       rate_law=delta*E)
t3 = NaturalConversion(subject=BASE_CONCEPTS['I'],
                       outcome=BASE_CONCEPTS['R'],
                       rate_law=gamma*(1 - eta)*I)
t4 = NaturalConversion(subject=BASE_CONCEPTS['I'],
                       outcome=BASE_CONCEPTS['H'],
                       rate_law=gamma*eta*I)
t5 = NaturalConversion(subject=BASE_CONCEPTS['H'],
                       outcome=BASE_CONCEPTS['R'],
                       rate_law=((1 - mu)/los)*H)
t6 = NaturalConversion(subject=BASE_CONCEPTS['H'],
                       outcome=BASE_CONCEPTS['D'],
                       rate_law=(mu/los)*H)

## Produce ASKEM Model Representation

In [None]:
templates = [t1, t2, t3, t4, t5, t6]
tm = TemplateModel(
    templates=templates,
    parameters=BASE_PARAMETERS,
    initials=BASE_INITIALS,
    time=Time(name='t', units=day_units()),
    observables=observables,
    annotations=Annotations(name='SEIRHD base model EE')
)

# sanity_check_tm(tm)
am = AskeNetPetriNetModel(Model(tm))
sanity_check_amr(am.to_json())
am.to_json_file('SEIRHD_base_model_ee.json')

# Create SEIRHD model with NPI intervention type 1
Following [Tay et al.](https://www.mdpi.com/1660-4601/18/17/9027), we simulate compulsory social distancing and/or mask wearing by changing the transmission rate $\beta$ to vary in time according to $\beta (t) = \kappa m(t)$, where $\displaystyle m(t) = \beta_c + \frac{\beta_s - \beta_c}{1 + e^{-k(t_0 - t)}}.$ 

In [None]:
parameters_alt_1 = {
    'N': Parameter(name='N', value=total_population, units=person_units()),
    'kappa': Parameter(name='kappa', value=0.4, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.05,
                                                     'maximum': 0.8
                                                 })),
    'delta': Parameter(name='delta', value=0.25, units=per_day_units()),
    'gamma': Parameter(name='gamma', value=0.2, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.1,
                                                     'maximum': 0.5
                                                 })),
    'eta': Parameter(name='eta', value=0.1, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.005,
                                                     'maximum': 0.2
                                                 })),
    'mu': Parameter(name='mu', value=0.07, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.001,
                                                     'maximum': 0.1
                                                 })),
    'los': Parameter(name='los', value=5, units=day_units()),
    'beta_s': Parameter(name='beta_s', value=1, units=per_day_units()),
    # 'beta_c': Parameter(name='beta_c', value=0.4, units=per_day_units()),
    'beta_c': Parameter(name='beta', value=0.4, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.1,
                                                     'maximum': 0.8
                                                 })),
    't_0': Parameter(name='t_0', value=89, unts=day_units, units=day_units()),
    'k': Parameter(name='k', value=5.0, units=dimensionless_units()),
    # 'k': Parameter(name='k', value=10.0, units=dimensionless_units(),
    #                    distribution=Distribution(type='Uniform1',
    #                                              parameters={
    #                                                  'minimum': 1.0,
    #                                                  'maximum': 20.0
    #                                              })),
}

beta_s, beta_c, t_0, kappa, k, t = sympy.symbols('beta_s beta_c t_0 kappa k t')
m_1 = (beta_s - beta_c) / (1 + sympy.exp(-k*(t_0-t))) + beta_c
beta_time_varying = kappa*m_1

t1_alt_1 = ControlledConversion(subject=BASE_CONCEPTS['S'],
                          outcome=BASE_CONCEPTS['E'],
                          controller=BASE_CONCEPTS['I'],
                          rate_law=S*I*beta_time_varying / N)

tm_mask_1 = TemplateModel(
    templates=[t1_alt_1, t2, t3, t4, t5, t6],
    parameters=parameters_alt_1,
    initials=BASE_INITIALS,
    time=Time(name='t', units=day_units()),
    observables=observables,
    # annotations=Annotations(name='SEIRHD model with NPI type 1')
    # annotations=Annotations(name='SEIRHD model with NPI type 1, k varies')
    annotations=Annotations(name='SEIRHD model with NPI type 1, beta_c varies')
)

# sanity_check_tm(tm_mask_2)
am = AskeNetPetriNetModel(Model(tm_mask_1))
sanity_check_amr(am.to_json())
# am.to_json_file('SEIRHD_npi1_ee.json')
# am.to_json_file('SEIRHD_npi1_k_varying_ee.json')
am.to_json_file('SEIRHD_npi1_beta_c_varying_ee.json')

# Stratify models into four age groups
Age groups: a0 (0 - 19), a1 (20 - 49), a2 (50 - 64), and a3 (65+)

In [None]:
AGES = ["a0", "a1", "a2", "a3"]

ORDER = ["age"]
STRATA = {
    "age": AGES,
}

stratification_config = {
    "S": ["age"],
    "E": ["age"],
    "I": ["age"],
    "R": ["age"],
    "H": ["age"],
    "D": ["age"],
}

param_stratification_config = {
    "eta": ["age"],
    "mu": ["age"],
    "los": ["age"],
}

templates = []

## Stratify concepts

In [None]:
concepts = {}

for concept, labels in stratification_config.items():
    for keys in itt.product(
        *(zip(itt.repeat(label), enumerate(STRATA[label])) for label in labels)
    ):
        d = {key: str(idx) for key, (idx, label) in keys}
        idx = tuple(d[k] for k in ORDER if k in d)
        concept_copy = _d(BASE_CONCEPTS[concept]).with_context(
            **d, do_rename=False
        )
        concept_copy.name = f"{concept_copy.name}_" + "_".join(idx)
        concepts[(concept, *idx)] = concept_copy

# list(concepts.items())

## Stratify initial conditions

In [None]:
# NYS population by age groups: a0 (0 - 19), a1 (20 - 49), a2 (50 - 64), and a3 (65+)
# Source: https://www.health.ny.gov/statistics/vital_statistics/2016/table01.htm
# new_york_state_population_by_age = [4597368, 7893559, 3878809, 2970264]

num_strata = len(AGES)
initial_values = [4597368.0 - (E0 + I0)/num_strata, 7893559.0 - (E0 + I0)/num_strata, 3878809.0 - (E0 + I0)/num_strata, 2970264.0 - (E0 + I0)/num_strata, 
                  E0/num_strata, E0/num_strata, E0/num_strata, E0/num_strata, I0/num_strata, I0/num_strata, I0/num_strata, I0/num_strata,
                  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
                 ]
initials = {}
counter = 0
for concept in concepts.values():
    orig_key = concept.name.split("_")[0]
    initials[concept.name] = Initial(
        concept=concept, value=initial_values[counter] #BASE_INITIALS[orig_key].value
    )
    counter += 1
    
# initials

In [None]:
concept_to_strata = defaultdict(list)
for idx, concept in concepts.items():
    concept_to_strata[idx[0]].append(concept)


def concept_strata_prod(*variables: str):
    yield from itt.product(
        *(concept_to_strata[variable] for variable in variables)
    )

# concepts

In [None]:
parameters = {
    'N': Parameter(name='N', value=total_population, units=person_units()),
    'kappa': Parameter(name='kappa', value=0.4, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.05,
                                                     'maximum': 0.8
                                                 })),
    'delta': Parameter(name='delta', value=0.25, units=per_day_units()),
    'gamma': Parameter(name='gamma', value=0.2, units=per_day_units(),
                       distribution=Distribution(type='Uniform1',
                                                 parameters={
                                                     'minimum': 0.1,
                                                     'maximum': 0.5
                                                 })),
    'beta_s': Parameter(name='beta_s', value=1, units=per_day_units()),
    'beta_c': Parameter(name='beta_c', value=0.4, units=per_day_units()),
    # 'beta_c': Parameter(name='beta', value=0.4, units=per_day_units(),
    #                    distribution=Distribution(type='Uniform1',
    #                                              parameters={
    #                                                  'minimum': 0.1,
    #                                                  'maximum': 0.8
    #                                              })),
    't_0': Parameter(name='t_0', value=89, unts=day_units, units=day_units()),
    'k': Parameter(name='k', value=5.0, units=dimensionless_units()),
    # 'k': Parameter(name='k', value=10.0, units=dimensionless_units(),
    #                    distribution=Distribution(type='Uniform1',
    #                                              parameters={
    #                                                  'minimum': 1.0,
    #                                                  'maximum': 20.0
    #                                              })),
}

for parameter, labels in param_stratification_config.items():
    for keys in itt.product(
        *(zip(itt.repeat(label), enumerate(STRATA[label])) for label in labels)
    ):
        d = defaultdict(list)
        for key, (idx, label) in keys:
            d[key].append(str(idx))
        d = dict(d)
        idx = tuple(itt.chain.from_iterable(d[k] for k in ORDER if k in d))
        p = _d(parameters_alt_1[parameter])
        p.name = f"{p.name}_" + "_".join(idx)
        parameters[(parameter, *idx)] = p

parameter_to_strata = defaultdict(list)
for idx, parameter in parameters.items():
    parameter_to_strata[idx[0]].append(parameter)

display(parameters)

In [None]:
def context_idx(concept):
    return tuple(
        concept.context[part]
        for part in stratification_config[concept.name.split("_")[0]]
    )

context_idx(concepts["S", "0"])

In [None]:
def c_symbol(concept):
    return sympy.Symbol(concept.name)

c_symbol(concepts["S", "0"])

In [None]:
def force_of_infection_component(e, i):
    return (
        sympy.Symbol("kappa")
        * (((sympy.Symbol("beta_s") - sympy.Symbol("beta_c")) / (1 + sympy.exp(-sympy.Symbol("k")*(sympy.Symbol("t_0") - sympy.Symbol("t"))))) + sympy.Symbol("beta_c"))
        * c_symbol(i)
        / sympy.Symbol("N")
    )

In [None]:
templates = []

In [None]:
def not_conserved(c1, c2):
    return not all(
        c1.context[key] == c2.context[key]
        for key in ORDER
        if key in c1.context and key in c2.context
    )


# t1 = ControlledConversion(
#    subject=c["S"], outcome=c["E"], controller=c["I"], rate_law=S * I * q * d
# )
for s, e, i in tqdm(
    concept_strata_prod("S", "E", "I"),
    unit_scale=True,
):
    if not_conserved(s, e):
        continue
    # if e.context["variant"] != i.context["variant"]:
    #     continue
    infection = ControlledConversion(
        subject=s,
        outcome=e,
        controller=i,
        rate_law=c_symbol(s) * force_of_infection_component(e, i),
    )
    templates.append(infection)

In [None]:
# t2 = NaturalConversion(subject=c["E"], outcome=c["I"], rate_law=rho * E)
for e, i in concept_strata_prod("E", "I"):
    if not_conserved(e, i):
        continue
    t2 = NaturalConversion(
        subject=e,
        outcome=i,
        rate_law=sympy.Symbol(parameters_alt_1["delta"].name) * c_symbol(e),
    )
    templates.append(t2)

In [None]:
# t3 = NaturalConversion(subject=c["I"], outcome=c["R"], rate_law=gamma * (1 - eta) * I)
for i, r in concept_strata_prod("I", "R"):
    if not_conserved(i, r):
        continue
    t3 = NaturalConversion(
        subject=i,
        outcome=r,
        rate_law=(
            sympy.Symbol(parameters_alt_1["gamma"].name)
            * (1 - sympy.Symbol(parameters[("eta", *context_idx(r))].name))
            * c_symbol(i)
        ),
    )
    templates.append(t3)

In [None]:
# t4 = NaturalConversion(subject=c["I"], outcome=c["H"], rate_law=gamma * eta * I)
for i, h in concept_strata_prod("I", "H"):
    if not_conserved(i, h):
        continue
    t4 = NaturalConversion(
        subject=i,
        outcome=h,
        rate_law=(
            sympy.Symbol(parameters_alt_1["gamma"].name)
            * sympy.Symbol(parameters[("eta", *context_idx(h))].name)
            * c_symbol(i)
        ),
    )
    templates.append(t4)

In [None]:
# t5 = NaturalConversion(subject=c["H"], outcome=c["R"], rate_law=(1-mu) * H / lr)
for h, r in concept_strata_prod("H", "R"):
    if not_conserved(h, r):
        continue
    t5 = NaturalConversion(
        subject=h,
        outcome=r,
        rate_law=(
            (1 - sympy.Symbol(parameters[("mu", *context_idx(r))].name))
            * c_symbol(h)
            / sympy.Symbol(parameters["los", r.context["age"]].name)
        ),
    )
    templates.append(t5)

In [None]:
# t6 = NaturalConversion(subject=c["H"], outcome=c["D"], rate_law=mu * H / ld)
for h, d in concept_strata_prod("H", "D"):
    if not_conserved(h, d):
        continue
    t6 = NaturalConversion(
        subject=h,
        outcome=d,
        rate_law=(
            sympy.Symbol(parameters[("mu", *context_idx(d))].name)
            * c_symbol(h)
            / sympy.Symbol(parameters[("los", d.context["age"])].name)
        ),
    )
    templates.append(t6)

In [None]:
model = TemplateModel(
    templates=templates,
    parameters={p.name: p for p in parameters.values()},
    initials=initials,
    time=Time(name="t", units=day_units()),
    observables=observables,
    annotations=Annotations(
        name="SEIRHD NPI Type 1 Age-Stratified",
        description="""\
        This model has been stratified into four age groups: (0) for 0-19, (1) for 20-49, (2) for 50-64, and (3) for 65+
        """,
    ),
)
# sanity_check_tm(model)
am = AskeNetPetriNetModel(Model(model))
sanity_check_amr(am.to_json())
am.to_json_file("SEIRHD_npi1_age_stratified_v1.json")

# Stratify by age, vaccination status, and variant

In [92]:
STATUSES = [
    "unvaccinated",
    "vaccinated",
]
AGES = [
    "0-19",
    "20-49",
    "50-64",
    "65",
]
VARIANTS = [
    "wild",
    "delta",
    "omicron",
]

In [93]:
BASE_CONCEPTS = {
    "S": Concept(
        name="S", units=person_units(), identifiers={"ido": "0000514"}
    ),
    "E": Concept(
        name="E", units=person_units(), identifiers={"apollosv": "0000154"}
    ),
    "I": Concept(
        name="I", units=person_units(), identifiers={"ido": "0000511"}
    ),
    "R": Concept(
        name="R", units=person_units(), identifiers={"ido": "0000592"}
    ),
    "H": Concept(
        name="H",
        units=person_units(),
        identifiers={"ido": "0000511"},
        context={"property": "ncit:C25179"},
    ),
    "D": Concept(
        name="D", units=person_units(), identifiers={"ncit": "C28554"}
    ),
}


N_val = 19_340_000
E_val = 1
I_val = 4
R_0 = 2.6
gamma_val = 1 / 5

BASE_PARAMETERS = {
    "gamma": Parameter(name="gamma", value=gamma_val, units=per_day_units()),
    "alpha": Parameter(
        name="alpha", value=10_000 / N_val, units=per_day_units(), 
        distribution=Distribution(type='Uniform1',
                                  parameters={
                                      'minimum': 10 / N_val,
                                      'maximum': 12_000 / N_val
                                  })), # rate, but later we weant this to be a time varying function
    "eta": Parameter(name="eta", value=0.1, units=dimensionless_units()),
    "mu": Parameter(name="mu", value=0.003, units=dimensionless_units(), 
                   distribution=Distribution(type='Uniform1',
                                  parameters={
                                      'minimum': 0.001,
                                      'maximum': 0.06
                                  })),
    "lr": Parameter(
        name="lr", value=5, units=day_units()
    ),  # average time to recovery (duration of hospital stay if they recover)
    "ld": Parameter(
        name="ld", value=9.25, units=day_units()
    ),  # average time to recovery (duration of hospital stay if they die)
    "limm": Parameter(
        name="ld", value=1 / 365, units=day_units()
    ),  # length of immunity for recovery from infection
    "rho": Parameter(name="rho", value=1 / 2, units=per_day_units()),
    "q": Parameter(
        name="q", value=R_0 * gamma_val, units=dimensionless_units(),
        distribution=Distribution(type='Uniform1',
                                  parameters={
                                      'minimum': 0.5,
                                      'maximum': 0.6
                                  })), # transmission probability
    "d": Parameter(
        name="d", value=1 / N_val, units=per_day_per_person_units()
    ),  # scaled contact rate
    "phi": Parameter(
        name="phi", value=1.0, units=dimensionless_units(),
        distribution=Distribution(type='Uniform1',
                                  parameters={
                                      'minimum': 0.05,
                                      'maximum': 1.0
                                  })), # host susceptibility
    "chi": Parameter(
        name="chi", value=1.0, units=dimensionless_units(), 
        distribution=Distribution(type='Uniform1',
                                  parameters={
                                      'minimum': 1.0,
                                      'maximum': 2.5
                                  })), # relative transmissibility of variant
}

BASE_INITIALS = {
    "S": Initial(concept=Concept(name="S"), value=N_val - (E_val + I_val)),
    "E": Initial(concept=Concept(name="E"), value=E_val),
    "I": Initial(concept=Concept(name="I"), value=I_val),
    "R": Initial(concept=Concept(name="R"), value=0),
    "H": Initial(concept=Concept(name="H"), value=0),
    "D": Initial(concept=Concept(name="D"), value=0),
}

observables = {}

In [94]:
ORDER = ["age", "variant", "status"]
STRATA = {
    "age": AGES,
    "variant": VARIANTS,
    "status": STATUSES,
}

stratification_config = {
    "S": ["age", "status"],
    "E": ["age", "variant", "status"],
    "I": ["age", "variant", "status"],
    "R": ["age", "variant", "status"],
    "H": ["age", "variant", "status"],
    "D": ["age", "variant", "status"]
}

param_stratification_config = {
    "d": ["age", "age"],
    "eta": ["age", "variant", "status"],
    "mu": ["age", "variant", "status"],
    "ld": ["age"],
    "lr": ["age"],
    "phi": ["status"],
    "chi": ["variant"],
    "alpha": ["age"],  # vaccination rate
}


templates = []

In [95]:
# Index all p
concepts = {}

for concept, labels in stratification_config.items():
    for keys in itt.product(
        *(zip(itt.repeat(label), enumerate(STRATA[label])) for label in labels)
    ):
        d = {key: str(idx) for key, (idx, label) in keys}
        idx = tuple(d[k] for k in ORDER if k in d)
        concept_copy = _d(BASE_CONCEPTS[concept]).with_context(
            **d, do_rename=False
        )
        concept_copy.name = f"{concept_copy.name}_" + "_".join(idx)
        concepts[(concept, *idx)] = concept_copy

list(concepts.items())[30:35]

[(('E', '3', '2', '0'),
  Concept(name='E_3_2_0', display_name=None, description=None, identifiers={'apollosv': '0000154'}, context={'age': '3', 'variant': '2', 'status': '0'}, units=Unit(expression=person))),
 (('E', '3', '2', '1'),
  Concept(name='E_3_2_1', display_name=None, description=None, identifiers={'apollosv': '0000154'}, context={'age': '3', 'variant': '2', 'status': '1'}, units=Unit(expression=person))),
 (('I', '0', '0', '0'),
  Concept(name='I_0_0_0', display_name=None, description=None, identifiers={'ido': '0000511'}, context={'age': '0', 'variant': '0', 'status': '0'}, units=Unit(expression=person))),
 (('I', '0', '0', '1'),
  Concept(name='I_0_0_1', display_name=None, description=None, identifiers={'ido': '0000511'}, context={'age': '0', 'variant': '0', 'status': '1'}, units=Unit(expression=person))),
 (('I', '0', '1', '0'),
  Concept(name='I_0_1_0', display_name=None, description=None, identifiers={'ido': '0000511'}, context={'age': '0', 'variant': '1', 'status': '0'}

In [96]:
initials = {}
for concept in concepts.values():
    orig_key = concept.name.split("_")[0]
    initials[concept.name] = Initial(
        concept=concept, value=BASE_INITIALS[orig_key].value
    )

In [97]:
concept_to_strata = defaultdict(list)
for idx, concept in concepts.items():
    concept_to_strata[idx[0]].append(concept)


def concept_strata_prod(*variables: str):
    yield from itt.product(
        *(concept_to_strata[variable] for variable in variables)
    )

In [98]:
# Index all possible parameters with the name as the
parameters = {
    "rho": Parameter(name="rho", value=1 / 2, units=per_day_units()),
    "q": Parameter(
        name="q", value=R_0 * gamma_val, units=dimensionless_units()
    ),  # transmission probability
    "gamma": Parameter(name="gamma", value=gamma_val, units=per_day_units()),
    "limm": Parameter(name="limm", value=1 / 365, units=per_day_units())
}
for parameter, labels in param_stratification_config.items():
    for keys in itt.product(
        *(zip(itt.repeat(label), enumerate(STRATA[label])) for label in labels)
    ):
        d = defaultdict(list)
        for key, (idx, label) in keys:
            d[key].append(str(idx))
        d = dict(d)
        idx = tuple(itt.chain.from_iterable(d[k] for k in ORDER if k in d))
        p = _d(BASE_PARAMETERS[parameter])
        p.name = f"{p.name}_" + "_".join(idx)
        parameters[(parameter, *idx)] = p

parameter_to_strata = defaultdict(list)
for idx, parameter in parameters.items():
    parameter_to_strata[idx[0]].append(parameter)

# parameters.values()

In [99]:
def context_idx(concept):
    return tuple(
        concept.context[part]
        for part in stratification_config[concept.name.split("_")[0]]
    )


context_idx(concepts["S", "0", "0"])

('0', '0')

In [100]:
def c_symbol(concept):
    return sympy.Symbol(concept.name)


c_symbol(concepts["S", "0", "0"])

S_0_0

In [101]:
def force_of_infection_component(e, i):
    return (
        sympy.Symbol("q")
        * sympy.Symbol(
            "chi_" + e.context["variant"]
        )  # implicitly i and e have the same
        * sympy.Symbol(
            "phi_" + e.context["status"]
        )  # implicitly s and e have the same
        * sympy.Symbol("d_" + e.context["age"] + "_" + i.context["age"])
        * c_symbol(i)
    )

In [102]:
templates = []

In [103]:
def not_conserved(c1, c2):
    return not all(
        c1.context[key] == c2.context[key]
        for key in ORDER
        if key in c1.context and key in c2.context
    )


# t1 = ControlledConversion(
#    subject=c["S"], outcome=c["E"], controller=c["I"], rate_law=S * I * q * d
# )
for s, e, i in tqdm(
    concept_strata_prod("S", "E", "I"),
    unit_scale=True,
):
    if not_conserved(s, e):
        continue
    if e.context["variant"] != i.context["variant"]:
        continue
    infection = ControlledConversion(
        subject=s,
        outcome=e,
        controller=i,
        rate_law=c_symbol(s) * force_of_infection_component(e, i),
    )
    templates.append(infection)

0.00it [00:00, ?it/s]

In [104]:
# t2 = NaturalConversion(subject=c["E"], outcome=c["I"], rate_law=rho * E)
for e, i in concept_strata_prod("E", "I"):
    if not_conserved(e, i):
        continue
    t2 = NaturalConversion(
        subject=e,
        outcome=i,
        rate_law=sympy.Symbol(BASE_PARAMETERS["rho"].name) * c_symbol(e),
    )
    templates.append(t2)

In [105]:
# t3 = NaturalConversion(subject=c["I"], outcome=c["R"], rate_law=gamma * (1 - eta) * I)
for i, r in concept_strata_prod("I", "R"):
    if not_conserved(i, r):
        continue
    t3 = NaturalConversion(
        subject=i,
        outcome=r,
        rate_law=(
            sympy.Symbol(BASE_PARAMETERS["gamma"].name)
            * (1 - sympy.Symbol(parameters[("eta", *context_idx(r))].name))
            * c_symbol(i)
        ),
    )
    templates.append(t3)

In [106]:
# t4 = NaturalConversion(subject=c["I"], outcome=c["H"], rate_law=gamma * eta * I)
for i, h in concept_strata_prod("I", "H"):
    if not_conserved(i, h):
        continue
    t4 = NaturalConversion(
        subject=i,
        outcome=h,
        rate_law=(
            sympy.Symbol(BASE_PARAMETERS["gamma"].name)
            * sympy.Symbol(parameters[("eta", *context_idx(h))].name)
            * c_symbol(i)
        ),
    )
    templates.append(t4)

In [107]:
# t5 = NaturalConversion(subject=c["H"], outcome=c["R"], rate_law=(1-mu) * H / lr)
for h, r in concept_strata_prod("H", "R"):
    if not_conserved(h, r):
        continue
    t5 = NaturalConversion(
        subject=h,
        outcome=r,
        rate_law=(
            (1 - sympy.Symbol(parameters[("mu", *context_idx(r))].name))
            * c_symbol(h)
            / sympy.Symbol(parameters["lr", r.context["age"]].name)
        ),
    )
    templates.append(t5)

In [108]:
# t6 = NaturalConversion(subject=c["H"], outcome=c["D"], rate_law=mu * H / ld)
for h, d in concept_strata_prod("H", "D"):
    if not_conserved(h, d):
        continue
    t6 = NaturalConversion(
        subject=h,
        outcome=d,
        rate_law=(
            sympy.Symbol(parameters[("mu", *context_idx(d))].name)
            * c_symbol(h)
            / sympy.Symbol(parameters[("ld", d.context["age"])].name)
        ),
    )
    templates.append(t6)

In [109]:
for i, age in enumerate(AGES):
    subject = concepts["S", str(i), "0"]
    outcome = concepts["S", str(i), "1"]
    t7 = NaturalConversion(
        subject=subject,
        outcome=outcome,
        rate_law=(
            sympy.Symbol(parameters["alpha", subject.context["age"]].name)
        ),
    )
    templates.append(t7)

In [113]:
for i, status in enumerate(STATUSES):
    for j, variant in enumerate(VARIANTS):
        for k, age in enumerate(AGES):
            subject = concepts["R", str(k), str(j), str(i)]
            outcome = concepts["S", str(k), str(i)]
            t8 = NaturalConversion(
                subject=subject,
                outcome=outcome,
                rate_law=(
                    sympy.Symbol(parameters["limm"].name)
                    * c_symbol(r)
                ),
            )
            templates.append(t8)          

In [114]:
model = TemplateModel(
    templates=templates,
    parameters={p.name: p for p in parameters.values()},
    initials=initials,
    time=Time(name="t", units=day_units()),
    observables=observables,
    annotations=Annotations(
        name=f"Age Vacc Variant Model with Reinfection",
        description="This model has been stratified by age (4 ways), variant (3 ways), and vaccine status (2 ways).",
    ),
)
# sanity_check_tm(model)
am = AskeNetPetriNetModel(Model(model))
sanity_check_amr(am.to_json())
am.to_json_file("age_vacc_var_reinfection_v1.json")

In [115]:
model.parameters

{'rho': Parameter(name='rho', display_name=None, description=None, identifiers={}, context={}, units=Unit(expression=1/day), value=0.5, distribution=None),
 'q': Parameter(name='q', display_name=None, description=None, identifiers={}, context={}, units=Unit(expression=1), value=0.52, distribution=None),
 'gamma': Parameter(name='gamma', display_name=None, description=None, identifiers={}, context={}, units=Unit(expression=1/day), value=0.2, distribution=None),
 'limm': Parameter(name='limm', display_name=None, description=None, identifiers={}, context={}, units=Unit(expression=1/day), value=0.0027397260273972603, distribution=None),
 'd_0_0': Parameter(name='d_0_0', display_name=None, description=None, identifiers={}, context={}, units=Unit(expression=1/(day*person)), value=5.170630816959669e-08, distribution=None),
 'd_0_1': Parameter(name='d_0_1', display_name=None, description=None, identifiers={}, context={}, units=Unit(expression=1/(day*person)), value=5.170630816959669e-08, distr

# Post-processing of data files

In [8]:
# from pyciemss.utils.interface_utils import cdc_reformatcsv

q_ensemble_data = cdc_reformatcsv(filename="../../notebook/ensemble_eval_sa/ensemble_results/partII_ensemble_of_one_quantiles_day0_Jun032021.csv", 
                                  solution_string_mapping={"I": "cases", "H": "hospitalizations", "D": "deaths"}, 
                                  forecast_start_date="2021-07-15",
                                  location="New York State",
                                  drop_column_names=["timepoint_id", "number_days", "inc_cum", "output", "Forecast_Backcast"])
q_ensemble_data.to_csv("../../notebook/ensemble_eval_sa/ensemble_results/partII_ensemble_of_one_forecast_quantiles_Jul152021.csv")