# summer2 Internals; the when, where and what of it all...

In [None]:
# If we are running in google colab, pip install the required packages, 
# but do not modify local environments
try:
    import google.colab
    # graphviz is installed already, but need lib too
    !apt install libgraphviz-dev
    !pip install pygraphviz
    
    !pip install summerepi2==1.2.2



In [None]:
# The following imports are required for the code portions of this notebook
import numpy as np

import summer2

from summer2 import CompartmentalModel
from summer2.parameters import Parameter

# jax is used for just-in-time (JIT) compilation of the models
from jax import jit
# It provides its own implementation of numpy
from jax import numpy as jnp

# summer2's parameterization machinery happens via computegraph; we will discuss this further below
from computegraph import ComputeGraph, Function

## Just what is a model anyway?

In the context of an article or presenation, a (compartmental) model may be considered primarily as a collection of equations.  In practice, any sufficiently complex (ODE) model must be solved numerically rather than analytically, yet implementation specifics are often incompletely described, or elided altogether.

In a Bayesian context, parameters (priors) and data (observations) are intrinsic to the conception of a model, in addition to any internal numerical solution (such as the ODE solvers used in compartmental modelling).

In deep neural networks, which may have billions of parameters, but a fixed (or hyper-parameterizable) structure, behaviour is almost solely determined by weights (parameters), and so they must be included in any meaningful definition of a model.

These competing (and sometimes fuzzy) definitions can often make it difficult to cleanly separate these various conceptual layers in code.  For reasons of performance and clarity, summer2 enforces strict separations of concern; this notebook is designed to help users understand the boundaries between these layers, with particular reference to changes from previous versions of summer

## Model building, model running

Since its inception, summer has had 2 primary goals
1. Provide a robust, easy-to-understand set of methods for constructing and describing compartmental models
2. Provide a means of running these models and obtaining numerical outputs - ie to automate handling the complexities of numerical integration.


To a large extent, these are extremely complimentary goals - by using high level descriptions of structure and behaviour, users can focus on _what_ needs to be computed, not _how_ to compute it, and summer's internal code is free to make appropriate choices around optimization and numerical implementation.

### summer1 approach

In earlier versions of summer (which we will refer to as summer1), fewer distinctions were made between these 2 aspects of modelling.  In cases where we might wish to evaluate many parameter sets (for example in calibration), in order to achieve goal #2 (numerical evaluation), we would also need to repeat the processes of goal #1, as in the following pseudocode

```
run_model(parameters):
	model = build_model(parameters)
	model.run()
```

Since we expect the outputs of this process to differ by value (but retain the same structure), there is an implied contract regarding the behaviour of build_model, the enforcement of which is frequently complicated in practice by the nature of the "parameters" structure passed in to this function, which would incorporate 2 distinct kinds of 'parameter'; the first being 'true' parameters (the kind that are strictly numerical, might be calibrated, and would show up in a "table of parameters"), the second being everything else (e.g switches for certain functionality, configuration of model structure, segments of data, segments of code, metadata... the list goes on)

### summer2 approach

summer2 draws much stronger lines in the sand, and it is considered idiomatic to refer to the inputs of a 'build_model' function as "model configuration" rather than "parameters".  This configuration input concerns itself with the broad second type of data discussed above (although in doing so, it will encode the names and relationships of the ('true') parameters as well).  These functions return a CompartmentalModel object, whose structure is now fixed, but whose 'true' parameters can still be varied.

```
m = build_model(model_config)
m.run(parameters)
```

When referring to parameters in summer2, we are talking about 'true' parameters ('the greeks').  In practice, these are the values that are passed into a run() function - which can only be called after a model is finalized (and therefore has a fixed structure).  When we refer to Parameters (with an upper-case P), we are referring specifically to the computegraph objects that encode the name (key) of a parameter, and are used to describe the connection between values passed in to run() and the internal processes of the model itself.

This provides a clear separation; while a 'build_model' function might best be considered a factory for a family of (hopefully related) model types, a finalized model (CompartmentalModel) has a specified structure, a defined set of parameters (whose values may be varied), and all the code and data required for a full numerical realisation.

The sort of 'build_model' function that we use throughout our examples (and in many of our working models) is idiomatic in summer2, but not fundamental.  While summer1 required build_model at every iteration, in summer2 we might obtain a CompartmentalModel object from a build_model function, we might build it 'inline' in a notebook or script, or our code might be concerned with a pre-built model passed to us by some other function.

