# Building an integration graph

This notebook walks through the steps of building the computational graph for a model's integrator.
The basic idea is to compute all state histories forward by one time step, and then collect all the generated updates into a `scan` subgraph.

This algorithm is implemented in `sinn.models.Model::advance_updates()`. For concreteness, we use the 1-variable SCS model found in the *examples* directory.

In [1]:
import numpy as np
import theano_shim as shim
from collections import OrderedDict

import mackelab_toolbox as mtb
import mackelab_toolbox.typing

In [7]:
# Make code in examples directory importable
import sys, os
from pathlib import Path
import sinn
examples_dir = Path(sinn.__file__).parent.parent/'examples'
sys.path = [str(examples_dir)] + sys.path

In [2]:
shim.load('theano')

In [3]:
from sinn.histories import TimeAxis
from examples import SCS

In [4]:
mtb.typing.freeze_types()

In [5]:
N = 10
params = SCS.Parameters(
    N = N,
    J = np.random.normal(size=(N,N)),
    g = 1
    )

TimeAxis.time_step = np.float64(2**-6)  # Powers of 2 minimize numerical error
time = TimeAxis(min=0, max=10)

In [6]:
scs = SCS(time=time, params=params, initializer=np.ones(N))
scs._curtidx_var = shim.tensor(np.array(1, dtype=scs.tidx_dtype), name='curtidx (model)')
object.__setattr__(scs, '_stoptidx_var',
                   shim.tensor(np.array(3, dtype=scs.tidx_dtype), name='stoptidx (model)'))

First, declare a “anchor” time index – this will serve as the reference time point for all history updates.

In [8]:
anchor_tidx = scs._num_tidx

Construct a substitution dictionary converting each history's current time index to the model's (anchor) time index. This works because we make sure that all unlocked histories are “synchronized” – their current `_num_tidx` all correspond to the same point in simulation time.

In [9]:
assert scs.histories_are_synchronized()
anchor_tidx_typed = scs.time.Index(anchor_tidx)  # Do only once to keep graph as clean as possible
tidxsubs = {h._num_tidx: anchor_tidx_typed.convert(h.time)
            for h in scs.unlocked_histories}

Build the update dictionary by computing each history forward by one step.

Hidden in this simple line is the dependency resolution logic, which is all implemented within the `History` classes. Moreover, we remain completely agnostic as to _how_ a history updates its data – for example, for `Series`, `_sym_data` is a single array, but for `Spiketrain` it is a tuple of three arrays.

In [10]:
for h in scs.unlocked_statehists:
    h(h._num_tidx+1)
anchored_updates = shim.get_updates()

Replace all the history time indices by the single anchor time index

In [11]:
anchored_updates = {k: shim.graph.clone(g, replace=tidxsubs)
                    for k,g in anchored_updates.items()}

Now we construct the scan function. The update step simply “unanchors” the updates, by replacing the anchor time index by the symbolic `tidx`.

**Important** Theano can get confused when the sequence variable `tidx` appears in more than one graph. To work around this, we track an explicit variable `tidxm1` (equal `tidx - 1`) instead. The `tidx` value serves only to update `tidxm1` (we could also increment `tidxm1` and do away with `tidx` complement, keeping it only an iteration counter.)

There are no output variables – all the computations are done in the updates.
Even though dictionaries are ordered in Python3.7+, Theano still emits a warning if we don't use an `OrderedDict`.

In [12]:
def onestep(tidx, tidxm1):
    step_updates = OrderedDict(
                    (k, shim.graph.clone(g, replace={anchor_tidx: tidxm1}))
                    for k,g in anchored_updates.items())
    return [tidx], step_updates

The scan is defined by iterating over a single sequence of time indices, defined by two symbolic placeholder variables `_cur_tidxvar` and `_stopidx_var`. These will be the arguments of the final compiled function.

The `tidxm1` variable is intialized by the `outputs_info` argument.

We discard the outputs returned by `scan` (by assigning to `_`), since everything is in the updates.

In [13]:
_, upds = shim.scan(onestep,
                    sequences = [shim.arange(scs._curtidx_var+1, scs._stoptidx_var,
                    #sequences = [shim.getT().arange(-1, 8,
                                             dtype=scs.tidx_dtype)],
                    outputs_info = [scs._curtidx_var],
                    name = f"scan ({type(scs).__name__})")

**Note** There are still some anchored time index variables (`_num_tidx`) in the resulting `upds`, and I'm not entirely sure why. It seems it's fine to leave them there though.

Finally we can compile the update function.

In [14]:
advance = shim.graph.compile(inputs  = [scs._curtidx_var, scs._stoptidx_var],
                             outputs = [],
                             updates = upds)

In [15]:
scs.h._num_data.get_value()[:10]

array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [16]:
advance(scs.cur_tidx, 5)

[]

In [17]:
scs.h._num_data.get_value()[:10]

array([[1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ],
       [0.9622697 , 0.98741163, 1.01000921, 0.92892203, 0.98741345,
        0.99398004, 0.94149194, 1.0097513 , 1.03064802, 1.00111782],
       [0.92578224, 0.97578152, 1.02002321, 0.8601452 , 0.97455717,
        0.98831304, 0.88458863, 1.01943653, 1.06006262, 1.00139186],
       [0.89048657, 0.96513946, 1.03002636, 0.79369481, 0.96145926,
        0.98299558, 0.82930135, 1.02903317, 1.08817829, 1.00083828],
       [0.85633262, 0.95551232, 1.04000517, 0.72959197, 0.94814639,
        0.97802277, 0.77563647, 1.03852011, 1.11493575, 0.99947376],
       [0.82327204, 0.94692321, 1.04994867, 0.66785275, 0.93464434,
        0.97338786, 0.7235942 , 1.04787745, 1.14028322, 0.99731575],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.       