# Batching model evaluations

VEP-style field models have a high spatial resolution and usually deterministic.  At the other end of the spectrum, we have stochastic models with high temporal resolution such as the Montbrio-Pazo-Roxin or Jansen-Rit models.

### Why batching?

There are multiple applications of batching for performance

- delays are effectively random access for the memory system, so accessing a batch of data instead of a scalar allows amortizing the cost of the randomness of memory access due to delays, achieving 80%+ of memory bandwidth
- parallel-in-time evaluation for centered models can evaluate several time windows as a parallel batch
- parallel simulations run entire simulations as multiple CPU SIMD lanes or GPU threads

If we can batch time windows, we can batch entire simulations.  If we can batch delay evaluation, we can batch time windows, so it's a progessive increase in complexity. 

In [1]:
%pylab inline
import jax
import jax.numpy as np
%load_ext autoreload
%load_ext memory_profiler
%autoreload 2

import nfjax as nf

%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib


RuntimeError: module compiled against API version 0x10 but this version of numpy is 0xf

** shtns is not available


## Jansen Rit

Let's take the Jansen Rit model as an example, with an end goal of fitting some EEG style spectra.

In [3]:
x0 = np.zeros((6, 32))
nf.jr_dfun(x0, c=0, p=nf.jr_default_theta).shape

(6, 32)

so our dfun will automatically make use of vectors for state space, let's build a network model, just by composing a coupling function with the Jansen-Rit dfun

In [4]:
nn = 32
ns = np.tile(np.arange(nn),(nn,1))
key = jax.random.PRNGKey(42)
weights = jax.random.normal (key, (nn, nn))
lengths = jax.random.randint(key, (nn, nn), 0, 255)
buffer = jax.random.normal(key, (nn, lengths.max()+1))

def net_dfun(x, p):
    w, l, buf, mass_θ = p
    lfp = x[1] - x[2]
    # c = 1e-2 * np.sum(w*buffer[ns,l], axis=1)
    c = 1e-2 * np.dot(w, lfp)
    return nf.jr_dfun(x, c, mass_θ)

x0 = np.zeros((6, nn))
net_θ = weights, lengths, buffer, nf.jr_default_theta
net_dfun(x0, net_θ).shape

(6, 32)

now do a simulation,

In [5]:
step, loop = nf.make_sde(1.0, net_dfun, 1e-1)
key = jax.random.PRNGKey(0)
nt = 1024
x0 = np.zeros((6, nn))
zt = jax.random.normal(key, (nt, ) + x0.shape) * 1e-2
yt = loop(x0, zt, net_θ)
yt.shape

(1024, 6, 32)

In [6]:
%timeit loop(x0, zt, net_θ)

690 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Now let's try a batch loop over initial conditions,

In [7]:
nb = 8
x0b = np.zeros((nb, 6, nn))
jax.vmap(lambda x0: loop(x0, zt, net_θ))(x0b).shape

(8, 1024, 6, 32)

but that's essentially a outer loop since Jax can't reorder the arrays; can we batch an inner dimension instead?

In [8]:

x0b = np.zeros((6, nn, nb))

jax.vmap(lambda x0: loop(x0, zt, net_θ), 2, 3)(x0b).shape

(1024, 6, 32, 8)

That's what we'd want for a GPU, at least: stride 1 batching. What does perf say?

In [9]:
x0b = np.zeros((nb, 6, nn))
lb = jax.vmap(lambda x0: loop(x0, zt, net_θ))
lb(x0b)
%timeit lb(x0b).block_until_ready()

x0b = np.zeros((6, nn, nb))
lb = jax.vmap(lambda x0: loop(x0, zt, net_θ), 2, 3)
lb(x0b)
%timeit lb(x0b).block_until_ready()

4.68 ms ± 43.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.72 ms ± 14.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


what if we don't need to vmap?

In [10]:
%timeit loop(x0b,zt.reshape(zt.shape+ (1,)),net_θ).shape

3.81 ms ± 16.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Broadcasting ftw: that's the batching speed up we're looking for, and it's single core.