## ⚖️ Choose A or B:

## A: Emulating multi-device system on CPU

Use this section to initialize a set of virtual devices on CPU if you have no access to a multi-device system.

It can also help you prototype, debug and test your multi-device code locally before running it on the expensive system.

Even in the case of using Google Colab it can help you prototype faster because a CPU runtime is faster to restart.

In [None]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

In [None]:
import jax
import jax.numpy as jnp

In [None]:
jax.devices("cpu")

## B: Setting up TPU

!! **First, use instruction from Appendix C or example from Chapter 3 to run Cloud TPU and connect Colab to it.** Then run code below !!

In [None]:
# install if you didn't install it on a Cloud TPU VM
!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

In [None]:
# need for our later example with SPMD neural net training
!pip install tensorflow

In [None]:
!pip install tensorflow_datasets

In [1]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

tpu


In [2]:
import jax
import jax.numpy as jnp

In [3]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

## Parallelizing a function

In [4]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [5]:
dot(jnp.array([1., 1., 1.]), jnp.array([1., 2., -1]))

Array(2., dtype=float32)

Generate some **large** random arrays

In [141]:
from jax import random

In [142]:
rng_key = random.PRNGKey(42)

In [143]:
vs = random.normal(rng_key, shape=(20_000_000,3))

In [144]:
v1s = vs[:10_000_000,:]
v2s = vs[10_000_000:,:]

In [145]:
v1s.shape, v2s.shape

((10000000, 3), (10000000, 3))

Make a compiled vectorized function as a baseline:

In [146]:
dot_batched = jax.jit(jax.vmap(dot)).lower(v1s,v2s).compile() # you can use AOT-compilation

In [147]:
%timeit x_vmap = dot_batched(v1s, v2s).block_until_ready()

503 μs ± 994 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [148]:
dot_batched = jax.jit(jax.vmap(dot)) # or JIT-compilation

In [149]:
x_vmap = dot_batched(v1s, v2s)

In [150]:
x_vmap.shape

(10000000,)

In [151]:
dot_parallel = jax.pmap(dot)

In [152]:
x_pmap = dot_parallel(v1s,v2s)

ValueError: compiling computation that requires 10000000 logical devices, but only 4 XLA devices are available (num_replicas=10000000)

In [153]:
v1s.shape

(10000000, 3)

