## Using param-summer and computegraph in AuTuMN

In [None]:
# We currently have the slightly modified sm_sir model on a branch:

!git checkout input-graph

In [None]:
from autumn.core.project import get_project
from computegraph import ComputeGraph
import pandas as pd
pd.options.plotting.backend = "plotly"

from summer.runner.jax.util import build_model_with_jax

from computegraph.utils import get_nested_graph_dict, expand_nested_dict
from summer.parameters import find_all_parameters

from jax import jit
from jax import numpy as jnp
import numpy as np

In [None]:
p = get_project("sm_sir", "national-capital-region")

In [None]:
parameters = p.param_set.baseline.to_dict()

In [None]:
%time model, input_dict = p.build_model(parameters)

In [None]:
model

In [None]:
# We now have a dict ready to be turned into a ComputeGraph
input_dict

In [None]:
ig = ComputeGraph(input_dict)

In [None]:
ig.draw()

In [None]:
# Inspect the parameters consumed by our input graph
# Note that this specifically refers to Variables, not Parameters
# Some input graphs consume multiple dictionaries,
# even though this one only takes one (parameters)

ig.get_input_variables()

In [None]:
print("Input graph inputs")
for p in list(ig.get_input_variables()):
    print(p, expand_nested_dict(parameters)[p.name])

print("Model inputs")
for p in list(find_all_parameters(model)):
    print(p, expand_nested_dict(parameters)[p.name])

In [None]:
# Summer expects a "flat" parameter dictionary - ie in calibration format
# We will make our input graph use the same format for convenience
inputrunner = ig.get_callable(nested_params=False)

In [None]:
# Write a filtering function that consumes a parameter dict,
# makes sure it is expanded to calibration format, and 
# returns only the parameters needed by the graph
# This will help us with Jax calls later on
def get_filtered_params(cg: ComputeGraph, parameters):
    expanded = expand_nested_dict(parameters, include_parents=True)
    graph_inputs = [v.name for v in ig.get_input_variables()]
    return {k:v for k,v in expanded.items() if k in graph_inputs}

In [None]:
get_filtered_params(ig, parameters)

In [None]:
frozen_graph = ig.freeze(ig.get_input_variables(),["initial_population"],{"parameters":get_filtered_params(ig, parameters)})

In [None]:
run_frozen = frozen_graph.get_callable(False)

In [None]:
run_frozen(parameters = get_filtered_params(ig, parameters))

In [None]:
# The following command fails - why?
jit(run_frozen)(parameters = parameters)

In [None]:
# As per the last notebook, Jax only handles certain datatypes - we use the filter function
# above to ensure we're only giving it what it needs
jit(run_frozen)(parameters=get_filtered_params(ig, parameters))

In [None]:
# Compare this to running the complete graph
inputrunner(parameters = get_filtered_params(ig, parameters))

In [None]:
def run_graph_and_model(parameters, updated_params = None):
    updated_params = updated_params or {}
    parameters = expand_nested_dict(parameters, include_parents=True) | updated_params
    input_res = inputrunner(parameters=parameters)
    model.run(parameters = parameters | input_res)

In [None]:
run_graph_and_model(parameters)

orig_notifications = model.get_derived_outputs_df()['notifications'].copy()

In [None]:
run_graph_and_model(parameters, {'contact_rate': 0.15})

new_notifications = model.get_derived_outputs_df()['notifications']

pd.DataFrame({'orig': orig_notifications, 'new': new_notifications}).plot()
