## Using Cloud TPU and a Local runtime in Colab

Make preparations according to Appendix C, or the following links:

* Creating a Cloud TPU https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#tpu-vms

* Preparing Jupyter and connect to a Local runtime https://research.google.com/colaboratory/local-runtimes.html


In [None]:
!pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Needed for pretty sharding visualization

In [None]:
!pip install rich

In [1]:
import jax
import jax.numpy as jnp

In [2]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

In [3]:
jax.__version__

'0.4.31'

## Using Tensor Sharding

In [4]:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding, NamedSharding

In [5]:
from jax import random

In [6]:
import numpy as np

### Dot example

In [7]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [8]:
rng_key = random.PRNGKey(42)

In [9]:
vs = random.normal(rng_key, shape=(8_000,10_000))
v1s = vs[:4_000,:]
v2s = vs[4_000:,:]

v1s.shape, v2s.shape

((4000, 10000), (4000, 10000))

In [10]:
jax.debug.visualize_array_sharding(v1s)

### Positional sharding

In [11]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((4,1)))

In [12]:
sharding

PositionalSharding([[{TPU 0}]
                    [{TPU 2}]
                    [{TPU 1}]
                    [{TPU 3}]], shape=(4, 1))

In [13]:
v1sp = jax.device_put(v1s, sharding)

In [15]:
type(v1s)

jaxlib.xla_extension.ArrayImpl

In [14]:
type(v1sp)

jaxlib.xla_extension.ArrayImpl

In [16]:
jax.debug.visualize_array_sharding(v1sp)

In [17]:
v2sp = jax.device_put(v2s, sharding)

In [18]:
jax.debug.visualize_array_sharding(v2sp)

Input is sharded across all the devices.

In [19]:
d = jax.vmap(dot)(v1sp, v2sp)

In [20]:
d.shape

(4000,)

In [21]:
jax.debug.visualize_array_sharding(d)

In [22]:
%timeit jax.vmap(dot)(v1sp, v2sp).block_until_ready()

570 μs ± 3.61 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [23]:
%timeit jax.vmap(dot)(v1s, v2s).block_until_ready()

738 μs ± 2.69 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [24]:
jax.make_jaxpr(jax.vmap(dot))(v1s, v2s)

{ lambda ; a:f32[4000,10000] b:f32[4000,10000]. let
    c:f32[4000] = dot_general[
      dimension_numbers=(([1], [1]), ([0], [0]))
      preferred_element_type=float32
    ] a b
  in (c,) }

In [29]:
modules = jax.jit(jax.vmap(dot)).lower(v1s, v2s).compile().as_text()
print(modules)
#for hlo in modules:
#  print(hlo.to_string())

HloModule jit_dot, is_scheduled=true, entry_computation_layout={(f32[4000,10000]{1,0:T(8,128)}, f32[4000,10000]{1,0:T(8,128)})->f32[4000]{0:T(1024)}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_rhs = f32[]{:T(128)} parameter(1)
  %scalar_lhs = f32[]{:T(128)} parameter(0)
  ROOT %add = f32[]{:T(128)} add(f32[]{:T(128)} %scalar_lhs, f32[]{:T(128)} %scalar_rhs)
}

