<a href="https://colab.research.google.com/github/dshipman/colabsnippets/blob/main/vector_borne.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!pip install summerepi2

In [None]:
from summer2 import CompartmentalModel, Stratification
from summer2.parameters import Parameter, Function, CompartmentValues, Time

In [None]:
#comps = ["s_mosq", "e_mosq", "i_mosq", "s_human", "e_human", "i_human", "r_human"]

In [None]:
m = CompartmentalModel([0,1000.0], ["human","mosquito"], [])

In [None]:
m.set_initial_population({"human": 10000.0})

In [None]:
# compartments arg should accept single string

human_strat = Stratification("h_state", ["S","E","I","R"], ["human"])
human_strat.set_population_split({"S": 0.999, "E": 0.001, "I": 0.0, "R": 0.0})
m.stratify_with(human_strat)

mosq_strat = Stratification("m_state", ["S","E","I"], ["mosquito"])
mosq_strat.set_population_split({"S": 1.0, "E": 0.0, "I": 0.0})
m.stratify_with(mosq_strat)

In [None]:
def infection_m2h(comp_vals, contact_m2h):
  i_mosq_idx = m.query_compartments({"m_state": "I"}, as_idx=True)
  tot_i_mosq = comp_vals[i_mosq_idx].sum()
  return tot_i_mosq * contact_m2h

def infection_h2m(comp_vals, contact_h2m):
  i_human_idx = m.query_compartments({"h_state": "I"}, as_idx=True)
  all_human_idx = m.query_compartments({"name": "human"},as_idx=True)
  prevalence = comp_vals[i_human_idx].sum() / comp_vals[all_human_idx].sum()
  return prevalence * contact_h2m


In [None]:
def add_transition_flow(model, name, flow_param, source, dest):
    sres = model.query_compartments(source)
    dres = model.query_compartments(dest)

    snames = set([c.name for c in sres])
    dnames = set([c.name for c in dres])

    assert(len(snames) == 1)
    assert(len(dnames) == 1)

    s_name = source.pop("name") if "name" in source else list(snames)[0]
    d_name = dest.pop("name") if "name" in dest else list(dnames)[0]

    return model.add_transition_flow(name, flow_param, s_name, d_name, source, dest)


In [None]:
add_transition_flow(m, "infection_m2h", Function(infection_m2h, [CompartmentValues, Parameter("contact_m2h")]), {"h_state": "S"}, {"h_state": "E"})
add_transition_flow(m, "infection_h2m", Function(infection_h2m, [CompartmentValues, Parameter("contact_h2m")]), {"m_state": "S"}, {"m_state": "E"})

add_transition_flow(m, "progression_h", 1.0/Parameter("incubation_period_h"), {"h_state": "E"}, {"h_state": "I"})
add_transition_flow(m, "recovery_h", 1.0/Parameter("infectious_period_h"), {"h_state": "I"}, {"h_state": "R"})

add_transition_flow(m, "progression_m", 1.0/Parameter("incubation_period_m"), {"m_state": "E"}, {"m_state": "I"})

m.add_death_flow("death_m", 1.0/Parameter("mosquito_lifetime"), "mosquito")


In [None]:
from jax import numpy as jnp

In [None]:
def mosq_birth(t):
  return 1.0 + 0.5*jnp.sin(jnp.pi*2.0*t/365.0)

In [None]:
m.add_importation_flow("birth_m", Function(mosq_birth, [Time]) * Parameter("mosq_birth_rate"), "mosquito", split_imports=False, dest_strata={"m_state": "S"})

In [None]:
m.flows

In [None]:
parameters = {
    "contact_m2h": 0.00005,
    "contact_h2m": 0.1,
    "incubation_period_h": 21.0,
    "infectious_period_h": 100.0,
    "incubation_period_m": 3.0,
    "mosquito_lifetime": 14.0,
    "mosq_birth_rate": 50.0
}

In [None]:
m.run(parameters)

In [None]:
import pandas as pd
pd.options.plotting.backend="plotly"

In [None]:
m.get_outputs_df()[m.query_compartments({"name": "human"})].plot()

In [None]:
m.get_outputs_df()[m.query_compartments({"name": "mosquito"})].plot()

In [None]:
m.graph["_var10"].args

In [None]:
m.graph.draw()