We should also note that while one entry point for numerically evaluating a CompartmentalModel is model.run(), most client libraries (eg estival) will use the get_runner() method to obtain a customized run function specialized on the parameters and outputs of interest; by "running the model" we generally mean evaluating a specific set of parameters by any of the above means.

## Build-time vs run-time - how does it work in practice?

As discussed above, there is a clear separation of computational concerns; certain processes happen only at 'runtime', while others may occur at any stage prior.  How does summer2 'know' where to draw these lines?

Consider the following model code;


In [None]:
m = CompartmentalModel([0,100], ["S", "I", "R"], ["I"])
m.set_initial_population({"S": 9990.0, "I": 10.0})

# Add a parameterised infection process
contact_rate = Parameter("contact_rate")
m.add_infection_frequency_flow("infection", contact_rate , "S", "I")

# Add parameterised recovery
recovery_rate = 1.0 / Parameter("recovery_period")
m.add_transition_flow("recovery", recovery_rate, "I", "R")

In [None]:
parameters = {
    "contact_rate": 0.3,
    "recovery_period": 10.0
}

m.run(parameters)
m.get_outputs_df().plot()

While our contact_rate is an unmodified parameters, our recovery flow is parameterized with 1.0 / Parameter("recovery_period").  Since we are dealing with a system of ODEs, all flow parameters must be expressed as rates internally, but there are many instances where it may be more natural to parameterize this as a period; perhaps this is how it would usually be expressed in the literature, or there may be numerical reasons for doing so.

The inputs to run() are the values of the parameters; and all 'true' model parameters are expressed using Parameter objects.  What the model 'sees' can involve (almost - within the limits of jax) arbitrary computations on these Parameters, but by definition, such computations will only happen at runtime (since the specific numerical values are supplied as arguments to run). 

Although the syntax used here is the same as it would be if were just calculating a value immediately, the object that is returned is not a numerical value, but rather an object that encodes our intentions (and providing a means to compute them later). This is referred to as lazy evaluation (in contrast to 'eager' evaluation that is executed as soon as a function is called). 

Such runtime (lazy) transformations of parameters and data are often referred to as being 'in the graph' (for reason that should seem obvious after the next cell)

In [None]:
m.graph.draw()

When we perform 1.0 / param("recovery_period"), we return a computegraph Function object.  This is a lazy object that contains both the function to be called (the divide operation; lower-case-f. just a function), as well as a mapping of inputs to pass into this function at runtime.  When we refer to a (capital-F) Function, we are talking about the computegraph object. (Parameter and Function are both subclasses of GraphObject - thus we might refer to GraphObjects more generally, as anything that might appear in a graph).

In [None]:
recovery_rate

In [None]:
ComputeGraph(recovery_rate).draw()

When a model is finalized (by running it, or by calling model.finalize()), summer2 will scan all the the possibly parameterisable aspects of the model, and if any of them contain GraphObjects, these will be collated into the final graph (model.graph).

These (directed, acyclic) graphs (DAGs), follow a functional dataflow style; each node provides some combination of inputs and output (parameters are output-only, targets are input-only, most nodes will transform their input into an output).  Each node has a fixed mapping of inputs to outputs - it always consumes its data from the same upstream nodes, and always outputs to the same nodes downstream.  They are functionally 'pure' - which is to say they do not encapsulate state, or perform side-effects - they simply return output based on their input (this is required by jax).

The lazy-evaluation/graph functionality is provided by the computegraph package, and the optimized JIT code is supplied by jax; the libraries are not intrinsically related, but are always used together in a summer2 context.

GraphObjects know about the following operations

- the standard arithmetic operations: [+, -, /, *, **]

- numpy ufuncs (ie most numpy maths functions; min/max/sin/cos etc.  See the numpy documentation for a full discussion of ufuncs)

- indexing

Whenever these operations feature a GraphObject as one (or more) of their arguments, the result will also be a GraphObject.  Computegraph is 'greedy' in this sense.


In [None]:
p = Parameter("x") # Parameter

some_data = np.linspace(0.0,2.0,10) # Plain old numpy array
scaled_data = some_data ** 2.0 # Still a plain old numpy array - we're not 'in the graph' yet

scaled_graph_data = p * scaled_data # Function object (GraphObject); everything is a GraphObject after this
transformed_data = np.sin(scaled_graph_data) # Another Function object - computegraph knows about numpy ufuncs

