<a href="https://colab.research.google.com/github/dshipman/colabsnippets/blob/main/vector_borne.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install summerepi2

Collecting summerepi2
  Downloading summerepi2-1.3.6-py3-none-any.whl (79 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/79.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.7/79.7 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting computegraph==0.4.5 (from summerepi2)
  Downloading computegraph-0.4.5-py3-none-any.whl (18 kB)
Collecting jax>=0.4.24 (from summerepi2)
  Downloading jax-0.4.30-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jaxlib>=0.4.24 (from summerepi2)
  Downloading jaxlib-0.4.30-cp310-cp310-manylinux2014_x86_64.whl (79.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.6/79.6 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jaxlib, jax, computegraph, summerepi2
  Attempting uninstall: jaxlib
    Found existing 

In [7]:
from summer2 import CompartmentalModel, Stratification
from summer2.parameters import Parameter, Function, CompartmentValues, Time

In [4]:
comps = ["s_mosq", "e_mosq", "i_mosq", "s_human", "e_human", "i_human", "r_human"]

In [206]:
m = CompartmentalModel([0,1000.0], comps, [])

In [207]:
m.set_initial_population({"s_human": 10000.0, "e_human": 100.0})

In [208]:
for c in m.compartments:
  if "mosq" in c.name:
    c.tags.append("mosquito")
  else:
    c.tags.append("human")

In [209]:
def infection_m2h(comp_vals, contact_m2h):
  i_mosq_idx = m.query_compartments(dict(name="i_mosq"), tags=["mosquito"],as_idx=True)
  tot_i_mosq = comp_vals[i_mosq_idx].sum()
  return tot_i_mosq * contact_m2h

def infection_h2m(comp_vals, contact_h2m):
  i_human_idx = m.query_compartments(dict(name="i_human"), tags=["human"],as_idx=True)
  all_human_idx = m.query_compartments(tags=["human"],as_idx=True)
  prevalence = comp_vals[i_human_idx].sum() / comp_vals[all_human_idx].sum()
  return prevalence * contact_h2m


In [210]:
m.add_transition_flow("infection_m2h", Function(infection_m2h, [CompartmentValues, Parameter("contact_m2h")]),
                      "s_human", "e_human")

m.add_transition_flow("infection_h2m", Function(infection_h2m, [CompartmentValues, Parameter("contact_h2m")]),
                      "s_mosq", "e_mosq")

m.add_transition_flow("progression_h", 1.0/Parameter("incubation_period_h"), "e_human", "i_human")
m.add_transition_flow("recovery_h", 1.0/Parameter("infectious_period_h"), "i_human", "r_human")

m.add_transition_flow("progression_m", 1.0/Parameter("incubation_period_m"), "e_mosq", "i_mosq")

for c in m.query_compartments(tags=["mosquito"]):
  m.add_death_flow("death_m", 1.0/Parameter("mosquito_lifetime"), c.name)



In [211]:
from jax import numpy as jnp

In [212]:
def mosq_birth(t):
  return 1.0 + 0.5*jnp.sin(jnp.pi*2.0*t/365.0)

In [213]:
m.add_importation_flow("birth_m", Function(mosq_birth, [Time]) * Parameter("mosq_birth_rate"), "s_mosq", split_imports=False)

In [219]:
parameters = {
    "contact_m2h": 0.00005,
    "contact_h2m": 0.1,
    "incubation_period_h": 21.0,
    "infectious_period_h": 100.0,
    "incubation_period_m": 3.0,
    "mosquito_lifetime": 14.0,
    "mosq_birth_rate": 50.0
}

In [220]:
m.run(parameters)

In [221]:
import pandas as pd
pd.options.plotting.backend="plotly"

In [222]:
m.get_outputs_df().plot()

In [218]:
m.graph.draw()