# Parallelism

In [22]:
import numpy as np

import jax
import jax.numpy as jnp

from jax import jit, random, pmap, vmap, lax

### Numpy baseline

In [2]:
x = np.ones((8, 4096, 4096))
np.matmul(x, x).shape

(8, 4096, 4096)

In [3]:
%timeit -n 5 -r 5 np.matmul(x, x)

1.78 s ± 35.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


### Jax acceleration

In [4]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [5]:
x = jnp.ones((8, 4096, 4096))
jnp.matmul(x, x).shape

(8, 4096, 4096)

In [6]:
%timeit -n 5 -r 5 jnp.matmul(x, x).block_until_ready()

20.8 ms ± 102 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


### PMAP

In [7]:

xd = pmap(lambda x: jnp.ones((4096, 4096)))(np.arange(8))
print(xd.shape)
print(type(xd))

(8, 4096, 4096)
<class 'jaxlib.xla_extension.pmap_lib.ShardedDeviceArray'>


In [10]:
yd = pmap(jnp.matmul)(xd, xd)
print(yd.shape)
print(type(yd))

(8, 4096, 4096)
<class 'jaxlib.xla_extension.pmap_lib.ShardedDeviceArray'>


In [11]:
yd[0]

DeviceArray([[4096., 4096., 4096., ..., 4096., 4096., 4096.],
             [4096., 4096., 4096., ..., 4096., 4096., 4096.],
             [4096., 4096., 4096., ..., 4096., 4096., 4096.],
             ...,
             [4096., 4096., 4096., ..., 4096., 4096., 4096.],
             [4096., 4096., 4096., ..., 4096., 4096., 4096.],
             [4096., 4096., 4096., ..., 4096., 4096., 4096.]],            dtype=float32)

In [20]:
%timeit -r 5 -n 5 pmap(jnp.matmul)(xd, xd).block_until_ready()

4.45 ms ± 357 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


### Collective operations 

In [38]:
def normalize(x):
    return x / lax.psum(x, axis_name='p')

y_n = pmap(normalize, axis_name='p')(jnp.arange(8))
print(y_n.shape)
print(y_n)
print(y_n.sum())

(8,)
[0.         0.03571429 0.07142857 0.10714287 0.14285715 0.17857143
 0.21428573 0.25      ]
1.0


In [41]:
x = np.ones(8)
x

array([1., 1., 1., 1., 1., 1., 1., 1.])

In [42]:
pmap(lambda x: jax.lax.psum(x, axis_name='i'), axis_name='i')(x)

ShardedDeviceArray([8., 8., 8., 8., 8., 8., 8., 8.], dtype=float32)

### PJIT

In [39]:
from jax.experimental.pjit import pjit, PartitionSpec as P
from jax.experimental.maps import Mesh
from jax.experimental.mesh_utils import create_device_mesh