In [2]:
import jax.numpy as jnp
import jax.scipy as jsp

import numpy as np
import scipy as sp

from numpy import vectorize
from numpy.testing import assert_array_equal
from jax import vmap
from jax import jit
from jax import config

config.update("jax_enable_x64", True)

In [6]:
batch_size = 10_000

dim = (5, 10)

In [7]:
a = np.random.uniform(size=(batch_size, *dim))

a.shape

(10000, 5, 10)

In [8]:
a_jax = jnp.array(a)

## Test equivalence

In [9]:
_lu = sp.linalg.lu(a[0])

_lu_jax = jsp.linalg.lu(a[0])

for arrs in zip(_lu, _lu_jax):
    assert_array_equal(*arrs)

## Without VMAP

In [10]:
sp.linalg.lu(a)

ValueError: expected matrix

In [11]:
jsp.linalg.lu(a)

ValueError: too many values to unpack (expected 2)

## With VMAP

In [12]:
_lu = list(map(sp.linalg.lu, list(a)))

In [13]:
jax_lu = jit(vmap(jsp.linalg.lu))

_lu_jax = jax_lu(a)

In [14]:
for arrs in zip(_lu[0], (_lu_jax[0][0], _lu_jax[1][0], _lu_jax[2][0])):
    assert_array_equal(*arrs)

In [15]:
%timeit list(map(sp.linalg.lu, list(a)))

336 ms ± 15.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
%timeit jax_lu(a)

13 ms ± 451 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Presentation

In [7]:
a = jnp.ones((1, 2, 2))

In [8]:
jsp.linalg.lu(a)

ValueError: too many values to unpack (expected 2)

In [9]:
jax_lu = jit(vmap(jsp.linalg.lu))

In [10]:
jax_lu(a)

(DeviceArray([[[1., 0.],
               [0., 1.]]], dtype=float64),
 DeviceArray([[[1., 0.],
               [1., 1.]]], dtype=float64),
 DeviceArray([[[1., 1.],
               [0., 0.]]], dtype=float64))