In [1]:
import sympy
from copy import deepcopy as _d
from mira.metamodel import *
from mira.modeling import Model, Author
from mira.modeling.askenet.petrinet import AskeNetPetriNetModel
import jsonschema
import itertools as itt
from tqdm.auto import tqdm
from collections import defaultdict
import requests
from sympy import IndexedBase, Indexed
from datetime import datetime

In [2]:
now = datetime.now().strftime("%m-%d %H:%M")
now

'07-19 11:16'

In [3]:
def associate(project_id, model_id) -> requests.Response:
    x = f"http://data-service.staging.terarium.ai/projects/{project_id}/assets/models/{model_id}"
    res = requests.post(x)
    print(f"associated model {model_id} to project {project_id}: ", res.json())
    return res

def post_terarium(template_model, project_id=None) -> requests.Response:
    """Post a model to terarium."""
    am = AskeNetPetriNetModel(Model(template_model))
    sanity_check_amr(am.to_json())
    return post_terarium_amr(am.to_json(), project_id=project_id)
    
def post_terarium_amr(amr, project_id=None) -> requests.Response:
    res = requests.post("http://data-service.staging.terarium.ai/models", json=amr)
    if isinstance(project_id, (str, int)):
        associate(project_id, res.json()["id"])
    elif isinstance(project_id, list):
        for i in project_id:
            associate(i, res.json()["id"])
    return res

def post_terarium_remote(model_url, project_id=None) -> requests.Response:
    model_amr_json = requests.get(model_url).json()
    return post_terarium_amr(model_amr_json, project_id=project_id)

In [4]:
def sanity_check_tm(tm: TemplateModel):
    assert tm.templates
    all_concept_names = set(tm.get_concepts_name_map())
    all_parameter_names = set(tm.parameters)
    all_symbols = all_concept_names | all_parameter_names | ({tm.time.name} if tm.time else set())
    for template in tm.templates:
        assert template.rate_law
        symbols = template.rate_law.args[0].free_symbols
        for symbol in symbols:
            assert symbol.name in all_symbols, f"missing symbol: {symbol.name}"
    all_initial_names = {init.concept.name for init in tm.initials.values()}
    for concept in all_concept_names:
        assert concept in all_initial_names

In [5]:
def sanity_check_amr(amr_json):
    import requests

    assert "schema" in amr_json
    schema_json = requests.get(amr_json["schema"]).json()
    jsonschema.validate(schema_json, amr_json)

In [6]:
STATUSES = [
    "unvaccinated",
    "vaccinated",
]
AGES = [
    "0-19",
    "20-49",
    "50-64",
    "65",
]
VARIANTS = [
    "wild",
    "delta",
    "omicron",
]

In [7]:
person_units = lambda: Unit(expression=sympy.Symbol("person"))
day_units = lambda: Unit(expression=sympy.Symbol("day"))
per_day_units = lambda: Unit(expression=1 / sympy.Symbol("day"))
dimensionless_units = lambda: Unit(expression=sympy.Integer("1"))
per_day_per_person_units = lambda: Unit(
    expression=1 / (sympy.Symbol("day") * sympy.Symbol("person"))
)

BASE_CONCEPTS = {
    "S": Concept(name="S", units=person_units(), identifiers={"ido": "0000514"}),
    "E": Concept(name="E", units=person_units(), identifiers={"apollosv": "0000154"}),
    "I": Concept(name="I", units=person_units(), identifiers={"ido": "0000511"}),
    "R": Concept(name="R", units=person_units(), identifiers={"ido": "0000592"}),
    "H": Concept(
        name="H",
        units=person_units(),
        identifiers={"ido": "0000511"},
        context={"property": "ncit:C25179"},
    ),
    "D": Concept(name="D", units=person_units(), identifiers={"ncit": "C28554"}),
}


N_val = 19_340_000
E_val = 1
I_val = 4
R_0 = 2.6
gamma_val = 1/5

