# computegraph
a brief tour

In [None]:
# We're not really using this class, but importing it does all our required Jax initialization
from summer.runner.jax.runner import JaxRunner

import pandas as pd
import numpy as np
from jax import jit
# Note the convention:
# jax.numpy -> jnp
# numpy -> np
# 
# We add a third item to this list
# 'fnp' (functional numpy)
# This could be either jax or standard numpy,
# but we use it as a switch depending on what code we want to build

from jax import numpy as jnp

from computegraph.types import Variable, local, Function, param
from computegraph.graph import ComputeGraph
from computegraph.utils import get_nested_graph_dict


In [None]:
def getitem(obj, index):
    return obj[index]


def get_pop_dict(use_jax=False):

    if use_jax:
        fnp = jnp
    else:
        fnp = np

    def gen_data_dict():
        return {
            "AUS": fnp.array(
                (
                    10.0,
                    30.0,
                )
            ),
            "MYS": fnp.array((5.0, 9.0, 1.4)),
        }

    pop_dict = {}
    pop_dict["pop_df"] = Function(gen_data_dict)
    pop_dict["country_pop"] = Function(getitem, [local("pop_df"), param("iso")])
    pop_dict["pop_stats"] = Function(
        lambda s: {"min": s.min(), "sum": s.sum()}, [local("country_pop")]
    )
    pop_dict["pop_sum"] = Function(getitem, [local("pop_stats"), "sum"])
    pop_dict["norm_pop"] = Function(fnp.divide, [local("country_pop"), local("pop_sum")])
    pop_dict["out_pop"] = Function(fnp.multiply, [local("norm_pop"), param("pop_scale")])

    return pop_dict

In [None]:
pop_dict = get_pop_dict(use_jax=True)

In [None]:
pop_dict

In [None]:
cg = ComputeGraph(pop_dict)

In [None]:
cg.draw()

In [None]:
# What inputs does the graph take?
# Note that this specifically refers to Variables, not just Parameters
# Some input graphs consume multiple dictionaries,
# even though this one only takes one (parameters)
cg.get_input_variables()

In [None]:
# Of course, what we want is to _run_ the graph
# Let's get a function to do this; we will reuse this over different parameter inputs
runner = cg.get_callable()

In [None]:
# Try changing these
parameters = {"iso": "AUS", "pop_scale": 0.5}

# Note how the outputs contain the entire contents of the graph
runner(parameters=parameters)

In [None]:
# What if we don't want to execute everything?

# Specify some inputs as dynamic - ie we will be changing them, and we expect everything
# that depends on them to be recomputed
dynamic_inputs = [param("pop_scale")]

# Specify some outputs we care about
# This will ensure they exist in the output graph, and also make sure we are not
# computing unneeded portions of the graph
targets = ["out_pop"]

# Lastly, we need to supply it with some parameter inputs that will be used
# when calculating the fixed (frozen) part of the graph
# Note the expanded form here - remember that graphs can consume multiple dictionaries
inputs = {"parameters": parameters}

frozen_graph = cg.freeze(dynamic_inputs, targets, inputs)

In [None]:
frozen_graph.draw()

In [None]:
fast_runner = jit(frozen_graph.get_callable())

In [None]:
# Hmm, a problem...
fast_runner(parameters=parameters)

In [None]:
# We only need one parameter
frozen_graph.get_input_variables()

In [None]:
fast_runner(parameters={"pop_scale": 0.5})