In [1]:
import jax

In [8]:
import jax.numpy as jnp
from jax import vmap
from jax.lax import fori_loop, scan

# Different Ways to Parallelize forloops

## VMAP

In [5]:
mapped_sq = vmap(lambda x: x**2)

In [6]:
# compare with the list comprehension 
lst = np.arange(1000)

%timeit [x**2 for x in lst]

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


63.7 ms ± 5.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
%timeit mapped_sq(lst)

419 µs ± 26.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


The function signature is `jax.vmap(fun, in_axes=0, out_axes=0, ...)
- `in_axes` specifies which input array axes to map (parallelize over); the number of integers correspond to the number parameters of the `fun` parameter. 
- `out_axes` specifies where the mapped axis should appear in the output. 

Examples in Linear Algebra:
1. Dot product (VV), `i,i->` in einsum
2. Matrix-vector product (MV) `ij,j->i` in einsum. In this case, vmap(VV) over the second axis. 
3. Matrix-vector product (MV) `ij,i->j` in einsum. In this case, vmap(VV) over the first axis. 
4. Matrix-Matrix product (MM) `ij,jk->ik` in einsum. In this case, vmap(MV) over the first axis 

The `out_axes`: __think about it as the dim you want to keep!__

In [9]:
import numpy as np

In [10]:
# algebraic examples from the doc

# vector dot product # ([a],[a]) -> []
vv = lambda x,y : jnp.vdot(x,y)

x = jnp.array(np.random.randn(10))
y = jnp.array(np.random.randn(10))

In [11]:
vv(x,y) # i,j->

Array(-1.7140954, dtype=float32)

In [16]:
# matrix vector product # ([a,b], a) -> [b] 
# (0, None): 1st axis of matrix, vector has no dim 
# out_axis: mapped axis (summed result) appear in the first entry 
# in einsum this is "ij,j->i" (I want to keep the 0-th dimension!)
mv = vmap(vv, (0, None), 0)
# in einsum, this is "ij,i->j" (I want to keep the 1-th dimension!)
mv2 = vmap(vv, (1, None), 0) 

In [17]:
X1 = jnp.array(np.random.randn(10,5))
X2 = jnp.array(np.random.randn(5,10))

In [18]:
# X2 = (5,10)
# x = (10,)
# output = (5,) : ij,j->i
mv(X2,x)

Array([-0.70639527,  1.512434  , -0.78491145,  0.9769071 ,  2.0266433 ],      dtype=float32)

In [20]:
# X1 = (10,5)
# x = (10,)
# output = (10, ): ij,i->j
mv2(X1, x)

Array([ 0.57710534, -4.5725822 , -0.7548177 , -3.4529448 , -0.93890244],      dtype=float32)

In [39]:
# mv: ij,j->i
# mm1: ij,jk->jk
# I want ot keep the 1-th dimension of the second matrix! 
# from mv: I want to keep 0-th dimension the 1st matrix
# the dim for k is the 1-th dimension of output
mm1 = vmap(mv, (None, 1), 1)

In [40]:
# X1 = (10, 5)
# X3 = (5, 8)
X3 = jnp.array(np.random.randn(5,8))

In [41]:
mm1(X1,X3).shape

(10, 8)

In [46]:
# X1 = (10,5)
# X4 = (8,10) 
# I want to keep 1-th dimension of X1
# I also want to keep 0-th dimension of X4
# ij,ki->j,k
# the dim in the output for k(8) is 1-th dimension (10,5),(8,10)->(5,8)
mm2 = vmap(mv2, (None, 0), 1)
# this is ij,ki->k,j
mm3 = vmap(mv2, (None, 0), 0) # (10,5), (8,10)->(8,5)

In [47]:
X4 = jnp.array(np.random.randn(8,10))
mm2(X1,X4).shape

(5, 8)

In [48]:
mm3(X1,X4).shape

(8, 5)

## Lax.scan

This function is useful for updating states using for loops. It has the following signature in python:

```py
def scan(f, init, xs, length=None):
    return [None] * length if xs is None 
    carry, ys = init, []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, np.stack(ys)
```

In [56]:
a = np.arange(500)

In [50]:
%timeit np.cumsum(a)

2.62 µs ± 163 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [51]:
from jax import lax

def cumsum(res, el):
    res = res + el 
    return res, res
result_init = 0

%timeit lax.scan(cumsum, result_init, a)

143 µs ± 2.31 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [53]:
from jax import jit

In [59]:
@jit
def scanned_cumsum(a):
    def cumsum(res, el):
        res = res + el 
        return res, res
    result_init = 0
    res = lax.scan(cumsum, result_init, a)
    return res

In [60]:
%timeit scanned_cumsum(a)

6.84 µs ± 275 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


## LAX.fori_loop

This is another optimizer for for loop, requiring a lower and upper function and returns updated value. 

```py
def fori_loop(lower, upper, body_fun, init_val):
    val = init_val
    for i in range(lower, upper):
        val = body_fun(i, val)
    return val 
```

In [61]:
@jit 
def fori_loop_cumsum(a):
    def body(i, curr):
        curr += a[i]
        return curr
    res = lax.fori_loop(0, len(a), body, 0)
    return res

In [62]:
%timeit fori_loop_cumsum(a)

5.9 µs ± 110 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
