In [1]:
%pylab inline
import jax
import jax.numpy as np
%load_ext autoreload
%autoreload 2

import vbjax

%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib


## Montbrio-Pazo-Roxin

Let's take the MPR model as an example, with an end goal of fitting some EEG style spectra.

In [2]:
x0 = np.zeros((2, 32))
vbjax.mpr_dfun(x0, c=[1., 0.], p=vbjax.mpr_default_theta).shape

(2, 32)

so our dfun will automatically make use of vectors for state space, let's build a network model, just by composing a coupling function with the Jansen-Rit dfun

In [3]:
nn = 32
ns = np.tile(np.arange(nn),(nn,1))
key = jax.random.PRNGKey(42)
weights = jax.random.normal (key, (nn, nn))
lengths = jax.random.randint(key, (nn, nn), 0, 255)
buffer = jax.random.normal(key, (nn, lengths.max()+1))

def net_dfun(x, p):
    w, l, buf, mass_θ = p
    lfp = x[1] - x[2]
    # c = 1e-2 * np.sum(w*buffer[ns,l], axis=1)
    c = 1e-2 * np.dot(w, lfp)
    return vbjax.mpr_dfun(x, c, mass_θ)

x0 = np.zeros((2, nn))
net_θ = weights, lengths, buffer, vbjax.mpr_default_theta
net_dfun(x0, net_θ).shape

(2, 32)

now do a simulation,

In [4]:
step, loop = vbjax.make_sde(1.0, net_dfun, 1e-1)
key = jax.random.PRNGKey(0)
nt = 1024
x0 = np.zeros((2, nn))
zt = jax.random.normal(key, (nt, ) + x0.shape) * 1e-2
yt = loop(x0, zt, net_θ)
yt.shape

(1024, 2, 32)

In [5]:
%timeit loop(x0, zt, net_θ)

47.1 µs ± 2.65 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Now let's try a batch loop over initial conditions,

In [6]:
nb = 8
x0b = np.zeros((nb, 2, nn))
jax.vmap(lambda x0: loop(x0, zt, net_θ))(x0b).shape

(8, 1024, 2, 32)

but that's essentially a outer loop since Jax can't reorder the arrays; can we batch an inner dimension instead?

In [7]:
x0b = np.zeros((2, nn, nb))

jax.vmap(lambda x0: loop(x0, zt, net_θ), 2, 3)(x0b).shape

(1024, 2, 32, 8)

That's what we'd want for a GPU, at least: stride 1 batching. What does perf say?

In [8]:
x0b = np.zeros((nb, 2, nn))
lb = jax.vmap(lambda x0: loop(x0, zt, net_θ))
lb(x0b)
%timeit lb(x0b).block_until_ready()

x0b = np.zeros((2, nn, nb))
lb = jax.vmap(lambda x0: loop(x0, zt, net_θ), 2, 3)
lb(x0b)
%timeit lb(x0b).block_until_ready()

1.92 ms ± 53.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.85 ms ± 38.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


what if we don't need to vmap?

In [9]:
%timeit loop(x0b,zt.reshape(zt.shape+ (1,)),net_θ).shape

698 µs ± 8.96 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Broadcasting ftw: that's the batching speed up we're looking for, and it's single core.