In [None]:
# If we are running in google colab, pip install the required packages, 
# but do not modify local environments
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

if IN_COLAB:
    !pip install summerepi2==1.0.1a5

    # graphviz is installed already, but need lib too
    !apt install libgraphviz-dev
    !pip install pygraphviz

In [None]:
# Typical imports
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"

# Import Jax's numpy implementation
from jax import numpy as jnp

# Also import jax itself - normally you'll never need this directly, but we want to have a look...
import jax


In [None]:
# Now let's have a look at summer2 itself

from summer2 import CompartmentalModel, Stratification

# Import some of our parameterization types;
# These are actually ComputeGraph types - we'll examine what that means in a later section
from summer2.parameters import Parameter, Function, Data

In [None]:
# Let's start with a very simple model.  As you can see, the API is much the same as summer
# We take our parameters as arguments - which how we build our models in AuTuMN
# This is not the "right way" (or the wrong way) - just a demonstration!

def build_model(contact_rate):
    m = CompartmentalModel([0,100], ["S","I","R"], ["I"])
    m.set_initial_population({"S": 900, "I": 100})
    m.add_infection_frequency_flow("infection", contact_rate, "S", "I")
    return m

In [None]:
# Typical build/run cycle in summer

m = build_model(0.1)
m.run()
m.get_outputs_df().plot()

In [None]:
# What if we give it a Parameter object instead? (actually a Variable, we'll cover this soon)
contact_rate = Parameter("contact_rate")

m = build_model(contact_rate)

In [None]:
pdict = {
    "contact_rate": 1.0
}

m.run(pdict)
m.get_outputs_df().plot()

## ComputeGraph

In [None]:
# ComputeGraph is our new library that handles building computational structures, and 
# representations of parameters, functions, and mappings between them

# 
from computegraph.types import GraphObject, Variable, Function, Data, param as Parameter
from computegraph import ComputeGraph

### ComputeGraph Types

Everything is a GraphObject

#### Variables (generalised 'parameters')

In [None]:
# "Parameter" is not really a class - it's just a wrapper around Variable that
# tells ComputeGraph to look up its value in the "parameters" dictionary

x = Parameter("x")
x.key, x.source

In [None]:
# Have a look at the 'Parameter' we created earlier
Variable("contact_rate", "parameters") == contact_rate

In [None]:
# summer2 models keep track of their expected input parameters
m.get_input_parameters()

In [None]:
# These are assembled in 'ComputeGraph' objects internally
# The simplest ComputeGraph is a single Variable:

ComputeGraph(contact_rate).draw()

In [None]:
# Because GraphObjects are evaluated 'lazily', you can perform what look like computations on them,
# but really they just return another GraphObject

contact_rate * 2.0

In [None]:
# These are retained and can be built into graphs later

ComputeGraph(contact_rate * (1.0 / Parameter("x"))).draw()

In [None]:
# ... of the type used by summer

m.graph.draw()

### Functions
#### The other important GraphObject

In [None]:
# In order to represent Python functions in our graph structure, we need to wrap them appropriately
# Much like Variables (which look up a dictionary), Functions look up their arguments in a lazy
# fashion - instead of telling them _what_ their arguments are, you tell them how to find them

def thing(x, y):
    return 1.0/(x**y)

In [None]:
thing(2.0,3.0)

In [None]:
tfunc = Function(thing, [2.0, 3.0])

In [None]:
# Because this function doesn't have any external sources, we can evaluate it directly
tfunc.evaluate()

In [None]:
ComputeGraph(tfunc).draw()

In [None]:
x = Parameter("x")
y = Parameter("y")

a = x+y
b = np.log(x*y)

tfunc2 = Function(thing, [a, b])

In [None]:
# Another way to represent functions
from computegraph.utils import defer

In [None]:
ComputeGraph({"t": tfunc2}).draw()

In [None]:
cg = ComputeGraph({"t": tfunc2})
cg_run = cg.get_callable()

In [None]:
cg.get_input_variables()

In [None]:
params = {"x": 1.0, "y": 2.0}

cg_run(parameters=params)

In [None]:
cg.get_callable?

In [None]:
cg.get_callable(output_all=True)(parameters=params)

### Back to summer...
Now that we know a few things....


In [None]:
from summer2.parameters import Time, CompartmentValues

In [None]:
contact_rate * Time

In [None]:
m = build_model(contact_rate * Time)
m.finalize()
m.graph.draw()

In [None]:
m.graph.filter("infection_rate").draw()

In [None]:
iflows = m.query_flows(tags="infection")
m.get_object_graph(iflows[0]).draw()

In [None]:
m.query_compartments({"name": ["S","R"]}, as_idx=True)

In [None]:
m.get_input_parameters()

In [None]:
m.run({"contact_rate": 0.01})

In [None]:
m.get_outputs_df().plot()

In [None]:
# A more sensible time-varying contact_rate
# In order to run as expected, this must use Jax functions!

In [None]:
def tv_contact(t, contact_param, time_start):
    return jnp.where(t>time_start, contact_param, 0.0)

In [None]:
cr_func = defer(tv_contact)(Time, Parameter("contact_rate"), Parameter("contact_start"))

In [None]:
m = build_model(cr_func)
m.finalize()

In [None]:
m.get_input_parameters()

In [None]:
params = {
    "contact_rate": 1.0,
    "contact_start": 30.0
}

In [None]:
# Notice the slightly longer time on first run
m.run(parameters=params)

In [None]:
params = {
    "contact_rate": 1.0,
    "contact_start": 10.0
}

m.run(parameters=params)
m.get_outputs_df().plot()

In [None]:
# Let's have a look at a slightly more complex model

def build_strat_model(contact_rate, mixing_matrix):
    m = build_model(contact_rate)
    
    age_strat = Stratification("age", ["young", "old"], ["S","I","R"])
    age_strat.set_mixing_matrix(mixing_matrix)

    m.stratify_with(age_strat)

    return m

In [None]:
mm = Data(jnp.array([
    [1.1,0.8],
    [0.7, 1.2]
]))

mstrat = build_strat_model(contact_rate, mm)
mstrat.finalize()

In [None]:
mstrat.get_input_parameters()

In [None]:
mstrat.graph.draw()

In [None]:
mm = Data(jnp.array([
    [1.1,0.8],
    [0.7, 1.2]
]))

def schools_closed(t, time_start, time_end):
    return jnp.where(t > time_start, jnp.where(t < time_end, 1.0, 0.0), 0.0)

# Matrix reflecting change in contacts for school closures
school_close_mod_matrix = Data(jnp.array([
    [-1.0, 0.0],
    [0.0, 0.0]
]
))

sc_mm = mm + defer(schools_closed)(Time, Parameter("sc_start"), Parameter("sc_end")) * school_close_mod_matrix

mstrat = build_strat_model(contact_rate, sc_mm)
mstrat.finalize()

In [None]:
mstrat.graph.draw()

In [None]:
mstrat.graph.filter("mixing_matrix").draw()

In [None]:
mstrat.get_input_parameters()

In [None]:
params = {
    "contact_rate": 0.05,
    "sc_start": 25,
    "sc_end": 50.0
}

mstrat.run(parameters=params)
mstrat.get_outputs_df().plot()