single_value = transformed_data[5] # Indexing a GraphObject produces another GraphObject

cg = ComputeGraph(single_value)
cg.draw()

In [None]:
# The graphs inputs can be obtained via get_input_variables
in_vars = cg.get_input_variables()
in_vars

In [None]:
# To actually run a graph, use the get_callable method

f = cg.get_callable()

# The required inputs 
f(parameters={"x": 1.0})

In [None]:
# A version of the function with all outputs included

f_all = cg.get_callable(output_all=True)
f_all(parameters={"x": 3.0})

Let's inspect a few of the objects used in constructing the graph

In [None]:
# Takes our Parameter (x) and the numpy data is arguments...
scaled_graph_data

In [None]:
# This Function has another Function in its arguments - this is how the graph is constructed
transformed_data

In [None]:
transformed_data.args[0] is scaled_graph_data

These Function objects are exactly the kind that you would construct "manually", like so

In [None]:
def some_custom_function(a, b):
    return a**b / (a-b)

In [None]:
some_custom_function(2.0,3.0)

In [None]:
# Note that we're giving this Function 2 Parameters as arguments, but these could be any combination of GraphObjects,
# or constant scalars or arrays from previous computations

custom_graphobject_func = Function(some_custom_function, [Parameter("x"), Parameter("y")])

cg = ComputeGraph(custom_graphobject_func)
cg.get_callable()(parameters={"x": 2.0, "y": 3.0})

In [None]:
cg.draw()

While you may frequently have cause to compose these kinds of manual functions, it's worth considering whether you actually need to.
There is really no reason for this function to exist;

```
def special_notifications_function(incidence, case_detection_rate):
	return incidence * case_detection_rate

notif_f = Function(special_notifications_function, [incidence, Parameter("cdr")]
```

...when it could be written as;

```
notifications = incidence * Parameter("cdr")
```

It might be preferable to 'squash' multiple nodes into a single function (as in the above some_complex_function), although there may also be instances where complex functions are split out into separate graph nodes - if for example you wish to output some intermediate value of a computation (and perhaps store it as part of a model run). This can be a  more expressive way of interacting with model internals than attempting to step through breakpoints in a debugger...

In many cases, these concerns are not unique to summer2 or computegraph, but rather a general consideration about how code should be constructed, and what constitutes a meaningful function. 

## Jax - JIT and functions

Although we use Python code to compose our jax operations, the code that is generated is not in fact Python at all, but rather an intermediate representation consumed by XLA (a domain specific compiler for Accelerated Linear Algebra).  XLA produces optimized machine code, along with providing a range of other benefits (auto-differentation, GPU execution, automatic vectorization, and much more).
However, these advantages come at the cost of some specific restrictions.

We have already mentioned that that jax has its own version of numpy, and also one of scipy (which are idiomatically imported as jnp and jsp respectively, to differentiate them from the Python versions).  Any computation that is intended to occur 'in the graph' must use the jax equivalents.

You might still see 'plain' python code inside jax functions - this is perfectly valid, but serves a very different purpose - only the pure jax code will be run at every function call, whereas the Python code is evaluated exactly once when the code is first called.
There are some very good and useful reasons why we might want to do this, but if in doubt, it's probably best to assume that you just want the jax version!

In [None]:
# Here we use the jit decorator just to ensure we're running this code "in jax"
# summer2 runs this over its whole model, so you know code must be jittable to work inside summer

# Use the 'plain python' numpy version
@jit
def numpy_sin(x):
    return np.sin(x)


# Use 'jax numpy'
@jit
def jax_sin(x):
    return jnp.sin(x)

In [None]:
numpy_sin(1.0)

Whenever you see a stack trace mentioning Tracer objects and concrete values, this is almost certainly an indicator that you are trying to call pure python code on objects that already part of jax's tracing heirarchy. Exactly as with computegraph, once something is "in the tracer", all following calls must use jax code in order to be valid. The final data returned from your (jited) function can be used as normal see fit (even if it's a jax array type rather than an ndarray), but everything happening inside JIT must be jax-only; because summer2 jits the entire model, this means all code in the graph.  Note that this is not an intrinsic limitation of computegraph itself, but rather a design decision within summer2.


In [None]:
jax_sin(1.0)

