This model is based on the Scenario 1 notebook.

def SEIRHD_Model(y, t, N, beta, r_I_to_R, r_I_to_H, r_E_to_I, r_H_to_R, r_H_to_D, p_I_to_H, p_I_to_R, p_H_to_D, p_H_to_R):
    S, E, I, R, H, D = y

    dSdt = -beta * I * S / N
    dEdt = beta* I * S / N - r_E_to_I * E
    dIdt = r_E_to_I * E - (r_I_to_H * p_I_to_H) * I - (r_I_to_R * p_I_to_R * I)
    dRdt = (r_I_to_R * p_I_to_R * I) + (r_H_to_R * p_H_to_R * H)
    dHdt = (r_I_to_H * p_I_to_H * I) - (r_H_to_D * p_H_to_D * H)  - (r_H_to_R * p_H_to_R * H)
    dDdt = r_H_to_D * p_H_to_D * H
    return dSdt, dEdt, dIdt, dRdt, dHdt, dDdt

In [87]:
import sympy
from copy import deepcopy as _d
from mira.metamodel import *
from mira.modeling import Model
from mira.modeling.amr.petrinet import AMRPetriNetModel
from mira.metamodel.utils import SympyExprStr
# from mira.modeling.viz import GraphicalModel
import jsonschema
import itertools as itt
from tqdm.auto import tqdm
from collections import defaultdict
import requests
from sympy import IndexedBase, Indexed

url = "http://data-service.staging.terarium.ai/models"


def post_terarium(template_model) -> requests.Response:
    """Post a model to terarium."""
    am = AMRPetriNetModel(Model(model_2))
    sanity_check_amr(am.to_json())
    res = requests.post(url, json=am.to_json())
    return res

In [88]:
def sanity_check_tm(tm: TemplateModel):
    assert tm.templates
    all_concept_names = set(tm.get_concepts_name_map())
    all_parameter_names = set(tm.parameters)
    all_symbols = all_concept_names | all_parameter_names | ({tm.time.name} if tm.time else set())
    for template in tm.templates:
        assert template.rate_law
        symbols = template.rate_law.args[0].free_symbols
        for symbol in symbols:
            assert symbol.name in all_symbols, f"missing symbol: {symbol.name}"
    all_initial_names = {init.concept.name for init in tm.initials.values()}
    for concept in all_concept_names:
        assert concept in all_initial_names

In [89]:
def sanity_check_amr(amr_json):
    import requests

    assert "schema" in amr_json
    schema_json = requests.get(amr_json["schema"]).json()
    jsonschema.validate(schema_json, amr_json)

In [90]:
STATUSES = [
    "unvaccinated",
    "vaccinated",
]
AGES = [
    "0-19",
    "20-49",
    "50-64",
    "65",
]
VARIANTS = [
    "wild",
    "delta",
    "omicron",
]

In [91]:
person_units = lambda: Unit(expression=sympy.Symbol("person"))
day_units = lambda: Unit(expression=sympy.Symbol("day"))
per_day_units = lambda: Unit(expression=1 / sympy.Symbol("day"))
dimensionless_units = lambda: Unit(expression=sympy.Integer("1"))
per_day_per_person_units = lambda: Unit(
    expression=1 / (sympy.Symbol("day") * sympy.Symbol("person"))
)

BASE_CONCEPTS = {
    "S": Concept(name="S", units=person_units(), identifiers={"ido": "0000514"}),
    "E": Concept(name="E", units=person_units(), identifiers={"apollosv": "0000154"}),
    "I": Concept(name="I", units=person_units(), identifiers={"ido": "0000511"}),
    "R": Concept(name="R", units=person_units(), identifiers={"ido": "0000592"}),
    "H": Concept(
        name="H",
        units=person_units(),
        identifiers={"ido": "0000511"},
        context={"property": "ncit:C25179"},
    ),
    "D": Concept(name="D", units=person_units(), identifiers={"ncit": "C28554"}),
}


N_val = 19_340_000
E_val = 1
I_val = 4
R_0 = 2.6
gamma_val = 1/5

BASE_PARAMETERS = {
    "gamma":  Parameter(name="gamma", value=gamma_val, units=per_day_units()),
    "eta": Parameter(name="eta", value=0.1, units=dimensionless_units()),
    "mu":  Parameter(name="mu", value=0.003, units=dimensionless_units()),
    "lr":  Parameter(name="lr", value=5, units=day_units()), # average time to recovery (duration of hospital stay if they recover)
    "ld":  Parameter(name="ld", value=9.25, units=day_units()), # average time to recovery (duration of hospital stay if they die)
    "rho": Parameter(name="rho", value=1/2, units=per_day_units()),
    "q": Parameter(name="q", value=R_0 * gamma_val, units=dimensionless_units()), # transmission probability
    "d": Parameter(name="d", value=1/N_val, units=per_day_per_person_units()), # scaled contact rate
    "phi": Parameter(name="phi", value=1.0, units=dimensionless_units()), # host susceptibility
    "chi": Parameter(name="chi", value=1.0, units=dimensionless_units()), # relative transmissibility of variant

}