BASE_PARAMETERS = {
    "gamma":  Parameter(name="gamma", value=gamma_val, units=per_day_units()),
    "alpha":  Parameter(name="alpha", value=10_000 / N_val, units=per_day_units()),  # rate, but later we weant this to be a time varying function
    "eta": Parameter(name="eta", value=0.1, units=dimensionless_units()),
    "mu":  Parameter(name="mu", value=0.003, units=dimensionless_units()),
    "lr":  Parameter(name="lr", value=5, units=day_units()), # average time to recovery (duration of hospital stay if they recover)
    "ld":  Parameter(name="ld", value=9.25, units=day_units()), # average time to recovery (duration of hospital stay if they die)
    "rho": Parameter(name="rho", value=1/2, units=per_day_units()),
    "q": Parameter(name="q", value=R_0 * gamma_val, units=dimensionless_units()), # transmission probability
    "d": Parameter(name="d", value=1/N_val, units=per_day_per_person_units()), # scaled contact rate
    "phi": Parameter(name="phi", value=1.0, units=dimensionless_units()), # host susceptibility
    "chi": Parameter(name="chi", value=1.0, units=dimensionless_units()), # relative transmissibility of variant

}

BASE_INITIALS = {
    "S": Initial(concept=Concept(name="S"), value=N_val - (E_val + I_val)),
    "E": Initial(concept=Concept(name="E"), value=E_val),
    "I": Initial(concept=Concept(name="I"), value=I_val),
    "R": Initial(concept=Concept(name="R"), value=0),
    "H": Initial(concept=Concept(name="H"), value=0),
    "D": Initial(concept=Concept(name="D"), value=0),
}

observables = {}

In [8]:
ORDER = ["age", "variant", "status"]
STRATA = {
    "age": AGES, 
    "variant": VARIANTS,
    "status": STATUSES,
}

stratification_config = {
    "S": ["age", "status"],
    "E": ["age", "variant", "status"],
    "I": ["age", "variant", "status"],
    "R": ["age", "variant", "status"],
    "H": ["age", "variant", "status"],
    "D": ["age", "variant", "status"],
}

param_stratification_config = {
    "d": ["age", "age"],
    "eta": ["age", "variant", "status"],
    "mu": ["age", "variant", "status"],
    "ld": ["age"],
    "lr": ["age"],
    "phi": ["status"],
    "chi": ["variant"],
    "alpha": ["age"],  # vaccination rate
}


templates = []

In [9]:
# Index all p
concepts = {}

for concept, labels in stratification_config.items():
    for keys in itt.product(*(
        zip(itt.repeat(label), enumerate(STRATA[label]))
        for label in labels
    )):
        d = {
            key: str(idx)
            for key, (idx, label) in keys
        }
        idx = tuple(d[k] for k in ORDER if k in d) 
        concept_copy = _d(BASE_CONCEPTS[concept]).with_context(**d, do_rename=False)
        concept_copy.name = f"{concept_copy.name}_" + "_".join(idx)
        concepts[(concept, *idx)] = concept_copy
        
list(concepts.items())[30:35]   

