## ⚖️ Choose A or B:

## A: Emulating multi-device system on CPU

Use this section to initialize a set of virtual devices on CPU if you have no access to a multi-device system.

It can also help you prototype, debug and test your multi-device code locally before running it on the expensive system.

Even in the case of using Google Colab it can help you prototype faster because a CPU runtime is faster to restart.

In [1]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

In [2]:
import jax
import jax.numpy as jnp

In [3]:
jax.devices("cpu")

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

## B: Setting up TPU

!! **First, use instruction from Appendix C or example from Chapter 3 to run Cloud TPU and connect Colab to it.** Then run code below !!

In [None]:
# install if you didn't install it on a Cloud TPU VM
!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

In [None]:
# need for our later example with SPMD neural net training
!pip install tensorflow

Collecting tensorflow
  Downloading tensorflow-2.13.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (479.6 MB)
[K     |█████████████████████           | 315.1 MB 431 kB/s eta 0:06:22

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



[K     |████████████████████████████████| 479.6 MB 8.1 kB/s 
[?25hCollecting absl-py>=1.0.0
  Downloading absl_py-2.0.0-py3-none-any.whl (130 kB)
[K     |████████████████████████████████| 130 kB 100.0 MB/s 
[?25hCollecting astunparse>=1.6.0
  Downloading astunparse-1.6.3-py2.py3-none-any.whl (12 kB)
Collecting flatbuffers>=23.1.21
  Downloading flatbuffers-23.5.26-py2.py3-none-any.whl (26 kB)
Collecting gast<=0.4.0,>=0.2.1
  Downloading gast-0.4.0-py3-none-any.whl (9.8 kB)
Collecting google-pasta>=0.1.1
  Downloading google_pasta-0.2.0-py3-none-any.whl (57 kB)
[K     |████████████████████████████████| 57 kB 9.0 MB/s 
[?25hCollecting grpcio<2.0,>=1.24.3
  Downloading grpcio-1.59.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB)
[K     |████████████████████████████████| 5.3 MB 84.0 MB/s 
[?25hCollecting h5py>=2.9.0
  Downloading h5py-3.10.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.8 MB)
[K     |████████████████████████████████| 4.8 MB 93.5 MB

In [None]:
!pip install tensorflow_datasets

Collecting tensorflow_datasets
  Downloading tensorflow_datasets-4.9.2-py3-none-any.whl (5.4 MB)
[K     |████████████████████████████████| 5.4 MB 4.8 MB/s 
Collecting array-record
  Downloading array_record-0.4.0-py38-none-any.whl (3.0 MB)
[K     |████████████████████████████████| 3.0 MB 93.5 MB/s 
Collecting dm-tree
  Downloading dm_tree-0.1.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (152 kB)
[K     |████████████████████████████████| 152 kB 105.7 MB/s 
[?25hCollecting etils[enp,epath]>=0.9.0
  Downloading etils-1.3.0-py3-none-any.whl (126 kB)
[K     |████████████████████████████████| 126 kB 110.7 MB/s 
Collecting promise
  Downloading promise-2.3.tar.gz (19 kB)
Collecting tensorflow-metadata
  Downloading tensorflow_metadata-1.14.0-py3-none-any.whl (28 kB)
Collecting toml
  Downloading toml-0.10.2-py2.py3-none-any.whl (16 kB)
Collecting tqdm
  Downloading tqdm-4.66.1-py3-none-any.whl (78 kB)
[K     |████████████████████████████████| 78 kB 9.0 MB/s 
Collecting goo

In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

tpu


In [None]:
import jax
import jax.numpy as jnp

In [None]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

## Parallelizing a function

In [None]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [None]:
dot(jnp.array([1., 1., 1.]), jnp.array([1., 2., -1]))

Array(2., dtype=float32)

Generate some **large** random arrays

In [None]:
from jax import random

In [None]:
rng_key = random.PRNGKey(42)

In [None]:
vs = random.normal(rng_key, shape=(20_000_000,3))

In [None]:
v1s = vs[:10_000_000,:]
v2s = vs[10_000_000:,:]

In [None]:
v1s.shape, v2s.shape

((10000000, 3), (10000000, 3))

Make a compiled vectorized function as a baseline:

In [None]:
dot_batched = jax.jit(jax.vmap(dot)).lower(v1s,v2s).compile() # you can use AOT-compilation

In [None]:
%timeit x_vmap = dot_batched(v1s, v2s).block_until_ready()

1.91 ms ± 52.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
dot_batched = jax.jit(jax.vmap(dot)) # or JIT-compilation

In [None]:
x_vmap = dot_batched(v1s, v2s)

In [None]:
x_vmap.shape

(10000000,)

In [None]:
dot_parallel = jax.pmap(dot)

In [None]:
x_pmap = dot_parallel(v1s,v2s)

ValueError: ignored

In [None]:
v1s.shape

(10000000, 3)

In [None]:
v1sp = v1s.reshape((8, v1s.shape[0]//8, v1s.shape[1]))
v2sp = v2s.reshape((8, v2s.shape[0]//8, v2s.shape[1]))    # we use integer division here

In [None]:
v1sp.shape

(8, 1250000, 3)

In [None]:
v1s[0,:], v1sp[0,0,:]

(Array([1.456546  , 1.1449531 , 0.02485494], dtype=float32),
 Array([1.456546  , 1.1449531 , 0.02485494], dtype=float32))

In [None]:
v1s[1250000,:], v1sp[1,0,:]

(Array([ 0.50284576,  0.38336843, -0.35499337], dtype=float32),
 Array([ 0.50284576,  0.38336843, -0.35499337], dtype=float32))

In [None]:
x_pmap = dot_parallel(v1sp,v2sp)

In [None]:
x_pmap.shape

(8,)

In [None]:
dot_parallel = jax.pmap(jax.vmap(dot))

In [None]:
x_pmap = dot_parallel(v1sp,v2sp)

In [None]:
x_pmap.shape

(8, 1250000)

In [None]:
type(x_pmap)

jaxlib.xla_extension.ArrayImpl

In [None]:
x_pmap = x_pmap.reshape((x_pmap.shape[0]*x_pmap.shape[1]))
x_pmap.shape

(10000000,)

In [None]:
jax.numpy.all(x_pmap == x_vmap)

Array(True, dtype=bool)

In [None]:
%timeit xp = dot_parallel(v1sp,v2sp).block_until_ready()

9.9 ms ± 49.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
x_pmap

Array([ 1.1981742 , -3.0097814 , -2.539115  , ..., -2.7281013 ,
        0.26364914,  1.9583547 ], dtype=float32)

In [None]:
x_vmap

Array([ 1.1981742 , -3.0097814 , -2.539115  , ..., -2.7281013 ,
        0.26364914,  1.9583547 ], dtype=float32)

A very simple example when you can easily switch vmap and pmap:

In [None]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [None]:
vs = random.normal(rng_key, shape=(16,3))
v1s = vs[:8,:]
v2s = vs[8:,:]

In [None]:
jax.vmap(dot)(v1s,v2s)

Array([ 0.51048726, -0.7174605 , -0.20105815, -0.26437205, -1.3696793 ,
        2.744793  ,  1.7936493 , -1.1743435 ], dtype=float32)

In [None]:
jax.pmap(dot)(v1s,v2s)

Array([ 0.51048726, -0.7174605 , -0.20105815, -0.26437205, -1.3696793 ,
        2.744793  ,  1.7936493 , -1.1743435 ], dtype=float32)

In [None]:
dot_v = jax.jit(jax.vmap(dot))
x = dot_v(v1s,v2s)

In [None]:
dot_pjo = jax.jit(jax.pmap(dot))
x = dot_pjo(v1s,v2s)



In [None]:
dot_pji = jax.pmap(jax.jit(dot))
x = dot_pji(v1s,v2s)

In [None]:
dot_p = jax.pmap(dot)
x = dot_p(v1s,v2s)

In [None]:
%timeit dot_v(v1s,v2s).block_until_ready()

122 µs ± 5.53 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
%timeit dot_pjo(v1s,v2s).block_until_ready()

2.08 ms ± 30.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit dot_pji(v1s,v2s).block_until_ready()

1.51 ms ± 43 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
%timeit dot_p(v1s,v2s).block_until_ready()

1.54 ms ± 51 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
jax.make_jaxpr(dot_v)(v1s,v2s)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[8,3][39m b[35m:f32[8,3][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[8][39m = pjit[
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; d[35m:f32[8,3][39m e[35m:f32[8,3][39m. [34m[22m[1mlet
          [39m[22m[22mf[35m:f32[8][39m = dot_general[dimension_numbers=(([1], [1]), ([0], [0]))] d e
        [34m[22m[1min [39m[22m[22m(f,) }
      name=dot
    ] a b
  [34m[22m[1min [39m[22m[22m(c,) }

In [None]:
jax.make_jaxpr(dot_pjo)(v1s,v2s)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[8,3][39m b[35m:f32[8,3][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[8][39m = pjit[
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; d[35m:f32[8,3][39m e[35m:f32[8,3][39m. [34m[22m[1mlet
          [39m[22m[22mf[35m:f32[8][39m = xla_pmap[
            axis_name=<axis 0x7f42a91d2b80>
            axis_size=8
            backend=None
            call_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; g[35m:f32[3][39m h[35m:f32[3][39m. [34m[22m[1mlet
                [39m[22m[22mi[35m:f32[][39m = dot_general[dimension_numbers=(([0], [0]), ([], []))] g
                  h
              [34m[22m[1min [39m[22m[22m(i,) }
            devices=None
            donated_invars=(False, False)
            global_axis_size=8
            in_axes=(0, 0)
            is_explicit_global_axis_size=False
            name=dot
            out_axes=(0,)
          ] d e
        [34m[22m[1min [39m[22m[22m(f,) }
      nam

In [None]:
jax.make_jaxpr(dot_pji)(v1s,v2s)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[8,3][39m b[35m:f32[8,3][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[8][39m = xla_pmap[
      axis_name=<axis 0x7f42a8f09900>
      axis_size=8
      backend=None
      call_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; d[35m:f32[3][39m e[35m:f32[3][39m. [34m[22m[1mlet
          [39m[22m[22mf[35m:f32[][39m = pjit[
            jaxpr={ [34m[22m[1mlambda [39m[22m[22m; g[35m:f32[3][39m h[35m:f32[3][39m. [34m[22m[1mlet
                [39m[22m[22mi[35m:f32[][39m = dot_general[dimension_numbers=(([0], [0]), ([], []))] g
                  h
              [34m[22m[1min [39m[22m[22m(i,) }
            name=dot
          ] d e
        [34m[22m[1min [39m[22m[22m(f,) }
      devices=None
      donated_invars=(False, False)
      global_axis_size=8
      in_axes=(0, 0)
      is_explicit_global_axis_size=False
      name=dot
      out_axes=(0,)
    ] a b
  [34m[22m[1min [39m[22m[22m(c,) }

In [None]:
jax.make_jaxpr(dot_p)(v1s,v2s)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[8,3][39m b[35m:f32[8,3][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[8][39m = xla_pmap[
      axis_name=<axis 0x7f42a91d2b80>
      axis_size=8
      backend=None
      call_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; d[35m:f32[3][39m e[35m:f32[3][39m. [34m[22m[1mlet
          [39m[22m[22mf[35m:f32[][39m = dot_general[dimension_numbers=(([0], [0]), ([], []))] d e
        [34m[22m[1min [39m[22m[22m(f,) }
      devices=None
      donated_invars=(False, False)
      global_axis_size=8
      in_axes=(0, 0)
      is_explicit_global_axis_size=False
      name=dot
      out_axes=(0,)
    ] a b
  [34m[22m[1min [39m[22m[22m(c,) }

## Controlling pmap() behavior

### Using in_axes parameter

Using a small array just for demonstration purposes

In [None]:
vs = random.normal(rng_key, shape=(16,3))
v1s = vs[:8,:]
v2s = vs[8:,:]

In [None]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

A default value:

In [None]:
dot_pmapped = jax.pmap(dot, in_axes=(0,0))

In [None]:
dot_pmapped(v1s, v2s)

Array([ 0.51048726, -0.7174605 , -0.20105815, -0.26437205, -1.3696793 ,
        2.744793  ,  1.7936493 , -1.1743435 ], dtype=float32)

What if one of our arrays is transposed (the mapping axis is the second one)?

In [None]:
v1s.T.shape, v2s.shape

((3, 8), (8, 3))

In [None]:
jax.pmap(dot, in_axes=(1,0))(v1s.T, v2s)

Array([ 0.51048726, -0.7174605 , -0.20105815, -0.26437205, -1.3696793 ,
        2.744793  ,  1.7936493 , -1.1743435 ], dtype=float32)

In [None]:
dot_pmapped(v1s.T, v2s)

ValueError: ignored

In [None]:
v1s.T.shape, v2s.T.shape

((3, 8), (3, 8))

In [None]:
jax.pmap(dot, in_axes=(1,1))(v1s.T, v2s.T)

Array([ 0.51048726, -0.7174605 , -0.20105815, -0.26437205, -1.3696793 ,
        2.744793  ,  1.7936493 , -1.1743435 ], dtype=float32)

In [None]:
dot_pmapped(v1s.T, v2s.T)

Array([ 0.00782511, -0.31532776,  1.6295187 ], dtype=float32)

A more complicated case:

In [None]:
def scaled_dot(v1, v2, koeff):
  return koeff*jnp.vdot(v1, v2)

In [None]:
v1s_ = v1s
v2s_ = v2s.T
k = 1.0

In [None]:
v1s_.shape, v2s_.shape

((8, 3), (3, 8))

In [None]:
scaled_dot_pmapped = jax.pmap(scaled_dot)

Default values do not work:

In [None]:
scaled_dot_pmapped(v1s_, v2s_, k)

ValueError: ignored

In [None]:
scaled_dot_pmapped = jax.pmap(scaled_dot, in_axes=(0,1,None))

In [None]:
scaled_dot_pmapped(v1s_, v2s_, k)

Array([ 0.51048726, -0.7174605 , -0.20105815, -0.26437205, -1.3696793 ,
        2.744793  ,  1.7936493 , -1.1743435 ], dtype=float32)

Using more complex parameter structure:

In [None]:
def scaled_dot(data, koeff):
  return koeff*jnp.vdot(data['a'], data['b'])

In [None]:
scaled_dot_pmapped = jax.pmap(scaled_dot, in_axes=({'a':0,'b':1},None))

In [None]:
scaled_dot_pmapped({'a':v1s_, 'b': v2s_}, k)

Array([ 0.51048726, -0.7174605 , -0.20105815, -0.26437205, -1.3696793 ,
        2.744793  ,  1.7936493 , -1.1743435 ], dtype=float32)

### Using out_axes parameter

In [None]:
def scale(v, koeff):
  return koeff*v

In [None]:
scale_pmapped = jax.pmap(scale,
                         in_axes=(0,None),
                         out_axes=(1))

In [None]:
res = scale_pmapped(v1s, 2.0)

In [None]:
v1s.shape, res.shape

((8, 3), (3, 8))

In [None]:
scale_pmapped = jax.pmap(scale, in_axes=(0,None))

In [None]:
scale_pmapped(v1s, 2.0)

Array([[-0.44077486, -1.7589345 , -1.7889494 ],
       [-0.5077695 , -0.6062154 ,  0.8972974 ],
       [-1.2127566 ,  0.33352432, -0.9126665 ],
       [-0.41174212,  1.8439273 ,  1.8103753 ],
       [ 0.550015  , -2.7737749 , -1.8786148 ],
       [-0.13278721, -2.1072452 ,  2.6899092 ],
       [-0.3347797 ,  2.1231527 , -3.6858454 ],
       [-0.24115095,  2.8006103 , -1.3478842 ]], dtype=float32)

In [None]:
scale_pmapped = jax.pmap(scale,
                         in_axes=(0,None),
                         out_axes=None)

## A more real-life case with larger tensors and mixing vmap and pmap

In [None]:
vs = random.normal(rng_key, shape=(20_000_000,3))
v1s = vs[:10_000_000,:].T
v2s = vs[10_000_000:,:].T

In [None]:
v1s.shape, v2s.shape

((3, 10000000), (3, 10000000))

In [None]:
v1sp = v1s.reshape((v1s.shape[0], 8, v1s.shape[1]//8))
v2sp = v2s.reshape((v2s.shape[0], 8, v2s.shape[1]//8))
v1sp.shape, v2sp.shape

((3, 8, 1250000), (3, 8, 1250000))

In [None]:
dot_parallel = jax.pmap(
    jax.vmap(dot, in_axes=(1,1)),
    in_axes=(1,1)
)

In [None]:
x_pmap = dot_parallel(v1sp,v2sp)

In [None]:
x_pmap.shape

(8, 1250000)

In [None]:
x_pmap = x_pmap.reshape((x_pmap.shape[0]*x_pmap.shape[1]))
x_pmap.shape

(10000000,)

In [None]:
jax.numpy.all(x_pmap == x_vmap)

Array(True, dtype=bool)

In [None]:
x_pmap[:3]

Array([ 1.1981742, -3.0097814, -2.539115 ], dtype=float32)

In [None]:
x_vmap[:3]

Array([ 1.1981742, -3.0097814, -2.539115 ], dtype=float32)

## Using collective ops and the axis_name parameter

### Normalization examples

A small array example

In [None]:
arr = jnp.array(range(8))

In [None]:
arr

Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)

In [None]:
norm = jax.pmap(
    lambda x: x/jax.lax.psum(x, axis_name='p'),
    axis_name='p')

In [None]:
norm(arr)

Array([0.        , 0.03571429, 0.07142857, 0.10714287, 0.14285715,
       0.17857143, 0.21428573, 0.25      ], dtype=float32)

In [None]:
jnp.sum(norm(arr))

Array(1., dtype=float32)

A large array (more elements than XLA devices) example

In [None]:
arr = jnp.array(range(200))

In [None]:
arr = arr.reshape(8, 25)
arr.shape

(8, 25)

In [None]:
norm = jax.pmap(
    lambda x: x/jax.lax.psum(jnp.sum(x), axis_name='p'),
    axis_name='p')

In [None]:
narr = norm(arr)
narr.shape

(8, 25)

In [None]:
jnp.sum(narr)

Array(1., dtype=float32)

Using groups

In [None]:
norm = jax.pmap(
    lambda x: x/jax.lax.psum(
        jnp.sum(x),
        axis_name='p',
        axis_index_groups=[[0,1], [2,3], [4,5], [6,7]]
    ),
    axis_name='p')

In [None]:
narr = norm(arr)
narr.shape

(8, 25)

In [None]:
jnp.sum(narr)

Array(4., dtype=float32)

In [None]:
jnp.sum(narr[:2]), jnp.sum(narr[2:4]), jnp.sum(narr[4:6]), jnp.sum(narr[6:])

(Array(1., dtype=float32),
 Array(1., dtype=float32),
 Array(1.0000001, dtype=float32),
 Array(1., dtype=float32))

### Using nested pmap() and vmap()

In [None]:
arr = jnp.array(range(200))
arr = arr.reshape(8, 25)
arr.shape

(8, 25)

Understanding different axes

In [None]:
f = jax.pmap(
    jax.vmap(
        lambda x: jax.lax.pmax(x, axis_name='p'),    # finding a maximum across pmax axis
        axis_name='v'
    ),
    axis_name='p')

In [None]:
f(arr)

Array([[175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
        188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
       [175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
        188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
       [175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
        188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
       [175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
        188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
       [175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
        188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
       [175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
        188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
       [175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
        188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
      

In [None]:
f = jax.pmap(
    jax.vmap(
        lambda x: jax.lax.pmax(x, axis_name=('p','v')),    # Finding global maximum across two axes
        axis_name='v'
    ),
    axis_name='p')

In [None]:
f(arr)

Array([[199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199],
       [199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199],
       [199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199],
       [199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199],
       [199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199],
       [199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199],
       [199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199],
      

Using both axes

In [None]:
f = jax.pmap(
    jax.vmap(
        lambda x: jax.lax.pmax(x, axis_name='v')/jax.lax.pmax(x, axis_name='p'),
        axis_name='v'
    ),
    axis_name='p')

In [None]:
f(arr)

Array([[0.13714285, 0.13636364, 0.1355932 , 0.13483146, 0.1340782 ,
        0.13333334, 0.13259669, 0.13186814, 0.13114755, 0.13043478,
        0.12972972, 0.12903225, 0.12834224, 0.12765957, 0.12698412,
        0.12631579, 0.12565446, 0.125     , 0.12435233, 0.12371133,
        0.12307693, 0.12244899, 0.12182741, 0.12121212, 0.12060301],
       [0.28      , 0.2784091 , 0.27683613, 0.2752809 , 0.273743  ,
        0.27222222, 0.27071825, 0.26923078, 0.2677596 , 0.26630434,
        0.26486486, 0.26344085, 0.2620321 , 0.2606383 , 0.25925925,
        0.25789472, 0.2565445 , 0.25520834, 0.253886  , 0.2525773 ,
        0.25128207, 0.25000003, 0.24873096, 0.24747474, 0.24623115],
       [0.42285714, 0.42045456, 0.41807905, 0.41573033, 0.4134078 ,
        0.41111112, 0.4088398 , 0.4065934 , 0.40437162, 0.4021739 ,
        0.39999998, 0.39784947, 0.39572194, 0.393617  , 0.3915344 ,
        0.38947368, 0.38743457, 0.3854167 , 0.3834197 , 0.3814433 ,
        0.3794872 , 0.37755105, 0.37563452, 0.

Global normalization using two axes at once

In [None]:
norm = jax.pmap(
    jax.vmap(
        lambda x: x/jax.lax.psum(x, axis_name=('p','v')),
        axis_name='v'
    ),
    axis_name='p')

In [None]:
narr = norm(arr)
narr.shape

(8, 25)

In [None]:
jnp.sum(narr)

Array(1., dtype=float32)

In [None]:
narr

Array([[0.0000000e+00, 5.0251256e-05, 1.0050251e-04, 1.5075377e-04,
        2.0100502e-04, 2.5125628e-04, 3.0150753e-04, 3.5175879e-04,
        4.0201005e-04, 4.5226130e-04, 5.0251256e-04, 5.5276381e-04,
        6.0301507e-04, 6.5326632e-04, 7.0351758e-04, 7.5376884e-04,
        8.0402009e-04, 8.5427135e-04, 9.0452260e-04, 9.5477386e-04,
        1.0050251e-03, 1.0552764e-03, 1.1055276e-03, 1.1557789e-03,
        1.2060301e-03],
       [1.2562814e-03, 1.3065326e-03, 1.3567839e-03, 1.4070352e-03,
        1.4572864e-03, 1.5075377e-03, 1.5577889e-03, 1.6080402e-03,
        1.6582914e-03, 1.7085427e-03, 1.7587940e-03, 1.8090452e-03,
        1.8592965e-03, 1.9095477e-03, 1.9597989e-03, 2.0100502e-03,
        2.0603016e-03, 2.1105527e-03, 2.1608039e-03, 2.2110553e-03,
        2.2613066e-03, 2.3115578e-03, 2.3618089e-03, 2.4120603e-03,
        2.4623116e-03],
       [2.5125628e-03, 2.5628139e-03, 2.6130653e-03, 2.6633167e-03,
        2.7135678e-03, 2.7638189e-03, 2.8140703e-03, 2.8643217e-03,


Nested pmap

In [None]:
arr = jnp.array(range(8)).reshape(2,4)
arr

Array([[0, 1, 2, 3],
       [4, 5, 6, 7]], dtype=int32)

In [None]:
n = jax.pmap(
    jax.pmap(
        lambda x: x/jax.lax.psum(x, axis_name=('rows','cols')),
        axis_name='cols'
    ),
    axis_name='rows')

In [None]:
n(arr)

Array([[0.        , 0.03571429, 0.07142857, 0.10714287],
       [0.14285715, 0.17857143, 0.21428573, 0.25      ]], dtype=float32)

In [None]:
jnp.sum(n(arr))

Array(1., dtype=float32)

The same example using the decorator style

In [None]:
from functools import partial

In [None]:
@partial(jax.pmap, axis_name='rows')
@partial(jax.pmap, axis_name='cols')
def n(x):
  return x/jax.lax.psum(x, axis_name=('rows','cols'))

In [None]:
jnp.sum(n(arr))

Array(1., dtype=float32)

## Data-parallel neural network training

### Preparing data

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

data, info = tfds.load(name="mnist",
                       data_dir=data_dir,
                       as_supervised=True,
                       with_info=True)

data_train = data['train']
data_test  = data['test']

  from .autonotebook import tqdm as notebook_tqdm


[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /tmp/tfds/mnist/3.0.1...[0m


Dl Completed...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 13.64 file/s]


[1mDataset mnist downloaded and prepared to /tmp/tfds/mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [None]:
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = info.features['label'].num_classes
NUM_DEVICES = jax.device_count()
BATCH_SIZE  = 32

In [None]:
def preprocess(img, label):
  """Resize and preprocess images."""
  return (tf.cast(img, tf.float32)/255.0), label

train_data = tfds.as_numpy(
    data_train.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)
test_data  = tfds.as_numpy(
    data_test.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)

In [None]:
len(train_data)

235

### Preparing MLP

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn import swish, logsumexp, one_hot

In [None]:
LAYER_SIZES = [28*28, 512, 10]
PARAM_SCALE = 0.01

In [None]:
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
  """Initialize all layers for a fully-connected neural network with given sizes"""

  def random_layer_params(m, n, key, scale=1e-2):
    """A helper function to randomly initialize weights and biases of a dense layer"""
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

In [None]:
def predict(params, image):
  """Function for per-example predictions."""
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = swish(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

batched_predict = vmap(predict, in_axes=(None, 0))


### Loss and update functions

In [None]:
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5
NUM_EPOCHS  = 20

In [None]:
from functools import partial

In [None]:
def loss(params, images, targets):
  """Categorical cross entropy loss function."""
  logits = batched_predict(params, images)
  log_preds = logits - logsumexp(logits) # logsumexp trick https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
  return -jnp.mean(targets*log_preds)

@partial(jax.pmap, axis_name='devices', in_axes=(None, 0, 0, None), out_axes=(None,0))
def update(params, x, y, epoch_number):
  loss_value, grads = value_and_grad(loss)(params, x, y)
  grads = [(jax.lax.psum(dw, 'devices'), jax.lax.psum(db, 'devices'))
    for dw, db in grads]
  lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
  return [(w - lr * dw, b - lr * db)
          for (w, b), (dw, db) in zip(params, grads)], loss_value

### Section for debugging purposes

In [None]:
train_data_iter = iter(train_data)
x, y = next(train_data_iter)

In [None]:
x.shape, y.shape

((256, 28, 28, 1), (256,))

In [None]:
x = jnp.reshape(x, (NUM_DEVICES, BATCH_SIZE, NUM_PIXELS))
y = jnp.reshape(one_hot(y, NUM_LABELS), (NUM_DEVICES, BATCH_SIZE, NUM_LABELS))
x.shape, y.shape

((8, 32, 784), (8, 32, 10))

In [None]:
updated_params, loss_value = update(init_params, x, y, 0)

In [None]:
loss_value

Array([0.5771865 , 0.5766423 , 0.5766001 , 0.57689124, 0.57701343,
       0.57676095, 0.57668227, 0.5764269 ], dtype=float32)

### Training loop

In [None]:
@jit
def batch_accuracy(params, images, targets):
  images = jnp.reshape(images, (len(images), NUM_PIXELS))
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == targets)

def accuracy(params, data):
  accs = []
  for images, targets in data:
    accs.append(batch_accuracy(params, images, targets))
  return jnp.mean(jnp.array(accs))

In [None]:
import time

params = init_params
for epoch in range(NUM_EPOCHS):
  start_time = time.time()
  losses = []
  for x, y in train_data:
    num_elements = len(y)
    x = jnp.reshape(x, (NUM_DEVICES, num_elements//NUM_DEVICES, NUM_PIXELS))
    y = jnp.reshape(one_hot(y, NUM_LABELS), (NUM_DEVICES, num_elements//NUM_DEVICES, NUM_LABELS))
    params, loss_value = update(params, x, y, epoch)
    losses.append(jnp.sum(loss_value))
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_data)
  test_acc = accuracy(params, test_data)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 1.61 sec
Training set loss 3.3886349201202393
Training set accuracy 0.9229609370231628
Test set accuracy 0.9217773675918579
Epoch 1 in 1.01 sec
Training set loss 3.0531487464904785
Training set accuracy 0.9479831457138062
Test set accuracy 0.9482421875
Epoch 2 in 1.01 sec
Training set loss 2.983154535293579
Training set accuracy 0.9597517848014832
Test set accuracy 0.9580078125
Epoch 3 in 1.07 sec
Training set loss 2.948392391204834
Training set accuracy 0.9668937921524048
Test set accuracy 0.96484375
Epoch 4 in 1.07 sec
Training set loss 2.928722858428955
Training set accuracy 0.9714649319648743
Test set accuracy 0.9706054925918579
Epoch 5 in 0.95 sec
Training set loss 2.913074254989624
Training set accuracy 0.9739250540733337
Test set accuracy 0.973339855670929
Epoch 6 in 0.98 sec
Training set loss 2.9012343883514404
Training set accuracy 0.9771497845649719
Test set accuracy 0.975292980670929
Epoch 7 in 1.13 sec
Training set loss 2.891705274581909
Training set accuracy 0.9