BASE_INITIALS = {
    "S": Initial(concept=Concept(name="S"), expression=SympyExprStr(sympy.Expr(N_val - (E_val + I_val)))),
    "E": Initial(concept=Concept(name="E"), expression=SympyExprStr(sympy.Expr(E_val))),
    "I": Initial(concept=Concept(name="I"), expression=SympyExprStr(sympy.Expr(I_val))),
    "R": Initial(concept=Concept(name="R"), expression=SympyExprStr(sympy.Expr(0))),
    "H": Initial(concept=Concept(name="H"), expression=SympyExprStr(sympy.Expr(0))),
    "D": Initial(concept=Concept(name="D"), expression=SympyExprStr(sympy.Expr(0))),
}

observables = {}

In [92]:
# make function to do this automatically
(
    S,
    E,
    I,
    R,
    D,
    H,
    q,
    d,  # d eats the N
    rho, # r_E_to_I,
    gamma,
    eta,
    mu,
    lr,
    ld,
    chi,
    phi,
) = sympy.symbols(
    "S E I R D H q d rho gamma eta mu lr ld chi phi"
)

In [93]:
ORDER = ["age", "variant", "status"]
STRATA = {
    "age": AGES, 
    "variant": VARIANTS,
    "status": STATUSES,
}

stratification_config = {
    "S": ["age", "status"],
    "E": ["age", "variant", "status"],
    "I": ["age", "variant", "status"],
    "R": ["age", "variant", "status"],
    "H": ["age", "variant", "status"],
    "D": ["age", "variant", "status"],
}

param_stratification_config = {
    "d": ["age", "age"],
    "eta": ["age", "variant", "status"],
    "mu": ["age", "variant", "status"],
    "ld": ["age"],
    "lr": ["age"],
    "phi": ["status"],
    "chi": ["variant"],
}


templates = []

In [94]:
stratification_config["S"]

['age', 'status']

In [95]:
# Index all p
concepts = {}

for concept, labels in stratification_config.items():
    for keys in itt.product(*(
        zip(itt.repeat(label), enumerate(STRATA[label]))
        for label in labels
    )):
        d = {
            key: str(idx)
            for key, (idx, label) in keys
        }
        idx = tuple(d[k] for k in ORDER if k in d) 
        concept_copy = _d(BASE_CONCEPTS[concept]).with_context(**d, do_rename=False)
        concept_copy.name = f"{concept_copy.name}_" + "_".join(idx)
        concepts[(concept, *idx)] = concept_copy
        
list(concepts.items())[30:35]   