In [154]:
v1sp = v1s.reshape((2, v1s.shape[0]//2, v1s.shape[1]))
v2sp = v2s.reshape((2, v2s.shape[0]//2, v2s.shape[1]))    # we use integer division here

In [155]:
v1sp.shape

(2, 5000000, 3)

In [156]:
v1s[0,:], v1sp[0,0,:]

(Array([1.456546  , 1.1449531 , 0.02485494], dtype=float32),
 Array([1.456546  , 1.1449531 , 0.02485494], dtype=float32))

In [157]:
v1s[1250000,:], v1sp[1,0,:]

(Array([ 0.50284576,  0.38336843, -0.35499337], dtype=float32),
 Array([-0.3869151 ,  0.31480923, -0.11581508], dtype=float32))

In [158]:
x_pmap = dot_parallel(v1sp,v2sp)

In [159]:
x_pmap.shape

(2,)

In [160]:
dot_parallel = jax.pmap(jax.vmap(dot))

In [161]:
x_pmap = dot_parallel(v1sp,v2sp)

In [162]:
x_pmap.shape

(2, 5000000)

In [29]:
type(x_pmap)

jaxlib.xla_extension.ArrayImpl

In [30]:
x_pmap = x_pmap.reshape((x_pmap.shape[0]*x_pmap.shape[1]))
x_pmap.shape

(10000000,)

In [31]:
jax.numpy.all(x_pmap == x_vmap)

Array(True, dtype=bool)

In [32]:
%timeit xp = dot_parallel(v1sp,v2sp).block_until_ready()

5.68 ms ± 3.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [33]:
x_pmap

Array([ 1.1981742 , -3.0097814 , -2.539115  , ..., -2.7281013 ,
        0.26364914,  1.9583547 ], dtype=float32)

In [34]:
x_vmap

Array([ 1.1981742 , -3.0097814 , -2.539115  , ..., -2.7281013 ,
        0.26364914,  1.9583547 ], dtype=float32)

A very simple example when you can easily switch vmap and pmap:

In [35]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [39]:
vs = random.normal(rng_key, shape=(8,3))
v1s = vs[:4,:]
v2s = vs[4:,:]

In [40]:
jax.vmap(dot)(v1s,v2s)

Array([ 1.4808662 ,  0.2834839 ,  0.28449908, -0.8924192 ], dtype=float32)

In [41]:
jax.pmap(dot)(v1s,v2s)

Array([ 1.4808662 ,  0.2834839 ,  0.28449908, -0.8924192 ], dtype=float32)

In [42]:
dot_v = jax.jit(jax.vmap(dot))
x = dot_v(v1s,v2s)

In [43]:
dot_pjo = jax.jit(jax.pmap(dot))
x = dot_pjo(v1s,v2s)



In [44]:
dot_pji = jax.pmap(jax.jit(dot))
x = dot_pji(v1s,v2s)

In [45]:
dot_p = jax.pmap(dot)
x = dot_p(v1s,v2s)

In [46]:
%timeit dot_v(v1s,v2s).block_until_ready()

116 μs ± 446 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [47]:
%timeit dot_pjo(v1s,v2s).block_until_ready()

487 μs ± 10.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [48]:
%timeit dot_pji(v1s,v2s).block_until_ready()

841 μs ± 3.65 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [49]:
%timeit dot_p(v1s,v2s).block_until_ready()

827 μs ± 9.88 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [50]:
jax.make_jaxpr(dot_v)(v1s,v2s)

{ lambda ; a:f32[4,3] b:f32[4,3]. let
    c:f32[4] = pjit[
      name=dot
      jaxpr={ lambda ; d:f32[4,3] e:f32[4,3]. let
          f:f32[4] = dot_general[
            dimension_numbers=(([1], [1]), ([0], [0]))
            preferred_element_type=float32
          ] d e
        in (f,) }
    ] a b
  in (c,) }

In [51]:
jax.make_jaxpr(dot_pjo)(v1s,v2s)

{ lambda ; a:f32[4,3] b:f32[4,3]. let
    c:f32[4] = pjit[
      name=dot
      jaxpr={ lambda ; d:f32[4,3] e:f32[4,3]. let
          f:f32[4] = xla_pmap[
            axis_name=<axis 0x7fc329dead40>
            axis_size=4
            backend=None
            call_jaxpr={ lambda ; g:f32[3] h:f32[3]. let
                i:f32[] = dot_general[
                  dimension_numbers=(([0], [0]), ([], []))
                  preferred_element_type=float32
                ] g h
              in (i,) }
            devices=None
            donated_invars=(False, False)
            global_axis_size=4
            in_axes=(0, 0)
            is_explicit_global_axis_size=False
            name=dot
            out_axes=(0,)
          ] d e
        in (f,) }
    ] a b
  in (c,) }

In [52]:
jax.make_jaxpr(dot_pji)(v1s,v2s)

{ lambda ; a:f32[4,3] b:f32[4,3]. let
    c:f32[4] = xla_pmap[
      axis_name=<axis 0x7fd8586c3d30>
      axis_size=4
      backend=None
      call_jaxpr={ lambda ; d:f32[3] e:f32[3]. let
          f:f32[] = pjit[
            name=dot
            jaxpr={ lambda ; g:f32[3] h:f32[3]. let
                i:f32[] = dot_general[
                  dimension_numbers=(([0], [0]), ([], []))
                  preferred_element_type=float32
                ] g h
              in (i,) }
          ] d e
        in (f,) }
      devices=None
      donated_invars=(False, False)
      global_axis_size=4
      in_axes=(0, 0)
      is_explicit_global_axis_size=False
      name=dot
      out_axes=(0,)
    ] a b
  in (c,) }

In [53]:
jax.make_jaxpr(dot_p)(v1s,v2s)

{ lambda ; a:f32[4,3] b:f32[4,3]. let
    c:f32[4] = xla_pmap[
      axis_name=<axis 0x7fc329dead40>
      axis_size=4
      backend=None
      call_jaxpr={ lambda ; d:f32[3] e:f32[3]. let
          f:f32[] = dot_general[
            dimension_numbers=(([0], [0]), ([], []))
            preferred_element_type=float32
          ] d e
        in (f,) }
      devices=None
      donated_invars=(False, False)
      global_axis_size=4
      in_axes=(0, 0)
      is_explicit_global_axis_size=False
      name=dot
      out_axes=(0,)
    ] a b
  in (c,) }

## Controlling pmap() behavior

### Using in_axes parameter

Using a small array just for demonstration purposes

In [54]:
vs = random.normal(rng_key, shape=(8,3))
v1s = vs[:4,:]
v2s = vs[4:,:]

In [55]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

A default value:

In [56]:
dot_pmapped = jax.pmap(dot, in_axes=(0,0))

In [57]:
dot_pmapped(v1s, v2s)

Array([ 1.4808662 ,  0.2834839 ,  0.28449908, -0.8924192 ], dtype=float32)

What if one of our arrays is transposed (the mapping axis is the second one)?

In [58]:
v1s.T.shape, v2s.shape

((3, 4), (4, 3))

In [59]:
jax.pmap(dot, in_axes=(1,0))(v1s.T, v2s)

Array([ 1.4808662 ,  0.2834839 ,  0.28449908, -0.8924192 ], dtype=float32)

In [60]:
dot_pmapped(v1s.T, v2s)

ValueError: pmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 3: axis 0 of argument v1 of type float32[3,4];
  * one axis had size 4: axis 0 of argument v2 of type float32[4,3]

In [61]:
v1s.T.shape, v2s.T.shape

((3, 4), (3, 4))

In [62]:
jax.pmap(dot, in_axes=(1,1))(v1s.T, v2s.T)

Array([ 1.4808662 ,  0.2834839 ,  0.28449908, -0.8924192 ], dtype=float32)

In [63]:
dot_pmapped(v1s.T, v2s.T)

Array([-3.3310347,  3.7464314,  0.7410331], dtype=float32)

A more complicated case:

In [64]:
def scaled_dot(v1, v2, koeff):
  return koeff*jnp.vdot(v1, v2)

In [65]:
v1s_ = v1s
v2s_ = v2s.T
k = 1.0

In [66]:
v1s_.shape, v2s_.shape

((4, 3), (3, 4))

In [67]:
scaled_dot_pmapped = jax.pmap(scaled_dot)

Default values do not work:

In [68]:
scaled_dot_pmapped(v1s_, v2s_, k)

ValueError: pmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

In [69]:
scaled_dot_pmapped = jax.pmap(scaled_dot, in_axes=(0,1,None))

In [70]:
scaled_dot_pmapped(v1s_, v2s_, k)

Array([ 1.4808662 ,  0.2834839 ,  0.28449908, -0.8924192 ], dtype=float32)

Using more complex parameter structure:

In [71]:
def scaled_dot(data, koeff):
  return koeff*jnp.vdot(data['a'], data['b'])

In [72]:
scaled_dot_pmapped = jax.pmap(scaled_dot, in_axes=({'a':0,'b':1},None))

In [73]:
scaled_dot_pmapped({'a':v1s_, 'b': v2s_}, k)

Array([ 1.4808662 ,  0.2834839 ,  0.28449908, -0.8924192 ], dtype=float32)

### Using out_axes parameter

In [74]:
def scale(v, koeff):
  return koeff*v

In [75]:
scale_pmapped = jax.pmap(scale,
                         in_axes=(0,None),
                         out_axes=(1))

In [76]:
res = scale_pmapped(v1s, 2.0)

In [77]:
v1s.shape, res.shape

((4, 3), (3, 4))

In [78]:
scale_pmapped = jax.pmap(scale, in_axes=(0,None))

In [79]:
scale_pmapped(v1s, 2.0)

Array([[ 1.1957493 ,  2.100962  ,  4.260094  ],
       [-2.7670264 , -3.6707642 , -0.10890777],
       [-1.8921121 ,  0.99184775,  0.52138793],
       [ 1.7561268 ,  2.1729617 ,  2.6528773 ]], dtype=float32)

In [80]:
scale_pmapped = jax.pmap(scale,
                         in_axes=(0,None),
                         out_axes=None)

## A more real-life case with larger tensors and mixing vmap and pmap

In [81]:
vs = random.normal(rng_key, shape=(20_000_000,3))
v1s = vs[:10_000_000,:].T
v2s = vs[10_000_000:,:].T

In [82]:
v1s.shape, v2s.shape

((3, 10000000), (3, 10000000))

In [83]:
v1sp = v1s.reshape((v1s.shape[0], 4, v1s.shape[1]//4))
v2sp = v2s.reshape((v2s.shape[0], 4, v2s.shape[1]//4))
v1sp.shape, v2sp.shape

((3, 4, 2500000), (3, 4, 2500000))

In [84]:
dot_parallel = jax.pmap(
    jax.vmap(dot, in_axes=(1,1)),
    in_axes=(1,1)
)

In [85]:
x_pmap = dot_parallel(v1sp,v2sp)

In [86]:
x_pmap.shape

(4, 2500000)

In [87]:
x_pmap = x_pmap.reshape((x_pmap.shape[0]*x_pmap.shape[1]))
x_pmap.shape

(10000000,)

In [88]:
jax.numpy.all(x_pmap == x_vmap)

Array(True, dtype=bool)

In [89]:
x_pmap[:3]

Array([ 1.1981742, -3.0097814, -2.539115 ], dtype=float32)

In [90]:
x_vmap[:3]

Array([ 1.1981742, -3.0097814, -2.539115 ], dtype=float32)

## Using collective ops and the axis_name parameter

### Normalization examples

A small array example

In [95]:
arr = jnp.array(range(4))

In [96]:
arr

Array([0, 1, 2, 3], dtype=int32)

In [97]:
norm = jax.pmap(
    lambda x: x/jax.lax.psum(x, axis_name='p'),
    axis_name='p')

In [98]:
norm(arr)

Array([0.        , 0.16666667, 0.33333334, 0.5       ], dtype=float32)

In [99]:
jnp.sum(norm(arr))

Array(1., dtype=float32)

A large array (more elements than XLA devices) example

In [100]:
arr = jnp.array(range(200))

In [101]:
arr = arr.reshape(4, 50)
arr.shape

(4, 50)

In [102]:
norm = jax.pmap(
    lambda x: x/jax.lax.psum(jnp.sum(x), axis_name='p'),
    axis_name='p')

In [103]:
narr = norm(arr)
narr.shape

(4, 50)

In [104]:
jnp.sum(narr)

Array(1., dtype=float32)

Using groups

In [107]:
norm = jax.pmap(
    lambda x: x/jax.lax.psum(
        jnp.sum(x),
        axis_name='p',
        axis_index_groups=[[0,1], [2,3]]
    ),
    axis_name='p')

In [108]:
narr = norm(arr)
narr.shape

(4, 50)

In [109]:
jnp.sum(narr)

Array(2.0000002, dtype=float32)

In [110]:
jnp.sum(narr[:2]), jnp.sum(narr[2:4])

(Array(1., dtype=float32), Array(1., dtype=float32))

### Using nested pmap() and vmap()

In [111]:
arr = jnp.array(range(200))
arr = arr.reshape(4, 50)
arr.shape

(4, 50)

Understanding different axes

In [112]:
f = jax.pmap(
    jax.vmap(
        lambda x: jax.lax.pmax(x, axis_name='p'),    # finding a maximum across pmax axis
        axis_name='v'
    ),
    axis_name='p')

In [113]:
f(arr)

Array([[150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162,
        163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
        176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188,
        189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
       [150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162,
        163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
        176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188,
        189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
       [150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162,
        163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
        176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188,
        189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
       [150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162,
        163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
     

In [120]:
f = jax.pmap(
    jax.vmap(
        lambda x: jax.lax.pmax(x, axis_name=('p', 'v')),    # Finding global maximum across two axes
        axis_name='v'
    ),
    axis_name='p')

In [121]:
f(arr)

Array([[199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199],
       [199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199],
       [199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199],
       [199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
        199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199,
     

Using both axes

In [116]:
f = jax.pmap(
    jax.vmap(
        lambda x: jax.lax.pmax(x, axis_name='v')/jax.lax.pmax(x, axis_name='p'),
        axis_name='v'
    ),
    axis_name='p')

In [117]:
f(arr)

Array([[0.32666668, 0.3245033 , 0.3223684 , 0.32026145, 0.31818178,
        0.31612903, 0.31410256, 0.3121019 , 0.31012657, 0.3081761 ,
        0.30625   , 0.3043478 , 0.30246916, 0.30061352, 0.29878047,
        0.29696968, 0.2951807 , 0.2934132 , 0.2916667 , 0.28994083,
        0.2882353 , 0.28654972, 0.2848837 , 0.28323698, 0.2816092 ,
        0.28      , 0.2784091 , 0.27683613, 0.2752809 , 0.273743  ,
        0.27222222, 0.27071825, 0.26923078, 0.2677596 , 0.26630434,
        0.26486486, 0.26344085, 0.2620321 , 0.2606383 , 0.25925925,
        0.25789472, 0.2565445 , 0.25520834, 0.253886  , 0.2525773 ,
        0.25128207, 0.25000003, 0.24873096, 0.24747474, 0.24623115],
       [0.66      , 0.65562916, 0.6513158 , 0.64705884, 0.6428571 ,
        0.63870966, 0.6346154 , 0.6305733 , 0.62658226, 0.6226415 ,
        0.61875004, 0.6149068 , 0.61111116, 0.607362  , 0.6036585 ,
        0.59999996, 0.59638554, 0.5928144 , 0.58928573, 0.5857988 ,
        0.58235294, 0.57894737, 0.5755814 , 0.5

Global normalization using two axes at once

In [122]:
norm = jax.pmap(
    jax.vmap(
        lambda x: x/jax.lax.psum(x, axis_name=('p','v')),
        axis_name='v'
    ),
    axis_name='p')

In [123]:
narr = norm(arr)
narr.shape

(4, 50)

In [124]:
jnp.sum(narr)

Array(1., dtype=float32)

In [125]:
narr

Array([[0.0000000e+00, 5.0251256e-05, 1.0050251e-04, 1.5075377e-04,
        2.0100502e-04, 2.5125628e-04, 3.0150753e-04, 3.5175879e-04,
        4.0201005e-04, 4.5226130e-04, 5.0251256e-04, 5.5276381e-04,
        6.0301507e-04, 6.5326632e-04, 7.0351758e-04, 7.5376884e-04,
        8.0402009e-04, 8.5427135e-04, 9.0452260e-04, 9.5477386e-04,
        1.0050251e-03, 1.0552764e-03, 1.1055276e-03, 1.1557789e-03,
        1.2060301e-03, 1.2562814e-03, 1.3065326e-03, 1.3567839e-03,
        1.4070352e-03, 1.4572864e-03, 1.5075377e-03, 1.5577889e-03,
        1.6080402e-03, 1.6582914e-03, 1.7085427e-03, 1.7587940e-03,
        1.8090452e-03, 1.8592965e-03, 1.9095477e-03, 1.9597989e-03,
        2.0100502e-03, 2.0603016e-03, 2.1105527e-03, 2.1608039e-03,
        2.2110553e-03, 2.2613066e-03, 2.3115578e-03, 2.3618089e-03,
        2.4120603e-03, 2.4623116e-03],
       [2.5125628e-03, 2.5628139e-03, 2.6130653e-03, 2.6633167e-03,
        2.7135678e-03, 2.7638189e-03, 2.8140703e-03, 2.8643217e-03,
        2

Nested pmap

In [127]:
arr = jnp.array(range(4)).reshape(2,2)
arr

Array([[0, 1],
       [2, 3]], dtype=int32)

In [128]:
n = jax.pmap(
    jax.pmap(
        lambda x: x/jax.lax.psum(x, axis_name=('rows','cols')),
        axis_name='cols'
    ),
    axis_name='rows')

In [129]:
n(arr)

Array([[0.        , 0.16666667],
       [0.33333334, 0.5       ]], dtype=float32)

In [130]:
jnp.sum(n(arr))

Array(1., dtype=float32)

The same example using the decorator style

In [163]:
from functools import partial

In [164]:
@partial(jax.pmap, axis_name='rows')
@partial(jax.pmap, axis_name='cols')
def n(x):
  return x/jax.lax.psum(x, axis_name=('rows','cols'))

In [165]:
jnp.sum(n(arr))

Array(1., dtype=float32)

## Data-parallel neural network training

### Preparing data

In [166]:
import tensorflow as tf
import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

data, info = tfds.load(name="mnist",
                       data_dir=data_dir,
                       as_supervised=True,
                       with_info=True)

data_train = data['train']
data_test  = data['test']

2024-08-11 23:43:56.928815: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-11 23:43:56.950628: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-11 23:43:56.957397: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [167]:
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = info.features['label'].num_classes
NUM_DEVICES = jax.device_count()
BATCH_SIZE  = 32

In [168]:
def preprocess(img, label):
  """Resize and preprocess images."""
  return (tf.cast(img, tf.float32)/255.0), label

train_data = tfds.as_numpy(
    data_train.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)
test_data  = tfds.as_numpy(
    data_test.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)

In [169]:
len(train_data)

469

### Preparing MLP

In [170]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn import swish, logsumexp, one_hot

In [171]:
LAYER_SIZES = [28*28, 512, 10]
PARAM_SCALE = 0.01

In [172]:
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
  """Initialize all layers for a fully-connected neural network with given sizes"""

  def random_layer_params(m, n, key, scale=1e-2):
    """A helper function to randomly initialize weights and biases of a dense layer"""
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

In [173]:
def predict(params, image):
  """Function for per-example predictions."""
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = swish(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

batched_predict = vmap(predict, in_axes=(None, 0))


### Loss and update functions

In [174]:
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5
NUM_EPOCHS  = 20

In [175]:
from functools import partial

In [None]:
def loss(params, images, targets):
  """Categorical cross entropy loss function."""
  logits = batched_predict(params, images)
  log_preds = logits - jnp.expand_dims(logsumexp(logits, axis=1), 1)  # logsumexp trick https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
  return -jnp.mean(targets*log_preds)

@partial(jax.pmap, axis_name='devices', in_axes=(None, 0, 0, None), out_axes=(None,0))
def update(params, x, y, epoch_number):
  loss_value, grads = value_and_grad(loss)(params, x, y)
  grads = [(jax.lax.psum(dw, 'devices'), jax.lax.psum(db, 'devices'))
    for dw, db in grads]
  lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
  return [(w - lr * dw, b - lr * db)
          for (w, b), (dw, db) in zip(params, grads)], loss_value

### Section for debugging purposes

In [None]:
train_data_iter = iter(train_data)
x, y = next(train_data_iter)

In [None]:
x.shape, y.shape

In [None]:
x = jnp.reshape(x, (NUM_DEVICES, BATCH_SIZE, NUM_PIXELS))
y = jnp.reshape(one_hot(y, NUM_LABELS), (NUM_DEVICES, BATCH_SIZE, NUM_LABELS))
x.shape, y.shape

In [None]:
updated_params, loss_value = update(init_params, x, y, 0)

In [None]:
loss_value

### Training loop

In [None]:
@jit
def batch_accuracy(params, images, targets):
  images = jnp.reshape(images, (len(images), NUM_PIXELS))
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == targets)

def accuracy(params, data):
  accs = []
  for images, targets in data:
    accs.append(batch_accuracy(params, images, targets))
  return jnp.mean(jnp.array(accs))

In [None]:
import time

params = init_params
for epoch in range(NUM_EPOCHS):
  start_time = time.time()
  losses = []
  for x, y in train_data:
    num_elements = len(y)
    x = jnp.reshape(x, (NUM_DEVICES, num_elements//NUM_DEVICES, NUM_PIXELS))
    y = jnp.reshape(one_hot(y, NUM_LABELS), (NUM_DEVICES, num_elements//NUM_DEVICES, NUM_LABELS))
    params, loss_value = update(params, x, y, epoch)
    losses.append(jnp.sum(loss_value))
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_data)
  test_acc = accuracy(params, test_data)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))