[(('E', '3', '2', '0'),
  Concept(name='E_3_2_0', display_name=None, description=None, identifiers={'apollosv': '0000154'}, context={'age': '3', 'variant': '2', 'status': '0'}, units=Unit(expression=person))),
 (('E', '3', '2', '1'),
  Concept(name='E_3_2_1', display_name=None, description=None, identifiers={'apollosv': '0000154'}, context={'age': '3', 'variant': '2', 'status': '1'}, units=Unit(expression=person))),
 (('I', '0', '0', '0'),
  Concept(name='I_0_0_0', display_name=None, description=None, identifiers={'ido': '0000511'}, context={'age': '0', 'variant': '0', 'status': '0'}, units=Unit(expression=person))),
 (('I', '0', '0', '1'),
  Concept(name='I_0_0_1', display_name=None, description=None, identifiers={'ido': '0000511'}, context={'age': '0', 'variant': '0', 'status': '1'}, units=Unit(expression=person))),
 (('I', '0', '1', '0'),
  Concept(name='I_0_1_0', display_name=None, description=None, identifiers={'ido': '0000511'}, context={'age': '0', 'variant': '1', 'status': '0'}

In [10]:
initials = {}
for concept in concepts.values():
    orig_key = concept.name.split("_")[0]
    initials[concept.name] =  Initial(
        concept=concept, value=BASE_INITIALS[orig_key].value
    )

In [11]:
concept_to_strata = defaultdict(list)
for idx, concept in concepts.items():
    concept_to_strata[idx[0]].append(concept)

def concept_strata_prod(*variables: str):
    yield from itt.product(*(
        concept_to_strata[variable]
        for variable in variables
    ))

In [12]:
# Index all possible parameters with the name as the 
parameters = {
    "rho": Parameter(name="rho", value=1/2, units=per_day_units()),
    "q": Parameter(name="q", value=R_0 * gamma_val, units=dimensionless_units()), # transmission probability
    "gamma":  Parameter(name="gamma", value=gamma_val, units=per_day_units()),
}
for parameter, labels in param_stratification_config.items():
    for keys in itt.product(*(
        zip(itt.repeat(label), enumerate(STRATA[label]))
        for label in labels
    )):
        d = defaultdict(list)
        for key, (idx, label) in keys:
            d[key].append(str(idx))
        d = dict(d)
        idx = tuple(itt.chain.from_iterable(d[k] for k in ORDER if k in d))
        p = _d(BASE_PARAMETERS[parameter])
        p.name = f"{p.name}_" + "_".join(idx)
        parameters[(parameter, *idx)] = p
        
parameter_to_strata = defaultdict(list)
for idx, parameter in parameters.items():
    parameter_to_strata[idx[0]].append(parameter)
        
# parameters.values()

In [13]:
def context_idx(concept):
    return tuple(
        concept.context[part]
        for part in stratification_config[concept.name.split("_")[0]]
    )
context_idx(concepts['S', '0', '0'])

('0', '0')

In [14]:
def c_symbol(concept):
    return sympy.Symbol(concept.name)

c_symbol(concepts['S', '0', '0'])

S_0_0

In [15]:
def force_of_infection_component(e, i):
    return (
        sympy.Symbol('q') 
        * sympy.Symbol("chi_" + e.context["variant"]) # implicitly i and e have the same
        * sympy.Symbol("phi_" + e.context["status"])  # implicitly s and e have the same
        * sympy.Symbol("d_" + e.context["age"] + "_" + i.context["age"])
        * c_symbol(i)
    )

In [16]:
templates = []

In [17]:
def not_conserved(c1, c2):
    return not all(
        c1.context[key] == c2.context[key]
        for key in ORDER
        if key in c1.context and key in c2.context 
    )
    
#t1 = ControlledConversion(
#    subject=c["S"], outcome=c["E"], controller=c["I"], rate_law=S * I * q * d
#)
for s, e, i in tqdm(
    concept_strata_prod("S", "E", "I"),
    unit_scale=True,
):
    if not_conserved(s, e):
        continue
    if e.context["variant"] != i.context["variant"]:
        continue
    infection = ControlledConversion(
        subject=s, outcome=e, controller=i, 
        rate_law=c_symbol(s) * force_of_infection_component(e, i)
    )
    templates.append(infection)

0.00it [00:00, ?it/s]

In [18]:
# do param lookup based on output


# t2 = NaturalConversion(subject=c["E"], outcome=c["I"], rate_law=rho * E)
for e, i in concept_strata_prod("E", "I"):
    if not_conserved(e, i):
        continue
    t2 = NaturalConversion(
        subject=e, outcome=i, 
        rate_law=sympy.Symbol(BASE_PARAMETERS['rho'].name) * c_symbol(e)
    )
    templates.append(t2)

In [19]:
#t3 = NaturalConversion(subject=c["I"], outcome=c["R"], rate_law=gamma * (1 - eta) * I)
for i, r in concept_strata_prod("I", "R"):
    if not_conserved(i, r):
        continue
    t3 = NaturalConversion(
        subject=i, outcome=r, 
        rate_law=(
            sympy.Symbol(BASE_PARAMETERS['gamma'].name)
            * (1 - sympy.Symbol(parameters[('eta', *context_idx(r))].name)) 
            * c_symbol(i)
        )
    )
    templates.append(t3)

In [20]:
#t4 = NaturalConversion(subject=c["I"], outcome=c["H"], rate_law=gamma * eta * I)
for i, h in concept_strata_prod("I", "H"):
    if not_conserved(i, h):
        continue
    t4 = NaturalConversion(
        subject=i, outcome=h, 
        rate_law=(
            sympy.Symbol(BASE_PARAMETERS['gamma'].name)
            * sympy.Symbol(parameters[('eta', *context_idx(h))].name)
            * c_symbol(i)
        )
    )
    templates.append(t4)

In [21]:
#t5 = NaturalConversion(subject=c["H"], outcome=c["R"], rate_law=(1-mu) * H / lr)
for h, r in concept_strata_prod("H", "R"):
    if not_conserved(h, r):
        continue
    t5 = NaturalConversion(
        subject=h, outcome=r, 
        rate_law=(
            (1 - sympy.Symbol(parameters[('mu', *context_idx(r))].name)) 
            * c_symbol(h) / sympy.Symbol(parameters['lr', r.context['age']].name)
        )
    )
    templates.append(t5)

In [22]:
#t6 = NaturalConversion(subject=c["H"], outcome=c["D"], rate_law=mu * H / ld)
for h, d in concept_strata_prod("H", "D"):
    if not_conserved(h, d):
        continue
    t6 = NaturalConversion(
        subject=h, outcome=d, 
        rate_law=(
            sympy.Symbol(parameters[('mu', *context_idx(d))].name)
            * c_symbol(h) 
            / sympy.Symbol(parameters[('ld', d.context['age'])].name)
        )
    )
    templates.append(t6)

## Vaccination

This is the conversion between susceptible individuals who aren't vaccinated to ones that are. This means that we need to make sure that age is the same for the individuals.

In [23]:
for i, age in enumerate(AGES):
    subject = concepts["S", str(i), '0']
    outcome = concepts["S", str(i), '1']
    t7 = NaturalConversion(
        subject=subject, outcome=outcome, 
        rate_law=(
            sympy.Symbol(parameters['alpha', subject.context['age']].name)
        )
    )
    templates.append(t7)

In [24]:
model = TemplateModel(
    templates=templates,
    parameters={p.name:p for p in parameters.values()},
    initials=initials,
    time=Time(name="t", units=day_units()),
    observables=observables,
    annotations=Annotations(
        name=f"Toby's Great Adventure SEIRHD {now}",
        license="CC0",
        authors=[
          Author(name='Toby Brett'),
          Author(name='Charles Tapley Hoyt')
        ],
        pathogens=["ncbitaxon:2697049"],
        diseases=[
            "doid:0080600",
        ],
        hosts=[
            "ncbitaxon:9606",
        ],
        model_types=[
            "mamo:0000028",
            "mamo:0000046",
        ],
        description="""\
        This model has been stratified by age (4 ways),
        variant (3 ways), and vaccine status (2 ways).

        The naming convention is to have the base
        population type (e.g., S, E, I) followed by underscores
        in the order of age, variant, disase status. All compartments
        are stratified by all three except for the S compartment, which
        is only stratified by age and status.

        Because death is stratified, it might be useful to create observables
        which e.g. take the sum over all statuses/ages/variants.

        We used numerical indexes (encoded as strings) for each of the age,
        variants, and statuses.

        - Status: (0) unvaccinated and (1) vaccinated
        - Ages: (0) for 0-19, (1) for 20-49, (2) for 50-64, and (3) for 65+
        - Variants: (0) wild, (1) delta, and (2) omicron
        
        This model also includes the conversion from unvaccinated susceptible
        individuals to vaccinated susceptible individuals (with conserved age).
        The parameter alpha ideally is a time dependent function.
        
        TODO we can model the reverse process (i.e., waning) of vaccinated
        susceptible individuals becoming unvaccinated as a decay of vaccine
        efficacy, which would also be time dependant.
        """,
    ),
)
sanity_check_tm(model)
am = AskeNetPetriNetModel(Model(model))
sanity_check_amr(am.to_json())
am.to_json_file("toby_seirhd.json")

In [25]:
if True:
    res_json = post_terarium(model, project_id=["46", "52"])
    model_id = res_json.json()["id"]
    model_id

associated model 5b444b71-eb95-4768-8e54-4cc2338e4406 to project 46:  {'id': 747}
associated model 5b444b71-eb95-4768-8e54-4cc2338e4406 to project 52:  {'id': 748}


In [26]:
# associate(model_id=model_id, project_id="46")
# associate(model_id=model_id, project_id="52")