%fused_computation (param_0.2: f32[4000,10000], param_1.2: f32[4000,10000]) -> f32[4000] {
  %param_0.2 = f32[4000,10000]{1,0:T(8,128)} parameter(0)
  %param_1.2 = f32[4000,10000]{1,0:T(8,128)} parameter(1)
  %multiply.1 = f32[4000,10000]{1,0:T(8,128)} multiply(f32[4000,10000]{1,0:T(8,128)} %param_0.2, f32[4000,10000]{1,0:T(8,128)} %param_1.2)
  %constant.1 = f32[]{:T(128)} constant(0)
  ROOT %reduce.1 = f32[4000]{0:T(1024)} reduce(f32[4000,10000]{1,0:T(8,128)} %multiply.1, f32[]{:

In [26]:
jax.make_jaxpr(jax.vmap(dot))(v1sp, v2sp)

{ lambda ; a:f32[4000,10000] b:f32[4000,10000]. let
    c:f32[4000] = dot_general[
      dimension_numbers=(([1], [1]), ([0], [0]))
      preferred_element_type=float32
    ] a b
  in (c,) }

In [30]:
modules = jax.jit(jax.vmap(dot)).lower(v1sp, v2sp).compile().as_text()
print(modules)
#for hlo in modules:
#  print(hlo.to_string())

HloModule jit_dot, is_scheduled=true, entry_computation_layout={(f32[1000,10000]{1,0:T(8,128)}, f32[1000,10000]{1,0:T(8,128)})->f32[1000]{0:T(1024)}}, allow_spmd_sharding_propagation_to_parameters={false,false}, allow_spmd_sharding_propagation_to_output={true}, num_partitions=4

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_rhs = f32[]{:T(128)} parameter(1)
  %scalar_lhs = f32[]{:T(128)} parameter(0)
  ROOT %add = f32[]{:T(128)} add(f32[]{:T(128)} %scalar_lhs, f32[]{:T(128)} %scalar_rhs)
}

%fused_computation (param_0.2: f32[1000,10000], param_1.2: f32[1000,10000]) -> f32[1000] {
  %param_0.2 = f32[1000,10000]{1,0:T(8,128)} parameter(0)
  %param_1.2 = f32[1000,10000]{1,0:T(8,128)} parameter(1)
  %multiply.2 = f32[1000,10000]{1,0:T(8,128)} multiply(f32[1000,10000]{1,0:T(8,128)} %param_0.2, f32[1000,10000]{1,0:T(8,128)} %param_1.2)
  %constant.2 = f32[]{:T(128)} constant(0)
  ROOT %reduce.2 = f32[1000]{0:T(1024)} reduce(f32[1000,10000]{1,0:T(8,128)} 

### 2D mesh example


In [31]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((2,2)))

In [32]:
v1sp = jax.device_put(v1s, sharding)
v2sp = jax.device_put(v2s, sharding)

In [33]:
jax.debug.visualize_array_sharding(v1sp)

In [34]:
d = jax.vmap(dot)(v1sp, v2sp)
d.shape

(4000,)

In [35]:
jax.debug.visualize_array_sharding(d)

In [36]:
%timeit jax.vmap(dot)(v1sp, v2sp).block_until_ready()

583 μs ± 6.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [37]:
%timeit jax.vmap(dot)(v1s, v2s).block_until_ready()

741 μs ± 5.48 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Looking at HLO (note the all-reduce operation)

In [38]:
modules = jax.jit(jax.vmap(dot)).lower(v1sp, v2sp).compile().as_text()
print(modules)
#for hlo in modules:
#  print(hlo.to_string())

HloModule jit_dot, is_scheduled=true, entry_computation_layout={(f32[2000,5000]{1,0:T(8,128)}, f32[2000,5000]{1,0:T(8,128)})->f32[2000]{0:T(1024)}}, allow_spmd_sharding_propagation_to_parameters={false,false}, allow_spmd_sharding_propagation_to_output={true}, num_partitions=4

%scalar_add_computation.clone (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
  %scalar_rhs.1 = f32[]{:T(128)} parameter(1)
  %scalar_lhs.1 = f32[]{:T(128)} parameter(0)
  ROOT %add.1 = f32[]{:T(128)} add(f32[]{:T(128)} %scalar_lhs.1, f32[]{:T(128)} %scalar_rhs.1)
}

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_rhs = f32[]{:T(128)} parameter(1)
  %scalar_lhs = f32[]{:T(128)} parameter(0)
  ROOT %add = f32[]{:T(128)} add(f32[]{:T(128)} %scalar_lhs, f32[]{:T(128)} %scalar_rhs)
}

%fused_computation (param_0.2: f32[2000,5000], param_1.2: f32[2000,5000]) -> f32[2000] {
  %param_0.2 = f32[2000,5000]{1,0:T(8,128)} parameter(0)
  %param_1.2 = f32[2000,5000]{1,0:T(8,128)} para

### Using replication

In [128]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((2,2)))

v1sp = jax.device_put(v1s, sharding.replicate(axis=1))

jax.debug.visualize_array_sharding(v1sp)

In [129]:
A = random.normal(rng_key, shape=(10000,2000))
B = random.normal(rng_key, shape=(2000,5000))

In [134]:
Ad = jax.device_put(A, sharding.replicate(0))
Bd = jax.device_put(B, sharding.replicate())

In [135]:
jax.debug.visualize_array_sharding(Ad)
jax.debug.visualize_array_sharding(Bd)

In [136]:
Cd = jnp.dot(Ad, Bd)

In [137]:
jax.debug.visualize_array_sharding(Cd)

In [45]:
C = A @ B

In [46]:
jax.numpy.array_equal(C,Cd)

Array(False, dtype=bool)

In [47]:
jax.debug.visualize_array_sharding(C)

In [48]:
C.shape, Cd.shape

((10000, 5000), (10000, 5000))

In [49]:
C[12,3], Cd[12,3]

(Array(43.027626, dtype=float32), Array(43.027634, dtype=float32))

In [50]:
Ca = jnp.dot(A, B)

In [51]:
jax.numpy.array_equal(C,Ca)

Array(True, dtype=bool)

In [52]:
jax.debug.visualize_array_sharding(Ca)

In [53]:
d = (Cd - C)

In [54]:
jnp.max(d), jnp.sum(d)

(Array(6.1035156e-05, dtype=float32), Array(0.01788047, dtype=float32))

In [55]:
%timeit jnp.dot(Ad, Bd).block_until_ready()

400 μs ± 2.26 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [56]:
%timeit jnp.dot(A, B).block_until_ready()

958 μs ± 4.32 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [57]:
%timeit (A@B).block_until_ready()

956 μs ± 3.77 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [58]:
modules = jnp.dot.lower(A, B).compile().as_text()
print(modules)
#for hlo in modules:
#  print(hlo.to_string())

HloModule jit_dot, is_scheduled=true, entry_computation_layout={(f32[10000,2000]{0,1:T(8,128)}, f32[2000,5000]{1,0:T(8,128)})->f32[10000,5000]{0,1:T(8,128)}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}

%bitcast_fusion (bf16input: f32[10000,2000]) -> f32[10000,2000] {
  %bf16input = f32[10000,2000]{0,1:T(8,128)S(3)} parameter(0)
  ROOT %bitcast = f32[10000,2000]{0,1:T(8,128)} bitcast(f32[10000,2000]{0,1:T(8,128)S(3)} %bf16input)
}

%bitcast_fusion.1 (bf16input.1: f32[2000,5000]) -> f32[2000,5000] {
  %bf16input.1 = f32[2000,5000]{1,0:T(8,128)} parameter(0)
  ROOT %bitcast.1 = f32[2000,5000]{1,0:T(8,128)} bitcast(f32[2000,5000]{1,0:T(8,128)} %bf16input.1)
}

%fused_computation (param_0: f32[10000,2000], param_1: f32[2000,5000]) -> f32[10000,5000] {
  %param_0 = f32[10000,2000]{0,1:T(8,128)S(3)} parameter(0)
  %fusion.1 = f32[10000,2000]{0,1:T(8,128)} fusion(f32[10000,2000]{0,1:T(8,128)S(3)} %param_0), kind=kLoop, calls=%bi

In [59]:
modules = jnp.dot.lower(Ad, Bd).compile().as_text()
print(modules)
#for hlo in modules:
#  print(hlo.to_string())

HloModule jit_dot, is_scheduled=true, entry_computation_layout={(f32[5000,2000]{1,0:T(8,128)}, f32[2000,2500]{1,0:T(8,128)})->f32[5000,2500]{1,0:T(8,128)}}, allow_spmd_sharding_propagation_to_parameters={false,false}, allow_spmd_sharding_propagation_to_output={true}, num_partitions=4

%bitcast_fusion (bf16input: f32[5000,2000]) -> f32[5000,2000] {
  %bf16input = f32[5000,2000]{1,0:T(8,128)S(3)} parameter(0)
  ROOT %bitcast = f32[5000,2000]{1,0:T(8,128)} bitcast(f32[5000,2000]{1,0:T(8,128)S(3)} %bf16input)
}

%bitcast_fusion.1 (bf16input.1: f32[2000,2500]) -> f32[2000,2500] {
  %bf16input.1 = f32[2000,2500]{1,0:T(8,128)} parameter(0)
  ROOT %bitcast.1 = f32[2000,2500]{1,0:T(8,128)} bitcast(f32[2000,2500]{1,0:T(8,128)} %bf16input.1)
}

%fused_computation (param_0: f32[5000,2000], param_1: f32[2000,2500]) -> f32[5000,2500] {
  %param_0 = f32[5000,2000]{1,0:T(8,128)S(3)} parameter(0)
  %fusion.1 = f32[5000,2000]{1,0:T(8,128)} fusion(f32[5000,2000]{1,0:T(8,128)S(3)} %param_0), kind=kLoop, c

### Using sharding constraints

In [60]:
from jax import jit
from functools import partial

@partial(jax.jit, static_argnums=2)
def distributed_mul(a, b, sharding):
  ad = jax.lax.with_sharding_constraint(a, sharding.replicate(1))
  bd = jax.lax.with_sharding_constraint(b, sharding.replicate(0))
  return jnp.dot(ad, bd)

In [61]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((2,2)))

In [62]:
jax.debug.visualize_array_sharding(A)
jax.debug.visualize_array_sharding(B)

In [63]:
d = distributed_mul(A, B, sharding)

In [64]:
jax.debug.visualize_array_sharding(d)

In [65]:
d.shape

(10000, 5000)

In [66]:
@jit
def nondistributed_mul(a, b):
  return jnp.dot(a, b)

In [67]:
dn = nondistributed_mul(A, B)

In [68]:
jax.debug.visualize_array_sharding(dn)

In [71]:
%timeit distributed_mul(A, B, sharding).block_until_ready()

3.99 ms ± 33.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [72]:
%timeit nondistributed_mul(A, B).block_until_ready()

953 μs ± 351 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [73]:
modules = distributed_mul.lower(A, B, sharding).compile().as_text()
print(modules)
#for hlo in modules:
#  print(hlo.to_string())

HloModule jit_distributed_mul, is_scheduled=true, entry_computation_layout={(f32[5000,2000]{1,0:T(8,128)}, f32[2000,2500]{1,0:T(8,128)})->f32[5000,2500]{1,0:T(8,128)}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, num_partitions=4

%bitcast_fusion (bf16input: f32[5000,2000]) -> f32[5000,2000] {
  %bf16input = f32[5000,2000]{1,0:T(8,128)S(3)} parameter(0)
  ROOT %bitcast = f32[5000,2000]{1,0:T(8,128)} bitcast(f32[5000,2000]{1,0:T(8,128)S(3)} %bf16input)
}

%bitcast_fusion.1 (bf16input.1: f32[2000,2500]) -> f32[2000,2500] {
  %bf16input.1 = f32[2000,2500]{1,0:T(8,128)} parameter(0)
  ROOT %bitcast.1 = f32[2000,2500]{1,0:T(8,128)} bitcast(f32[2000,2500]{1,0:T(8,128)} %bf16input.1)
}

%fused_computation (param_0: f32[5000,2000], param_1: f32[2000,2500]) -> f32[5000,2500] {
  %param_0 = f32[5000,2000]{1,0:T(8,128)S(3)} parameter(0)
  %fusion.1 = f32[5000,2000]{1,0:T(8,128)} fusion(f32[5000,2000]{1,0:T(8,128)S(3)} %param_0), kin

### Named sharding

In [74]:
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding

In [75]:
mesh = Mesh(mesh_utils.create_device_mesh((2,2)), axis_names=('batch', 'features'))
sharding = NamedSharding(mesh, P('batch', 'features'))

In [76]:
v1sp = jax.device_put(v1s, sharding)
v2sp = jax.device_put(v2s, sharding)

jax.debug.visualize_array_sharding(v1sp)

In [77]:
d = jax.vmap(dot)(v1sp, v2sp)
d.shape

(4000,)

### Device placement policy and errors

Different devices:

In [78]:
sharding_a = PositionalSharding(np.array(jax.devices()[:2]).reshape(2,1))
sharding_b = PositionalSharding(np.array(jax.devices()[2:]).reshape(2,1))

In [79]:
sharding_a

PositionalSharding([[{TPU 0}]
                    [{TPU 1}]], shape=(2, 1))

In [80]:
sharding_b

PositionalSharding([[{TPU 2}]
                    [{TPU 3}]], shape=(2, 1))

In [81]:
v1sp = jax.device_put(v1s, sharding_a)
v2sp = jax.device_put(v2s, sharding_b)

In [82]:
jax.debug.visualize_array_sharding(v1sp)

In [83]:
jax.debug.visualize_array_sharding(v2sp)

In [84]:
d = jax.vmap(dot)(v1sp, v2sp)

ValueError: Received incompatible devices for jitted computation. Got ARG_SHARDING with device ids [0, 1] on platform TPU and ARG_SHARDING with device ids [2, 3] on platform TPU

Different order:

In [85]:
sharding_a = PositionalSharding(np.array(jax.devices()).reshape(4,1))
sharding_b = PositionalSharding(np.array(jax.devices()[::-1]).reshape(4,1))

In [86]:
sharding_a, sharding_b

(PositionalSharding([[{TPU 0}]
                     [{TPU 1}]
                     [{TPU 2}]
                     [{TPU 3}]], shape=(4, 1)),
 PositionalSharding([[{TPU 3}]
                     [{TPU 2}]
                     [{TPU 1}]
                     [{TPU 0}]], shape=(4, 1)))

In [87]:
v1sp = jax.device_put(v1s, sharding_a)
v2sp = jax.device_put(v2s, sharding_b)

In [88]:
jax.debug.visualize_array_sharding(v1sp)

In [89]:
jax.debug.visualize_array_sharding(v2sp)

In [90]:
d = jax.vmap(dot)(v1sp, v2sp)

ValueError: Received incompatible devices for jitted computation. Got ARG_SHARDING with device ids [0, 1, 2, 3] on platform TPU and ARG_SHARDING with device ids [3, 2, 1, 0] on platform TPU

In [91]:
d = jax.vmap(dot)(v1sp, v2s)

In [92]:
jax.debug.visualize_array_sharding(d)

## MLP example

### Preparing data

Install these modules if you created a new empty cloud machine

In [None]:
!pip install tensorflow

In [None]:
!pip install tensorflow_datasets

In [93]:
import jax
import tensorflow as tf
import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

data, info = tfds.load(name="mnist",
                       data_dir=data_dir,
                       as_supervised=True,
                       with_info=True)

data_train = data['train']
data_test  = data['test']

2024-08-12 04:25:43.085096: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-12 04:25:43.107179: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-12 04:25:43.113934: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [94]:
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = info.features['label'].num_classes
BATCH_SIZE  = 32 # total 60k samples
NUM_DEVICES = jax.device_count()

In [95]:
def preprocess(img, label):
  """Resize and preprocess images."""
  return (tf.cast(img, tf.float32)/255.0), label

train_data = tfds.as_numpy(
    data_train.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)
test_data  = tfds.as_numpy(
    data_test.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)

In [96]:
len(train_data)

469

### Preparing MLP

In [97]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn import swish, logsumexp, one_hot

In [98]:
LAYER_SIZES = [28*28, 512, 10]
PARAM_SCALE = 0.01

In [99]:
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
  """Initialize all layers for a fully-connected neural network with given sizes"""

  def random_layer_params(m, n, key, scale=1e-2):
    """A helper function to randomly initialize weights and biases of a dense layer"""
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

In [100]:
def predict(params, image):
  """Function for per-example predictions."""
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = swish(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

batched_predict = vmap(predict, in_axes=(None, 0))

### Loss and update functions

In [101]:
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5
NUM_EPOCHS  = 20

In [102]:
def loss(params, images, targets):
  """Categorical cross entropy loss function."""
  logits = batched_predict(params, images)
  log_preds = logits - jnp.expand_dims(logsumexp(logits, axis=1), 1) # logsumexp trick https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
  return -jnp.mean(targets*log_preds)

@jit
def update(params, x, y, epoch_number):
  loss_value, grads = value_and_grad(loss)(params, x, y)
  lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
  return [(w - lr * dw, b - lr * db)
          for (w, b), (dw, db) in zip(params, grads)], loss_value

### Training loop

In [103]:
@jit
def batch_accuracy(params, images, targets):
  images = jnp.reshape(images, (len(images), NUM_PIXELS))
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == targets)

def accuracy(params, data):
  accs = []
  for images, targets in data:
    accs.append(batch_accuracy(params, images, targets))
  return jnp.mean(jnp.array(accs))

#### 8-way data parallelism

In [104]:
sharding = PositionalSharding(jax.devices()).reshape(4, 1)

In [109]:
import time

params = init_params
for epoch in range(NUM_EPOCHS):
  start_time = time.time()
  losses = []
  for x, y in train_data:
    x = jnp.reshape(x, (len(x), NUM_PIXELS))
    y = one_hot(y, NUM_LABELS)
    x = jax.device_put(x, sharding)
    y = jax.device_put(y, sharding)
    params = jax.device_put(params, sharding.replicate())
    params, loss_value = update(params, x, y, epoch)
    losses.append(loss_value)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_data)
  test_acc = accuracy(params, test_data)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 0.92 sec
Training set loss 0.0822610855102539
Training set accuracy 0.8977934122085571
Test set accuracy 0.9017009735107422
Epoch 1 in 0.88 sec
Training set loss 0.033943504095077515
Training set accuracy 0.9136738777160645
Test set accuracy 0.915249228477478
Epoch 2 in 0.97 sec
Training set loss 0.030296433717012405
Training set accuracy 0.9209200143814087
Test set accuracy 0.9210838675498962
Epoch 3 in 0.96 sec
Training set loss 0.027934584766626358
Training set accuracy 0.9271888732910156
Test set accuracy 0.9252373576164246
Epoch 4 in 0.97 sec
Training set loss 0.025554420426487923
Training set accuracy 0.9339630603790283
Test set accuracy 0.9323576092720032
Epoch 5 in 0.99 sec
Training set loss 0.023182667791843414
Training set accuracy 0.9401763677597046
Test set accuracy 0.938686728477478
Epoch 6 in 0.95 sec
Training set loss 0.02106645330786705
Training set accuracy 0.9453791379928589
Test set accuracy 0.9431368708610535
Epoch 7 in 0.92 sec
Training set loss 0.019250

In [108]:

len(losses), losses

(469,
 [Array(0.00715537, dtype=float32),
  Array(0.00720292, dtype=float32),
  Array(0.00704907, dtype=float32),
  Array(0.00552397, dtype=float32),
  Array(0.00984609, dtype=float32),
  Array(0.01049436, dtype=float32),
  Array(0.00601297, dtype=float32),
  Array(0.0108844, dtype=float32),
  Array(0.00592147, dtype=float32),
  Array(0.00666778, dtype=float32),
  Array(0.00456971, dtype=float32),
  Array(0.00707005, dtype=float32),
  Array(0.00967919, dtype=float32),
  Array(0.01597867, dtype=float32),
  Array(0.01205531, dtype=float32),
  Array(0.00985306, dtype=float32),
  Array(0.00408849, dtype=float32),
  Array(0.01805894, dtype=float32),
  Array(0.01227613, dtype=float32),
  Array(0.00526123, dtype=float32),
  Array(0.01130037, dtype=float32),
  Array(0.01170993, dtype=float32),
  Array(0.00646852, dtype=float32),
  Array(0.02234733, dtype=float32),
  Array(0.01576995, dtype=float32),
  Array(0.00691734, dtype=float32),
  Array(0.01684645, dtype=float32),
  Array(0.00574006, dty

In [110]:
jax.debug.visualize_array_sharding(params[0][0])

#### 4-way data parallelism, 2-way tensor parallelism

In [162]:
sharding = PositionalSharding(jax.devices()).reshape(2, 2)

In [163]:
LAYER_SIZES = [28*28, 10000, 10000, 10]
PARAM_SCALE = 0.01

def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
  """Initialize all layers for a fully-connected neural network with given sizes"""

  def random_layer_params(m, n, key, scale=1e-2):
    """A helper function to randomly initialize weights and biases of a dense layer"""
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

In [164]:
sharded_params = []
for i,(w,b) in enumerate(init_params):
  print(i, w.shape, b.shape)
  if i==0:
    w = jax.device_put(w, sharding)
    b = jax.device_put(b, sharding.replicate())
  elif i==1:
    w = jax.device_put(w, sharding)
    b = jax.device_put(b, sharding.replicate())
  elif i==2:
    w = jax.device_put(w, sharding.replicate())
    b = jax.device_put(b, sharding.replicate())
  sharded_params.append((w,b))


0 (10000, 784) (10000,)
1 (10000, 10000) (10000,)
2 (10, 10000) (10,)


In [165]:
for (w,b) in init_params:
  jax.debug.visualize_array_sharding(w)
  jax.debug.visualize_array_sharding(b)

In [166]:
for (w,b) in init_params:
  jax.debug.visualize_array_sharding(jax.device_put(w, sharding.replicate()))
  jax.debug.visualize_array_sharding(jax.device_put(b, sharding.replicate()))

In [167]:
for (w,b) in sharded_params:
  jax.debug.visualize_array_sharding(w)
  jax.debug.visualize_array_sharding(b)

In [168]:
import time

params = sharded_params
for epoch in range(NUM_EPOCHS):
  start_time = time.time()
  losses = []
  for x, y in train_data:
    x = jnp.reshape(x, (len(x), NUM_PIXELS))
    y = one_hot(y, NUM_LABELS)
    x = jax.device_put(x, sharding.replicate(1))
    y = jax.device_put(y, sharding.replicate(1))
    #params = jax.device_put(params, sharding.replicate())
    params, loss_value = update(params, x, y, epoch)
    losses.append(jnp.sum(loss_value))
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_data)
  test_acc = accuracy(params, test_data)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

  jax.debug.visualize_array_sharding(params[0][0])

Epoch 0 in 1.01 sec
Training set loss 0.06498625129461288
Training set accuracy 0.8998367786407471
Test set accuracy 0.9048655033111572


Epoch 1 in 1.00 sec
Training set loss 0.03310569375753403
Training set accuracy 0.9135960936546326
Test set accuracy 0.9148536324501038


Epoch 2 in 1.01 sec
Training set loss 0.030245665460824966
Training set accuracy 0.9188210964202881
Test set accuracy 0.918710470199585


Epoch 3 in 0.94 sec
Training set loss 0.028722697868943214
Training set accuracy 0.9227911829948425
Test set accuracy 0.9221717119216919


Epoch 4 in 0.98 sec
Training set loss 0.027605729177594185
Training set accuracy 0.9255896806716919
Test set accuracy 0.9243473410606384


Epoch 5 in 1.05 sec
Training set loss 0.026645950973033905
Training set accuracy 0.9283826947212219
Test set accuracy 0.9260284900665283


Epoch 6 in 0.92 sec
Training set loss 0.02573257125914097
Training set accuracy 0.930936872959137
Test set accuracy 0.9283030033111572


Epoch 7 in 1.02 sec
Training set loss 0.024794965982437134
Training set accuracy 0.9339630603790283
Test set accuracy 0.9300830960273743


Epoch 8 in 0.97 sec
Training set loss 0.023775801062583923
Training set accuracy 0.9370447397232056
Test set accuracy 0.9322587251663208


Epoch 9 in 0.99 sec
Training set loss 0.02263575978577137
Training set accuracy 0.9405428767204285
Test set accuracy 0.9359177350997925


Epoch 10 in 1.00 sec
Training set loss 0.021378550678491592
Training set accuracy 0.9442464113235474
Test set accuracy 0.9404668211936951


Epoch 11 in 1.02 sec
Training set loss 0.02007247321307659
Training set accuracy 0.9480610489845276
Test set accuracy 0.9432357549667358


Epoch 12 in 1.00 sec
Training set loss 0.01881973072886467
Training set accuracy 0.9508928656578064
Test set accuracy 0.9471914768218994


Epoch 13 in 0.93 sec
Training set loss 0.01768065057694912
Training set accuracy 0.9535414576530457
Test set accuracy 0.9497627019882202


Epoch 14 in 0.95 sec
Training set loss 0.016660500317811966
Training set accuracy 0.9561067819595337
Test set accuracy 0.9510482549667358


Epoch 15 in 1.09 sec
Training set loss 0.015744229778647423
Training set accuracy 0.9581390023231506
Test set accuracy 0.9555973410606384


Epoch 16 in 0.94 sec
Training set loss 0.014912167564034462
Training set accuracy 0.9604044556617737
Test set accuracy 0.9575752019882202


Epoch 17 in 0.94 sec
Training set loss 0.014147961512207985
Training set accuracy 0.9620869159698486
Test set accuracy 0.9588607549667358


Epoch 18 in 1.00 sec
Training set loss 0.013438239693641663
Training set accuracy 0.9641191363334656
Test set accuracy 0.9602452516555786


Epoch 19 in 1.02 sec
Training set loss 0.012776298448443413
Training set accuracy 0.9654517769813538
Test set accuracy 0.9611353278160095


**!!! Do not forget to shutdown your Cloud TPU, or you'll spend much money on it!!!**