# Batching model evaluations

VEP-style field models have a high spatial resolution and usually deterministic.  At the other end of the spectrum, we have stochastic models with high temporal resolution such as the Montbrio-Pazo-Roxin or Jansen-Rit models.

### Why batching?

There are multiple applications of batching for performance

- delays are effectively random access for the memory system, so accessing a batch of data instead of a scalar allows amortizing the cost of the randomness of memory access due to delays, achieving 80%+ of memory bandwidth
- parallel-in-time evaluation for centered models can evaluate several time windows as a parallel batch
- parallel simulations run entire simulations as multiple CPU SIMD lanes or GPU threads

If we can batch time windows, we can batch entire simulations.  If we can batch delay evaluation, we can batch time windows, so it's a progessive increase in complexity. 

In [1]:
%pylab inline
import jax
import jax.numpy as np
%load_ext autoreload
%load_ext memory_profiler
%autoreload 2

import nfjax as nf

%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib
** shtns is not available


## Jansen Rit

Let's take the Jansen Rit model as an example, with an end goal of fitting some EEG style spectra.

In [2]:
x0 = np.zeros((6, 32))
nf.jr_dfun(x=x0, c=0, p=nf.jr_default_theta).shape

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


(6, 32)

so our dfun will automatically make use of vectors for state space, let's build a network model, just by composing a coupling function with the Jansen-Rit dfun

In [13]:
def net_dfun(x, p):
    c = 1e-2*np.mean(x[1] - x[2])
    return nf.jr_dfun(x, c, p)

now do a simulation,

In [14]:
step, loop = nf.make_sde(1.0, net_dfun, 1e-1)

key = jax.random.PRNGKey(0)
nt = 1024
nn = 32
x0 = np.zeros((6, nn))
zt = jax.random.normal(key, (nt, ) + x0.shape) * 1e-2

yt = loop(x0, zt, nf.jr_default_theta)
yt.shape

(1024, 6, 32)

In [None]:
%timeit loop(x0, zt, nf.jr_default_theta)