<a href="https://colab.research.google.com/github/dshipman/colabsnippets/blob/main/vector_borne.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install summerepi2

In [None]:
from summer2 import CompartmentalModel, Stratification
from summer2.parameters import Parameter, Function, CompartmentValues, Time

In [None]:
comps = ["s_mosq", "e_mosq", "i_mosq", "s_human", "e_human", "i_human", "r_human"]

In [None]:
m = CompartmentalModel([0,1000.0], comps, [])

In [None]:
m.set_initial_population({"s_human": 10000.0, "e_human": 100.0})

In [None]:
for c in m.compartments:
  if "mosq" in c.name:
    c.tags.append("mosquito")
  else:
    c.tags.append("human")

In [None]:
def infection_m2h(comp_vals, contact_m2h):
  i_mosq_idx = m.query_compartments(dict(name="i_mosq"), tags=["mosquito"],as_idx=True)
  tot_i_mosq = comp_vals[i_mosq_idx].sum()
  return tot_i_mosq * contact_m2h

def infection_h2m(comp_vals, contact_h2m):
  i_human_idx = m.query_compartments(dict(name="i_human"), tags=["human"],as_idx=True)
  all_human_idx = m.query_compartments(tags=["human"],as_idx=True)
  prevalence = comp_vals[i_human_idx].sum() / comp_vals[all_human_idx].sum()
  return prevalence * contact_h2m


In [None]:
m.add_transition_flow("infection_m2h", Function(infection_m2h, [CompartmentValues, Parameter("contact_m2h")]),
                      "s_human", "e_human")

m.add_transition_flow("infection_h2m", Function(infection_h2m, [CompartmentValues, Parameter("contact_h2m")]),
                      "s_mosq", "e_mosq")

m.add_transition_flow("progression_h", 1.0/Parameter("incubation_period_h"), "e_human", "i_human")
m.add_transition_flow("recovery_h", 1.0/Parameter("infectious_period_h"), "i_human", "r_human")

m.add_transition_flow("progression_m", 1.0/Parameter("incubation_period_m"), "e_mosq", "i_mosq")

for c in m.query_compartments(tags=["mosquito"]):
  m.add_death_flow("death_m", 1.0/Parameter("mosquito_lifetime"), c.name)



In [None]:
from jax import numpy as jnp

In [None]:
def mosq_birth(t):
  return 1.0 + 0.5*jnp.sin(jnp.pi*2.0*t/365.0)

In [None]:
m.add_importation_flow("birth_m", Function(mosq_birth, [Time]) * Parameter("mosq_birth_rate"), "s_mosq", split_imports=False)

In [None]:
parameters = {
    "contact_m2h": 0.00005,
    "contact_h2m": 0.1,
    "incubation_period_h": 21.0,
    "infectious_period_h": 100.0,
    "incubation_period_m": 3.0,
    "mosquito_lifetime": 14.0,
    "mosq_birth_rate": 50.0
}

In [None]:
m.run(parameters)

In [None]:
import pandas as pd
pd.options.plotting.backend="plotly"

In [None]:
m.get_outputs_df().plot()

In [None]:
m.graph.draw()