## Using Cloud TPU and a Local runtime in Colab

Make preparations according to the following links:

* Creating a Cloud TPU https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#tpu-vms

* Preparing Jupyter and connect to a Local runtime https://research.google.com/colaboratory/local-runtimes.html


In [None]:
!pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
Collecting jax[tpu]>=0.2.16
  Downloading jax-0.4.3.tar.gz (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 4.4 MB/s 
[?25hCollecting numpy>=1.20
  Downloading numpy-1.24.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
[K     |████████████████████████████████| 17.3 MB 68.6 MB/s 
[?25hCollecting opt_einsum
  Downloading opt_einsum-3.3.0-py3-none-any.whl (65 kB)
[K     |████████████████████████████████| 65 kB 6.4 MB/s 
[?25hCollecting scipy>=1.5
  Downloading scipy-1.10.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.5 MB)
[K     |████████████████████████████████| 34.5 MB 89.0 MB/s 
[?25hCollecting jaxlib==0.4.3
  Downloading jaxlib-0.4.3-cp38-cp38-manylinux2014_x86_64.whl (72.2 MB)
[K     |████████████████████████████████| 72.2 MB 153 kB/s 
[?25hCollecting libtpu-nightly==0.1.dev20230207
  Downloading https://storage.googleapis.com/cloud-tpu-tpuvm-ar

Needed for pretty sharding visualization

In [None]:
!pip install rich

Collecting rich
  Downloading rich-13.3.1-py3-none-any.whl (239 kB)
[K     |████████████████████████████████| 239 kB 4.9 MB/s 
[?25hCollecting markdown-it-py<3.0.0,>=2.1.0
  Downloading markdown_it_py-2.1.0-py3-none-any.whl (84 kB)
[K     |████████████████████████████████| 84 kB 4.8 MB/s 
Collecting typing-extensions<5.0,>=4.0.0; python_version < "3.9"
  Downloading typing_extensions-4.5.0-py3-none-any.whl (27 kB)
Collecting mdurl~=0.1
  Downloading mdurl-0.1.2-py3-none-any.whl (10.0 kB)
Installing collected packages: mdurl, markdown-it-py, typing-extensions, rich
Successfully installed markdown-it-py-2.1.0 mdurl-0.1.2 rich-13.3.1 typing-extensions-4.5.0


In [None]:
import jax
import jax.numpy as jnp

In [None]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [None]:
jax.__version__

'0.4.3'

## Using Tensor Sharding

In [None]:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding, NamedSharding

In [None]:
from jax import random

In [None]:
import numpy as np

### Dot example

In [None]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [None]:
rng_key = random.PRNGKey(42)

In [None]:
vs = random.normal(rng_key, shape=(2_000_000,100))
v1s = vs[:1_000_000,:]
v2s = vs[1_000_000:,:]

v1s.shape, v2s.shape

((1000000, 100), (1000000, 100))

In [None]:
jax.debug.visualize_array_sharding(v1s)

### Positional sharding

In [None]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,1)))

In [None]:
sharding

PositionalSharding([[{TPU 0}]
                    [{TPU 1}]
                    [{TPU 2}]
                    [{TPU 3}]
                    [{TPU 6}]
                    [{TPU 7}]
                    [{TPU 4}]
                    [{TPU 5}]])

In [None]:
v1sp = jax.device_put(v1s, sharding)

In [None]:
type(v1sp)

jaxlib.xla_extension.Array

In [None]:
jax.debug.visualize_array_sharding(v1sp)

In [None]:
v2sp = jax.device_put(v2s, sharding)

In [None]:
jax.debug.visualize_array_sharding(v2sp)

Input is sharded across all the devices.

In [None]:
d = jax.vmap(dot)(v1sp, v2sp)

In [None]:
d.shape

(1000000,)

In [None]:
jax.debug.visualize_array_sharding(d)

In [None]:
%timeit jax.vmap(dot)(v1sp, v2sp).block_until_ready()

1.71 ms ± 64.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
%timeit jax.vmap(dot)(v1s, v2s).block_until_ready()

4.12 ms ± 24.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
jax.make_jaxpr(jax.vmap(dot))(v1s, v2s)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1000000,100][39m b[35m:f32[1000000,100][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[1000000][39m = dot_general[dimension_numbers=(([1], [1]), ([0], [0]))] a b
  [34m[22m[1min [39m[22m[22m(c,) }

In [None]:
modules = jax.jit(jax.vmap(dot)).lower(v1s, v2s).compile().compiler_ir()
for hlo in modules:
  print(hlo.to_string())

HloModule jit_dot, is_scheduled=true, entry_computation_layout={(f32[1000000,100]{0,1:T(8,128)},f32[1000000,100]{0,1:T(8,128)})->f32[1000000]{0:T(1024)}}, allow_spmd_sharding_propagation_to_output={true}

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_rhs = f32[]{:T(256)} parameter(1)
  %scalar_lhs = f32[]{:T(256)} parameter(0)
  ROOT %add = f32[]{:T(256)} add(f32[]{:T(256)} %scalar_lhs, f32[]{:T(256)} %scalar_rhs)
}

%fused_computation (param_0.2: f32[1000000,100], param_1.2: f32[1000000,100]) -> f32[1000000] {
  %param_0.2 = f32[1000000,100]{0,1:T(8,128)} parameter(0)
  %param_1.2 = f32[1000000,100]{0,1:T(8,128)} parameter(1)
  %multiply.1 = f32[1000000,100]{0,1:T(8,128)} multiply(f32[1000000,100]{0,1:T(8,128)} %param_0.2, f32[1000000,100]{0,1:T(8,128)} %param_1.2)
  %constant.1 = f32[]{:T(256)} constant(0)
  ROOT %reduce.1 = f32[1000000]{0:T(1024)} reduce(f32[1000000,100]{0,1:T(8,128)} %multiply.1, f32[]{:T(256)} %constant.1), dimensions={1}, to_

In [None]:
jax.make_jaxpr(jax.vmap(dot))(v1sp, v2sp)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1000000,100][39m b[35m:f32[1000000,100][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[1000000][39m = dot_general[dimension_numbers=(([1], [1]), ([0], [0]))] a b
  [34m[22m[1min [39m[22m[22m(c,) }

In [None]:
modules = jax.jit(jax.vmap(dot)).lower(v1sp, v2sp).compile().compiler_ir()
for hlo in modules:
  print(hlo.to_string())

HloModule jit_dot, is_scheduled=true, entry_computation_layout={(f32[125000,100]{0,1:T(8,128)},f32[125000,100]{0,1:T(8,128)})->f32[125000]{0:T(1024)}}, allow_spmd_sharding_propagation_to_output={true}

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_rhs = f32[]{:T(256)} parameter(1)
  %scalar_lhs = f32[]{:T(256)} parameter(0)
  ROOT %add = f32[]{:T(256)} add(f32[]{:T(256)} %scalar_lhs, f32[]{:T(256)} %scalar_rhs)
}

%fused_computation (param_0.2: f32[125000,100], param_1.2: f32[125000,100]) -> f32[125000] {
  %param_0.2 = f32[125000,100]{0,1:T(8,128)} parameter(0)
  %param_1.2 = f32[125000,100]{0,1:T(8,128)} parameter(1)
  %multiply.2 = f32[125000,100]{0,1:T(8,128)} multiply(f32[125000,100]{0,1:T(8,128)} %param_0.2, f32[125000,100]{0,1:T(8,128)} %param_1.2)
  %constant.2 = f32[]{:T(256)} constant(0)
  ROOT %reduce.2 = f32[125000]{0:T(1024)} reduce(f32[125000,100]{0,1:T(8,128)} %multiply.2, f32[]{:T(256)} %constant.2), dimensions={1}, to_apply=%scalar

In [None]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((4,2)))

In [None]:
v1sp = jax.device_put(v1s, sharding)
v2sp = jax.device_put(v2s, sharding)

In [None]:
jax.debug.visualize_array_sharding(v1sp)

In [None]:
d = jax.vmap(dot)(v1sp, v2sp)
d.shape

(1000000,)

In [None]:
jax.debug.visualize_array_sharding(d)

In [None]:
%timeit jax.vmap(dot)(v1sp, v2sp).block_until_ready()

1.9 ms ± 32.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
%timeit jax.vmap(dot)(v1s, v2s).block_until_ready()

4.07 ms ± 23.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Looking at HLO (note the all-reduce operation)

In [None]:
modules = jax.jit(jax.vmap(dot)).lower(v1sp, v2sp).compile().compiler_ir()
for hlo in modules:
  print(hlo.to_string())

HloModule jit_dot, is_scheduled=true, entry_computation_layout={(f32[250000,50]{0,1:T(8,128)},f32[250000,50]{0,1:T(8,128)})->f32[250000]{0:T(1024)}}, allow_spmd_sharding_propagation_to_output={true}

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_rhs = f32[]{:T(256)} parameter(1)
  %scalar_lhs = f32[]{:T(256)} parameter(0)
  ROOT %add = f32[]{:T(256)} add(f32[]{:T(256)} %scalar_lhs, f32[]{:T(256)} %scalar_rhs)
}

%fused_computation (param_0.2: f32[250000,50], param_1.2: f32[250000,50]) -> f32[250000] {
  %param_0.2 = f32[250000,50]{0,1:T(8,128)} parameter(0)
  %param_1.2 = f32[250000,50]{0,1:T(8,128)} parameter(1)
  %multiply.2 = f32[250000,50]{0,1:T(8,128)} multiply(f32[250000,50]{0,1:T(8,128)} %param_0.2, f32[250000,50]{0,1:T(8,128)} %param_1.2)
  %constant.2 = f32[]{:T(256)} constant(0)
  ROOT %reduce.2 = f32[250000]{0:T(1024)} reduce(f32[250000,50]{0,1:T(8,128)} %multiply.2, f32[]{:T(256)} %constant.2), dimensions={1}, to_apply=%scalar_add_compu

### Using replication

In [None]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((4,2)))

v1sp = jax.device_put(v1s, sharding.replicate(axis=1))

jax.debug.visualize_array_sharding(v1sp)

In [None]:
A = random.normal(rng_key, shape=(10000,2000))
B = random.normal(rng_key, shape=(2000,5000))

In [None]:
Ad = jax.device_put(A, sharding.reshape(4, 2).replicate(1))
Bd = jax.device_put(B, sharding.reshape(4, 2).replicate(0))

In [None]:
jax.debug.visualize_array_sharding(Ad)
jax.debug.visualize_array_sharding(Bd)

In [None]:
Cd = jnp.dot(Ad, Bd)

In [None]:
jax.debug.visualize_array_sharding(Cd)

In [None]:
C = A @ B

In [None]:
jax.numpy.array_equal(C,Cd)

Array(True, dtype=bool)

In [None]:
jax.debug.visualize_array_sharding(C)

In [None]:
C.shape, Cd.shape

((10000, 5000), (10000, 5000))

In [None]:
C[12,3], Cd[12,3]

(Array(43.027634, dtype=float32), Array(43.027634, dtype=float32))

In [None]:
Ca = jnp.dot(A, B)

In [None]:
jax.numpy.array_equal(C,Ca)

Array(True, dtype=bool)

In [None]:
jax.debug.visualize_array_sharding(Ca)

In [None]:
d = (Cd - C)

In [None]:
jnp.max(d), jnp.sum(d)

(Array(0., dtype=float32), Array(0., dtype=float32))

In [None]:
%timeit jnp.dot(Ad, Bd).block_until_ready()

2.61 ms ± 53.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit jnp.dot(A, B).block_until_ready()

10.6 ms ± 40.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%timeit (A@B).block_until_ready()

10.6 ms ± 40.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Using sharding constraints

In [None]:
from jax import jit

@jit
def distributed_mul(a, b):
  ad = jax.lax.with_sharding_constraint(a, sharding.replicate(1))
  bd = jax.lax.with_sharding_constraint(b, sharding.replicate(0))
  return jnp.dot(ad, bd)

In [None]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((4,2)))

In [None]:
jax.debug.visualize_array_sharding(A)
jax.debug.visualize_array_sharding(B)

In [None]:
d = distributed_mul(A, B)

In [None]:
jax.debug.visualize_array_sharding(d)

In [None]:
d.shape

(10000, 5000)

In [None]:
@jit
def nondistributed_mul(a, b):
  return jnp.dot(a, b)

In [None]:
dn = nondistributed_mul(A, B)

In [None]:
jax.debug.visualize_array_sharding(dn)

In [None]:
%timeit distributed_mul(A, B).block_until_ready()

22 ms ± 99.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
%timeit nondistributed_mul(A, B).block_until_ready()

10.6 ms ± 16 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Named sharding

In [None]:
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding

In [None]:
mesh = Mesh(mesh_utils.create_device_mesh((4,2)), axis_names=('batch', 'features'))
sharding = NamedSharding(mesh, P('batch', 'features'))

In [None]:
v1sp = jax.device_put(v1s, sharding)
v2sp = jax.device_put(v2s, sharding)

jax.debug.visualize_array_sharding(v1sp)

In [None]:
d = jax.vmap(dot)(v1sp, v2sp)
d.shape

(1000000,)

### Device placement policy and errors

Different devices:

In [None]:
sharding_a = PositionalSharding(np.array(jax.devices()[:4]).reshape(4,1))
sharding_b = PositionalSharding(np.array(jax.devices()[4:]).reshape(4,1))

In [None]:
sharding_a

PositionalSharding([[{TPU 0}]
                    [{TPU 1}]
                    [{TPU 2}]
                    [{TPU 3}]])

In [None]:
sharding_b

PositionalSharding([[{TPU 4}]
                    [{TPU 5}]
                    [{TPU 6}]
                    [{TPU 7}]])

In [None]:
v1sp = jax.device_put(v1s, sharding_a)
v2sp = jax.device_put(v2s, sharding_b)

In [None]:
jax.debug.visualize_array_sharding(v1sp)

In [None]:
jax.debug.visualize_array_sharding(v2sp)

In [None]:
d = jax.vmap(dot)(v1sp, v2sp)

ValueError: ignored

Different order:

In [None]:
sharding_a = PositionalSharding(np.array(jax.devices()).reshape(8,1))
sharding_b = PositionalSharding(np.array(jax.devices()[::-1]).reshape(8,1))

In [None]:
sharding_a, sharding_b

(PositionalSharding([[{TPU 0}]
                     [{TPU 1}]
                     [{TPU 2}]
                     [{TPU 3}]
                     [{TPU 4}]
                     [{TPU 5}]
                     [{TPU 6}]
                     [{TPU 7}]]),
 PositionalSharding([[{TPU 7}]
                     [{TPU 6}]
                     [{TPU 5}]
                     [{TPU 4}]
                     [{TPU 3}]
                     [{TPU 2}]
                     [{TPU 1}]
                     [{TPU 0}]]))

In [None]:
v1sp = jax.device_put(v1s, sharding_a)
v2sp = jax.device_put(v2s, sharding_b)

In [None]:
jax.debug.visualize_array_sharding(v1sp)

In [None]:
jax.debug.visualize_array_sharding(v2sp)

In [None]:
d = jax.vmap(dot)(v1sp, v2sp)

ValueError: ignored

In [None]:
d = jax.vmap(dot)(v1sp, v2s)

In [None]:
jax.debug.visualize_array_sharding(d)

## MLP example

### Preparing data

Install these modules if you created a new empty cloud machine 

In [None]:
!pip install tensorflow

Collecting tensorflow
  Downloading tensorflow-2.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (588.3 MB)
[K     |█████████████████▍              | 320.6 MB 149.1 MB/s eta 0:00:02

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[K     |███████████████████████████████▋| 581.6 MB 148.5 MB/s eta 0:00:01

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[K     |████████████████████████████████| 588.3 MB 7.0 kB/s 
[?25hCollecting absl-py>=1.0.0
  Downloading absl_py-1.4.0-py3-none-any.whl (126 kB)
[K     |████████████████████████████████| 126 kB 105.2 MB/s 
[?25hCollecting astunparse>=1.6.0
  Downloading astunparse-1.6.3-py2.py3-none-any.whl (12 kB)
Collecting flatbuffers>=2.0
  Downloading flatbuffers-23.1.21-py2.py3-none-any.whl (26 kB)
Collecting gast<=0.4.0,>=0.2.1
  Downloading gast-0.4.0-py3-none-any.whl (9.8 kB)
Collecting google-pasta>=0.1.1
  Downloading google_pasta-0.2.0-py3-none-any.whl (57 kB)
[K     |████████████████████████████████| 57 kB 6.8 MB/s 
[?25hCollecting grpcio<2.0,>=1.24.3
  Downloading grpcio-1.51.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.8 MB)
[K     |████████████████████████████████| 4.8 MB 88.6 MB/s 
[?25hCollecting h5py>=2.9.0
  Downloading h5py-3.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
[K     |████████████████████████████████| 4.7 MB 85.3 MB/s 


In [None]:
!pip install tensorflow_datasets

Collecting tensorflow_datasets
  Downloading tensorflow_datasets-4.8.2-py3-none-any.whl (5.3 MB)
[K     |████████████████████████████████| 5.3 MB 4.6 MB/s 
Collecting dill
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[K     |████████████████████████████████| 110 kB 104.4 MB/s 
[?25hCollecting dm-tree
  Downloading dm_tree-0.1.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (152 kB)
[K     |████████████████████████████████| 152 kB 106.9 MB/s 
[?25hCollecting etils[enp,epath]>=0.9.0
  Downloading etils-1.0.0-py3-none-any.whl (146 kB)
[K     |████████████████████████████████| 146 kB 92.6 MB/s 
Collecting promise
  Downloading promise-2.3.tar.gz (19 kB)
Collecting tensorflow-metadata
  Downloading tensorflow_metadata-1.12.0-py3-none-any.whl (52 kB)
[K     |████████████████████████████████| 52 kB 1.6 MB/s 
Collecting toml
  Downloading toml-0.10.2-py2.py3-none-any.whl (16 kB)
Collecting tqdm
  Downloading tqdm-4.64.1-py2.py3-none-any.whl (78 kB)
[K     |████████████

In [None]:
import jax
import tensorflow as tf
import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

data, info = tfds.load(name="mnist",
                       data_dir=data_dir,
                       as_supervised=True, 
                       with_info=True)

data_train = data['train']
data_test  = data['test']

2023-02-15 14:39:10.586863: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2023-02-15 14:39:11.233063: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2023-02-15 14:39:11.233197: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
  from .autonotebook import tqdm as notebook_tqdm


[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /tmp/tfds/mnist/3.0.1...[0m


Dl Completed...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 26.07 file/s]
2023-02-15 14:39:13.454293: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2023-02-15 14:39:13.454326: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)


[1mDataset mnist downloaded and prepared to /tmp/tfds/mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [None]:
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS 
NUM_LABELS = info.features['label'].num_classes
BATCH_SIZE  = 32 # total 60k samples
NUM_DEVICES = jax.device_count()

In [None]:
def preprocess(img, label):
  """Resize and preprocess images."""
  return (tf.cast(img, tf.float32)/255.0), label

train_data = tfds.as_numpy(
    data_train.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)
test_data  = tfds.as_numpy(
    data_test.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)

In [None]:
len(train_data)

235

### Preparing MLP

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn import swish, logsumexp, one_hot

In [None]:
LAYER_SIZES = [28*28, 512, 10]
PARAM_SCALE = 0.01

In [None]:
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
  """Initialize all layers for a fully-connected neural network with given sizes"""

  def random_layer_params(m, n, key, scale=1e-2):
    """A helper function to randomly initialize weights and biases of a dense layer""" 
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

In [None]:
def predict(params, image):
  """Function for per-example predictions."""
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = swish(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

batched_predict = vmap(predict, in_axes=(None, 0))

### Loss and update functions

In [None]:
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5
NUM_EPOCHS  = 20

In [None]:
def loss(params, images, targets):
  """Categorical cross entropy loss function."""
  logits = batched_predict(params, images)
  log_preds = logits - logsumexp(logits) # logsumexp trick https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
  return -jnp.mean(targets*log_preds)

@jit
def update(params, x, y, epoch_number):
  loss_value, grads = value_and_grad(loss)(params, x, y)
  lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
  return [(w - lr * dw, b - lr * db)
          for (w, b), (dw, db) in zip(params, grads)], loss_value

### Training loop

In [None]:
@jit
def batch_accuracy(params, images, targets):
  images = jnp.reshape(images, (len(images), NUM_PIXELS))
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == targets)

def accuracy(params, data):
  accs = []
  for images, targets in data:
    accs.append(batch_accuracy(params, images, targets))
  return jnp.mean(jnp.array(accs))

! Be careful. This example does not work with Colab TPU because of Colab bug, it only works with Cloud TPU.

https://github.com/google/jax/issues/8300

If you want to use TPU, you need to build your own runtime on Google Cloud Platform (https://cloud.google.com/tpu/docs/jax-pods) and connects to it using Jupyter (https://research.google.com/colaboratory/local-runtimes.html).

#### 8-way data parallelism

In [None]:
sharding = PositionalSharding(jax.devices()).reshape(8, 1)

In [None]:
import time

params = init_params
for epoch in range(NUM_EPOCHS):
  start_time = time.time()
  losses = []
  for x, y in train_data:
    x = jnp.reshape(x, (len(x), NUM_PIXELS))
    y = one_hot(y, NUM_LABELS)
    x = jax.device_put(x, sharding)
    y = jax.device_put(y, sharding)
    params = jax.device_put(params, sharding.replicate())
    params, loss_value = update(params, x, y, epoch)
    losses.append(jnp.sum(loss_value))
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_data)
  test_acc = accuracy(params, test_data)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 2.33 sec
Training set loss 0.6914460062980652
Training set accuracy 0.8648548126220703
Test set accuracy 0.874804675579071
Epoch 1 in 1.10 sec
Training set loss 0.6259850859642029
Training set accuracy 0.8882258534431458
Test set accuracy 0.893847644329071
Epoch 2 in 1.11 sec
Training set loss 0.6164931058883667
Training set accuracy 0.8978224396705627
Test set accuracy 0.901074230670929
Epoch 3 in 1.11 sec
Training set loss 0.6100647449493408
Training set accuracy 0.906527042388916
Test set accuracy 0.908984363079071
Epoch 4 in 1.09 sec
Training set loss 0.6044379472732544
Training set accuracy 0.9142231941223145
Test set accuracy 0.917285144329071
Epoch 5 in 1.14 sec
Training set loss 0.5995732545852661
Training set accuracy 0.9211103320121765
Test set accuracy 0.9248046875
Epoch 6 in 1.17 sec
Training set loss 0.5955986976623535
Training set accuracy 0.9261746406555176
Test set accuracy 0.929394543170929
Epoch 7 in 1.12 sec
Training set loss 0.5924407839775085
Training se

In [None]:
jax.debug.visualize_array_sharding(params[0][0])

#### 4-way data parallelism, 2-way tensor parallelism

In [None]:
sharding = PositionalSharding(jax.devices()).reshape(4, 2)

In [None]:
LAYER_SIZES = [28*28, 10000, 10000, 10]
PARAM_SCALE = 0.01

def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
  """Initialize all layers for a fully-connected neural network with given sizes"""

  def random_layer_params(m, n, key, scale=1e-2):
    """A helper function to randomly initialize weights and biases of a dense layer""" 
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

In [None]:
sharded_params = []
for i,(w,b) in enumerate(init_params):
  print(i, w.shape, b.shape)
  if i==0:
    w = jax.device_put(w, sharding.replicate(0))
    b = jax.device_put(b, sharding.replicate(0))
  elif i==1:
    w = jax.device_put(w, sharding.replicate(0))
    b = jax.device_put(b, sharding.replicate(0))
  elif i==2:
    w = jax.device_put(w, sharding.replicate())
    b = jax.device_put(b, sharding.replicate())
  sharded_params.append((w,b))


0 (10000, 784) (10000,)
1 (10000, 10000) (10000,)
2 (10, 10000) (10,)


In [None]:
for (w,b) in init_params:
  jax.debug.visualize_array_sharding(w)
  jax.debug.visualize_array_sharding(b)

In [None]:
for (w,b) in init_params:
  jax.debug.visualize_array_sharding(jax.device_put(w, sharding.replicate()))
  jax.debug.visualize_array_sharding(jax.device_put(b, sharding.replicate()))

In [None]:
for (w,b) in sharded_params:
  jax.debug.visualize_array_sharding(w)
  jax.debug.visualize_array_sharding(b)

In [None]:
import time

params = sharded_params
for epoch in range(NUM_EPOCHS):
  start_time = time.time()
  losses = []
  for x, y in train_data:
    x = jnp.reshape(x, (len(x), NUM_PIXELS))
    y = one_hot(y, NUM_LABELS)
    x = jax.device_put(x, sharding.replicate(1))
    y = jax.device_put(y, sharding.replicate(1))
    #params = jax.device_put(params, sharding.replicate())
    params, loss_value = update(params, x, y, epoch)
    losses.append(jnp.sum(loss_value))
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_data)
  test_acc = accuracy(params, test_data)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

  jax.debug.visualize_array_sharding(params[0][0])

Epoch 0 in 6.76 sec
Training set loss 0.6637930870056152
Training set accuracy 0.8757923245429993
Test set accuracy 0.8863281607627869


Epoch 1 in 4.04 sec
Training set loss 0.6224856376647949
Training set accuracy 0.890198290348053
Test set accuracy 0.8955078125


Epoch 2 in 4.01 sec
Training set loss 0.6166571974754333
Training set accuracy 0.8960161209106445
Test set accuracy 0.900195300579071


Epoch 3 in 3.99 sec
Training set loss 0.6135416626930237
Training set accuracy 0.8999667167663574
Test set accuracy 0.9034180045127869


Epoch 4 in 4.01 sec
Training set loss 0.6112022399902344
Training set accuracy 0.9033576846122742
Test set accuracy 0.907421886920929


Epoch 5 in 4.02 sec
Training set loss 0.6091797351837158
Training set accuracy 0.9063663482666016
Test set accuracy 0.909375011920929


Epoch 6 in 4.03 sec
Training set loss 0.6072937250137329
Training set accuracy 0.9090425372123718
Test set accuracy 0.9117187857627869


Epoch 7 in 4.04 sec
Training set loss 0.6054385304450989
Training set accuracy 0.911436140537262
Test set accuracy 0.9149414300918579


Epoch 8 in 3.99 sec
Training set loss 0.6035288572311401
Training set accuracy 0.9151983261108398
Test set accuracy 0.918652355670929


Epoch 9 in 4.01 sec
Training set loss 0.6015028953552246
Training set accuracy 0.9181182980537415
Test set accuracy 0.921191394329071


Epoch 10 in 3.99 sec
Training set loss 0.5993684530258179
Training set accuracy 0.9215425252914429
Test set accuracy 0.923632800579071


Epoch 11 in 4.00 sec
Training set loss 0.5972490906715393
Training set accuracy 0.9245678186416626
Test set accuracy 0.9267578125


Epoch 12 in 4.01 sec
Training set loss 0.5953171849250793
Training set accuracy 0.9276872873306274
Test set accuracy 0.92822265625


Epoch 13 in 4.04 sec
Training set loss 0.5936307907104492
Training set accuracy 0.9301639795303345
Test set accuracy 0.930371105670929


Epoch 14 in 4.08 sec
Training set loss 0.592153787612915
Training set accuracy 0.932541012763977
Test set accuracy 0.932812511920929


Epoch 15 in 4.00 sec
Training set loss 0.5908187627792358
Training set accuracy 0.9346630573272705
Test set accuracy 0.934374988079071


Epoch 16 in 4.00 sec
Training set loss 0.5895962715148926
Training set accuracy 0.9365580081939697
Test set accuracy 0.9364258050918579


Epoch 17 in 4.05 sec
Training set loss 0.5884602665901184
Training set accuracy 0.9381703734397888
Test set accuracy 0.9383789300918579


Epoch 18 in 4.00 sec
Training set loss 0.58741295337677
Training set accuracy 0.9402149319648743
Test set accuracy 0.9404296875


Epoch 19 in 3.99 sec
Training set loss 0.5864367485046387
Training set accuracy 0.9415946006774902
Test set accuracy 0.941699206829071


**!!! Do not forget to shutdown your Cloud TPU, or you'll spend much money on it!!!**