[(('E', '3', '2', '0'),
  Concept(name='E_3_2_0', display_name=None, description=None, identifiers={'apollosv': '0000154'}, context={'age': '3', 'variant': '2', 'status': '0'}, units=Unit(expression=person))),
 (('E', '3', '2', '1'),
  Concept(name='E_3_2_1', display_name=None, description=None, identifiers={'apollosv': '0000154'}, context={'age': '3', 'variant': '2', 'status': '1'}, units=Unit(expression=person))),
 (('I', '0', '0', '0'),
  Concept(name='I_0_0_0', display_name=None, description=None, identifiers={'ido': '0000511'}, context={'age': '0', 'variant': '0', 'status': '0'}, units=Unit(expression=person))),
 (('I', '0', '0', '1'),
  Concept(name='I_0_0_1', display_name=None, description=None, identifiers={'ido': '0000511'}, context={'age': '0', 'variant': '0', 'status': '1'}, units=Unit(expression=person))),
 (('I', '0', '1', '0'),
  Concept(name='I_0_1_0', display_name=None, description=None, identifiers={'ido': '0000511'}, context={'age': '0', 'variant': '1', 'status': '0'}

In [96]:
initials = {}
for concept in concepts.values():
    orig_key = concept.name.split("_")[0]
    initials[concept.name] =  Initial(
        concept=concept, expression=BASE_INITIALS[orig_key].expression
    )

In [97]:
concept_to_strata = defaultdict(list)
for idx, concept in concepts.items():
    concept_to_strata[idx[0]].append(concept)
    
    
def concept_strata_prod(*variables: str):
    yield from itt.product(*(
        concept_to_strata[variable]
        for variable in variables
    ))

In [98]:
# Index all possible parameters with the name as the 
parameters = {}
for parameter, labels in param_stratification_config.items():
    for keys in itt.product(*(
        zip(itt.repeat(label), enumerate(STRATA[label]))
        for label in labels
    )):
        d = defaultdict(list)
        for key, (idx, label) in keys:
            d[key].append(str(idx))
        d = {k:sorted(v) for k,v in d.items()}
        idx = tuple(itt.chain.from_iterable(d[k] for k in ORDER if k in d))
        p = _d(BASE_PARAMETERS[parameter])
        p.name = f"{p.name}_" + "_".join(idx)
        parameters[(parameter, *idx)] = p
        
parameter_to_strata = defaultdict(list)
for idx, parameter in parameters.items():
    parameter_to_strata[idx[0]].append(parameter)
        
list(parameters.items())[:5]

[(('d', '0', '0'),
  Parameter(name='d_0_0', display_name=None, description=None, identifiers={}, context={}, units=Unit(expression=1/(day*person)), value=5.170630816959669e-08, distribution=None)),
 (('d', '0', '1'),
  Parameter(name='d_0_1', display_name=None, description=None, identifiers={}, context={}, units=Unit(expression=1/(day*person)), value=5.170630816959669e-08, distribution=None)),
 (('d', '0', '2'),
  Parameter(name='d_0_2', display_name=None, description=None, identifiers={}, context={}, units=Unit(expression=1/(day*person)), value=5.170630816959669e-08, distribution=None)),
 (('d', '0', '3'),
  Parameter(name='d_0_3', display_name=None, description=None, identifiers={}, context={}, units=Unit(expression=1/(day*person)), value=5.170630816959669e-08, distribution=None)),
 (('d', '1', '1'),
  Parameter(name='d_1_1', display_name=None, description=None, identifiers={}, context={}, units=Unit(expression=1/(day*person)), value=5.170630816959669e-08, distribution=None))]

In [99]:
def force_of_infection(a, v, i):
    d = IndexedBase('d')
    I = IndexedBase('I')
    exprs = [
        d[a, age_index] * I[age_index, v , status_index]
        for age_index, _ in enumerate(AGES)
        for status_index, _ in enumerate(STATUSES)
    ]
    return (
        q 
        * Indexed("chi", v)
        * Indexed("psi", i)
        * sum(exprs)
    )
    
force_of_infection(a='0-19', v='omicron', i='vaccinated')

q*(I[0, omicron, 0]*d[-19, 0] + I[0, omicron, 1]*d[-19, 0] + I[1, omicron, 0]*d[-19, 1] + I[1, omicron, 1]*d[-19, 1] + I[2, omicron, 0]*d[-19, 2] + I[2, omicron, 1]*d[-19, 2] + I[3, omicron, 0]*d[-19, 3] + I[3, omicron, 1]*d[-19, 3])*chi[omicron]*psi[vaccinated]

In [100]:
def context_idx(concept):
    return tuple(
        concept.context[part]
        for part in stratification_config[concept.name]
    )
context_idx(concepts['S', '0', '0'])

KeyError: 'S_0_0'

In [None]:
def c_symbol(concept):
    return sympy.Symbol(concept.name + "_" + "_".join(idx))

c_symbol(concepts['S', '0', '0'])

In [None]:
def force_of_infection_component(e, i):
    d = IndexedBase('d')
    I = IndexedBase('I')
    return (
        q 
        * sympy.Symbol("chi_" + e.context["status"])
        * sympy.Symbol("psi_" + e.context["variant"])
        * sympy.Symbol("d_" + e.context["age"] + "_" + i.context["age"])
        * c_symbol(i)
    )

In [None]:
templates = []

In [None]:
def not_conserved(c1, c2):
    return not all(
        c1.context[key] == c2.context[key]
        for key in ORDER
        if key in c1.context and key in c2.context 
    )
    
#t1 = ControlledConversion(
#    subject=c["S"], outcome=c["E"], controller=c["I"], rate_law=S * I * q * d
#)
for s, e, i in tqdm(
    concept_strata_prod("S", "E", "I"),
    unit_scale=True,
):
    if not_conserved(s, e):
        continue
    infection = ControlledConversion(
        subject=s, outcome=e, controller=i, 
        rate_law=c_symbol(s) * force_of_infection_component(e, i)
    )
    templates.append(infection)

In [None]:
# do param lookup based on output


# t2 = NaturalConversion(subject=c["E"], outcome=c["I"], rate_law=rho * E)
for e, i in concept_strata_prod("E", "I"):
    if not_conserved(e, i):
        continue
    t2 = NaturalConversion(
        subject=e, outcome=i, 
        rate_law=sympy.Symbol(BASE_PARAMETERS['rho'].name) * c_symbol(e)
    )
    templates.append(t2)

In [None]:
#t3 = NaturalConversion(subject=c["I"], outcome=c["R"], rate_law=gamma * (1 - eta) * I)
for i, r in concept_strata_prod("I", "R"):
    if not_conserved(i, r):
        continue
    t3 = NaturalConversion(
        subject=i, outcome=r, 
        rate_law=(
            sympy.Symbol(BASE_PARAMETERS['gamma'].name)
            * (1 - sympy.Symbol(parameters[('eta', *context_idx(r))].name)) 
            * c_symbol(i)
        )
    )
    templates.append(t3)

In [None]:
#t4 = NaturalConversion(subject=c["I"], outcome=c["H"], rate_law=gamma * eta * I)
for i, h in concept_strata_prod("I", "H"):
    if not_conserved(i, h):
        continue
    t4 = NaturalConversion(
        subject=i, outcome=h, 
        rate_law=(
            sympy.Symbol(BASE_PARAMETERS['gamma'].name)
            * sympy.Symbol(parameters[('eta', *context_idx(h))].name)
            * c_symbol(i)
        )
    )
    templates.append(t4)

In [None]:
#t5 = NaturalConversion(subject=c["H"], outcome=c["R"], rate_law=(1-mu) * H / lr)
for h, r in concept_strata_prod("H", "R"):
    if not_conserved(h, r):
        continue
    t5 = NaturalConversion(
        subject=h, outcome=r, 
        rate_law=(
            (1 - sympy.Symbol(parameters[('mu', *context_idx(r))].name)) 
            * c_symbol(h) / sympy.Symbol(parameters['lr', r.context['age']].name)
        )
    )
    templates.append(t5)

In [None]:
#t6 = NaturalConversion(subject=c["H"], outcome=c["D"], rate_law=mu * H / ld)
for h, d in concept_strata_prod("H", "D"):
    if not_conserved(h, d):
        continue
    t6 = NaturalConversion(
        subject=h, outcome=d, 
        rate_law=(
            sympy.Symbol(parameters[('mu', *context_idx(d))].name)
            * c_symbol(h) 
            / sympy.Symbol(parameters[('ld', d.context['age'])].name)
        )
    )
    templates.append(t6)

In [None]:
model = TemplateModel(
    templates=templates,
    parameters={p.name:p for p in parameters.values()},
    initials=initials,
    time=Time(name="t", units=day_units()),
    observables=observables,
    annotations=Annotations(name="Toby's Great Adventure SEIRHD"),
)
sanity_check_tm(model)
am = AMRPetriNetModel(Model(model))
sanity_check_amr(am.to_json())
am.to_json_file("toby_seirhd.json")

In [None]:
# GraphicalModel.for_jupyter(tm)

# Step 1 - Stratify By Age

In [None]:
model_2 = stratify(
    model_1,
    key="age",
    strata=["0_19", "20_49", "50_64", "65"],
    structure=[],
    directed=False,
    concepts_to_stratify={"S", "E", "I"},
    params_to_stratify={"eta", "mu", "lr", "ld", "d"},
    cartesian_control=True,
)
model_2.annotations.name = "Evaluation Ensemble Baseline - Step 1 - Age Stratified"

sanity_check_tm(model_2)
am = AMRPetriNetModel(Model(model_2))
sanity_check_amr(am.to_json())
am.to_json_file("eval_ensemble_step_1.json")

# Step 2 - Stratify By Vaccine Status

In [None]:
model_2 = stratify(
    model_1,
    key="status",
    strata=["unvaccinated", "vaccinated"],
    structure=[["unvaccinated", "vaccinated"]],
    directed=True,
    concepts_to_stratify={"S", "H", "R", "D"},
    params_to_stratify={"eta", "mu", "lr", "ld", "d"},
    cartesian_control=True,
)
model_2.annotations.name = "Evaluation Ensemble Baseline - Step 1 - Age Stratified"

sanity_check_tm(model_2)
am = AMRPetriNetModel(Model(model_2))
sanity_check_amr(am.to_json())
am.to_json_file("eval_ensemble_step_1.json")

Need to stratify I and E by disease variant.

In [None]:
model_2 = stratify(
    model_1,
    key="variant",
    strata=["wild", "delta", "omicron"],
    structure=[],
    directed=False,
    concepts_to_stratify={"E", "I", "H"},
    params_to_stratify={"eta", "mu", "lr", "ld", "d"},
    cartesian_control=True,
)
model_2.annotations.name = "Evaluation Ensemble Baseline - Step 1 - Age Stratified"

sanity_check_tm(model_2)
am = AMRPetriNetModel(Model(model_2))
sanity_check_amr(am.to_json())
am.to_json_file("eval_ensemble_step_1.json")

In [None]:
url = f"http://data-service.staging.terarium.ai/models/{res.json()['id']}"

res = requests.get(url)
res.json()