### A note about numpy ufuncs
You might have noticed we create valid jax-computable Function objects in our graph whenever we call a numpy ufunc on a GraphObject.  Yet didn't we just say these had to use the jax (jnp) versions instead?  And why are we able to use these functions directly, but not others? 

This is a little bit of convenience trickery inside computegraph; use of numpy functions is so common in transformation of model inputs, that we decided to include shorthand functionality rather than forcing users to constantly write full Functions.

While jax attempts to provide a (mostly complete) API compatible implementation of numpy, the ufuncs are a bit different.  Jax's ufunc equivalents are not 'true' ufuncs in that while they will perform the same calculations given numerical inputs, 'true' ufuncs will inspect the objects that are called on to see if they "know about ufuncs".  GraphObjects do - which is how they are able to produce more GraphObjects from these calls.  So we need to use the original (Python) numpy ufuncs while we are constructing the graph, but these are converted internally to their jax equivalents before they are actually run. 


In [None]:
sinx = np.sin(Parameter("x"))

# This is a jax PjitFunction, not a python function...
sinx.func

In [None]:
# Example of a class that can be 'inspected' by ufuncs
# GraphObjects contain this kind of logic internally

class CustomUFuncHandler:
    def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs):
        print(f"ufunc {ufunc.__name__} called on our custom object")

In [None]:
x = CustomUFuncHandler()
np.sin(x)

In [None]:
# This will fail... jnp is just trying to compute a numerical value,
# but our Parameter object doesn't have one - it's just a promise that we'll pass one in later on...
jnp.sin(Parameter("x"))

In [None]:
# To construct the sin function manually and use it inside a summer2 model, you would need to use the
# jnp version of sin; plain numpy would fail in the JIT

Function(jnp.sin, [Parameter("x")])

## Jax - Datatypes

More most of our purposes, jax can (more or less) be treated as a subset of Python. With regards to datatypes, some but not all standard datatypes are supported, for example;

- Basic numeric types can be used (float and int).
- Dictionaries, tuples (including namedtuples), and lists are usable (with some restrictions), while sets are not.
- String literals can be used as dictionary keys, but strings cannot be returned or passed into functions.

While it is possible to write adaptor code that will allow you to represent more exotic datatypes within jax, this is generally
considered outside the scope of summer2 modelling.  The most immediate implication is that 'extended' container types like those provided by Pandas and XArray will not work 'out of the box', and should be marshalled into appropriate (jax-compatible) datatypes before entering the graph;


In [None]:
import pandas as pd

In [None]:
@jit
def justmul(a, b):
    return a * b

justmul(2.0,3.0)

In [None]:
# This will fail, and give a fairly good error message about why

xs = pd.Series([0.0,1.0,2.0])
justmul(xs, 3.0)

In [None]:
# Convert it to a type jax understands...

justmul(xs.to_numpy(), 3.0)

## Splitting code between buildtime/runtime.  What goes in the graph?

The simple answer is - anything we need to parameterize!

Given this, it might be tempting to just parameterize everything we can - after all we never know when we might want to calibrate something...

There are however, a few things we definitely _never_ want in the graph.  Consider file or network I/O.  Unless we're writing very basic tests, there's a good chance that most practical models will need to load data from somewhere.  This might be something simple like a local CSV file, or something as complex as an SQL database, or something stored on a remote network.

We don't want all this overhead involved in every iteration, although we certainly need to do it someplace; we may also need to combine data from multiple sources, performing filtering operations or transformations so we end up with a format appropriate for our model.  In summer1, much of this necessarily occured within the build_model function; which leaves a rather open question about where a 'model' starts and ends - if all this filtering had been performed before we loaded the data from disk, would we still claim it was part of the model?  How about any processes that happened during the original survey that created the dataset in the first place?

It's clear in this case that I/O operations should therefore happen before anything enters the graph - now that we know the rules for the creation and propagation of GraphObjects (and their relationship with jax), it is simple enough to see how this code should be structured; do all the I/O before you get any GraphObjects involved.

It is also worth considering that, because of the restrictions jax places on us, we may not be able to rely on many external libraries, instead having to write our own equivalent code.  This can be a lot of effort, and is almost certainly more error-prone than a well tested-library that has existed 'in the wild' for some time.  Just because we might want to calibrate something "one day", doesn't mean we actually need it right now, and until such time, it can run 'outside the graph', and we can rely on the wider Python ecosystem.
