## ⚖️ Choose A or C:

## A: Emulating multi-device system on CPU

Use this section to initialize a set of virtual devices on CPU if you have no access to a multi-device system.

It can also help you prototype, debug and test your multi-device code locally before running it on the expensive system.

Even in the case of using Google Colab it can help you prototype faster because a CPU runtime is faster to restart.

In [None]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

In [None]:
import jax
import jax.numpy as jnp

In [None]:
jax.devices()



[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

## (do not use) B: Setting up Colab TPU

#### [this section creates Colab TPU, which does not work with recent JAX versions anymore. Do not use TPU Runtime from Colab with this notebook]

Use this section if you want to use Google Cloud TPU (and don't forget to change the Runtime type in "Runtime"-> "Change runtime type" -> "TPU".

In [None]:
# in order to use TPU you have to run this cell before importing JAX
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

tpu


In [None]:
import jax
import jax.numpy as jnp

In [None]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

## C: Using Cloud TPU and a Local runtime in Colab

Make preparations according to the following links:

* Creating a Cloud TPU https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#tpu-vms

* Preparing Jupyter and connect to a Local runtime https://research.google.com/colaboratory/local-runtimes.html


In [1]:
!pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html


In [2]:
import jax
import jax.numpy as jnp

In [3]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

## Using pjit()

In [4]:
jax.__version__

'0.4.13'

In [5]:
from jax.experimental.pjit import pjit

In [6]:
from jax import random

### Old vmap+pmap example

In [7]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [8]:
rng_key = random.PRNGKey(42)

In [9]:
vs = random.normal(rng_key, shape=(20_000_000,3))
v1s = vs[:10_000_000,:]
v2s = vs[10_000_000:,:]

v1s.shape, v2s.shape

((10000000, 3), (10000000, 3))

In [10]:
v1sp = v1s.reshape((8, v1s.shape[0]//8, v1s.shape[1]))
v2sp = v2s.reshape((8, v2s.shape[0]//8, v2s.shape[1]))

v1sp.shape, v2sp.shape

((8, 1250000, 3), (8, 1250000, 3))

In [11]:
dot_parallel = jax.pmap(
    jax.vmap(dot, in_axes=(0,0)),
    in_axes=(0,0)
)

In [12]:
x_pmap = dot_parallel(v1sp,v2sp)

In [13]:
x_pmap.shape

(8, 1250000)

In [14]:
x_pmap = x_pmap.reshape((x_pmap.shape[0]*x_pmap.shape[1]))
x_pmap.shape

(10000000,)

### Replacing with pjit and 1D mesh

In [15]:
from jax.sharding import Mesh
from jax.sharding import PartitionSpec

In [16]:
import numpy as np

In [17]:
devices = np.array(jax.devices())
devices

array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
       TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
       TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
       TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
       TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
       TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
       TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
       TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)],
      dtype=object)

Non-vectorized function won't work here because output partitioning works only for rank>=1 tensors

In [18]:
f = pjit(dot,
         in_axis_resources=None,
         out_axis_resources=PartitionSpec('devices')
         )

In [19]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

ValueError: ignored

Using a vectorized function that produces rank 1 tensor output. Input is replicated across all the devices.

In [20]:
f = pjit(jax.vmap(dot),
         in_axis_resources=None,
         out_axis_resources=PartitionSpec('devices')
         )

In [21]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [22]:
x_pjit.shape

(10000000,)

In [23]:
jax.numpy.all(x_pjit == x_pmap)

Array(True, dtype=bool)

Input is sharded across all the devices.

In [24]:
f = pjit(jax.vmap(dot),
         in_axis_resources=PartitionSpec('devices'),
         out_axis_resources=PartitionSpec('devices')
         )

In [25]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [26]:
x_pjit.shape

(10000000,)

In [27]:
jax.numpy.all(x_pjit == x_pmap)

Array(True, dtype=bool)

The same as previous

In [28]:
f = pjit(jax.vmap(dot),
         in_axis_resources=(PartitionSpec('devices'), None),
         out_axis_resources=PartitionSpec('devices')
         )

In [29]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [30]:
x_pjit.shape

(10000000,)

Also the same

In [31]:
f = pjit(jax.vmap(dot),
         in_axis_resources=PartitionSpec('devices', None),
         out_axis_resources=PartitionSpec('devices')
         )

In [32]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [33]:
x_pjit.shape

(10000000,)

Trying more values in PartitionSpec than there are parameters in the function

In [34]:
f = pjit(jax.vmap(dot),
         in_axis_resources=(PartitionSpec('devices'), None, None),
         out_axis_resources=PartitionSpec('devices')
         )

In [35]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

ValueError: ignored

Finally sharding both input parameters

In [36]:
f = pjit(jax.vmap(dot),
         in_axis_resources=(PartitionSpec('devices'), PartitionSpec('devices')),
         out_axis_resources=PartitionSpec('devices')
         )

In [37]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [38]:
x_pjit.shape

(10000000,)

### 2D mesh case

In [39]:
from jax.sharding import PartitionSpec as P   # could be useful to reduce typing
from jax.sharding import Mesh
import numpy as np

In [40]:
rng_key = random.PRNGKey(42)

In [41]:
vs = random.normal(rng_key, shape=(8_000,10_000))
v1s = vs[:4_000,:]
v2s = vs[4_000:,:]

v1s.shape, v2s.shape

((4000, 10000), (4000, 10000))

In [42]:
devices = np.array(jax.devices()).reshape(2, 4)
devices

array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
        TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
        TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
        TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1)],
       [TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
        TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
        TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
        TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]],
      dtype=object)

In [43]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [44]:
f = pjit(jax.vmap(dot),
         in_axis_resources=(P('x', 'y'), P('x', 'y')),
         out_axis_resources=P('x')
         )

In [45]:
with Mesh(devices, ('x','y')):
  x_pjit=f(v1s,v2s)

In [46]:
x_pjit.shape

(4000,)

### Looking into HLO

You can't see any collective ops on the jaxpr level:

In [47]:
with Mesh(devices, ('x','y')):
  print(jax.make_jaxpr(f)(v1s,v2s))

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[4000,10000][39m b[35m:f32[4000,10000][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[4000][39m = pjit[
      in_shardings=(GSPMDSharding({devices=[2,4]0,1,2,3,4,5,6,7}), GSPMDSharding({devices=[2,4]0,1,2,3,4,5,6,7}))
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; d[35m:f32[4000,10000][39m e[35m:f32[4000,10000][39m. [34m[22m[1mlet
          [39m[22m[22mf[35m:f32[4000][39m = dot_general[dimension_numbers=(([1], [1]), ([0], [0]))] d
            e
        [34m[22m[1min [39m[22m[22m(f,) }
      name=dot
      out_shardings=(GSPMDSharding({devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}),)
      resource_env=ResourceEnv(Mesh(device_ids=array([[0, 1, 2, 3],
       [4, 5, 6, 7]]), axis_names=('x', 'y')), ())
    ] a b
  [34m[22m[1min [39m[22m[22m(c,) }


But you can see them after the compilation:

In [48]:
# Thanks to https://github.com/google/jax/discussions/11275
with Mesh(devices, ('x','y')):
    modules = f.lower(v1s, v2s).compile().compiler_ir()
    for hlo in modules:
        print(hlo.to_string())

HloModule pjit_dot, is_scheduled=true, entry_computation_layout={(f32[2000,2500]{1,0:T(8,128)}, f32[2000,2500]{1,0:T(8,128)})->f32[2000]{0:T(1024)}}

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_rhs = f32[]{:T(256)} parameter(1)
  %scalar_lhs = f32[]{:T(256)} parameter(0)
  ROOT %add = f32[]{:T(256)} add(f32[]{:T(256)} %scalar_lhs, f32[]{:T(256)} %scalar_rhs)
}

%fused_computation (param_0.2: f32[2000,2500], param_1.2: f32[2000,2500]) -> f32[2000] {
  %param_0.2 = f32[2000,2500]{1,0:T(8,128)} parameter(0)
  %param_1.2 = f32[2000,2500]{1,0:T(8,128)} parameter(1)
  %multiply.2 = f32[2000,2500]{1,0:T(8,128)} multiply(f32[2000,2500]{1,0:T(8,128)} %param_0.2, f32[2000,2500]{1,0:T(8,128)} %param_1.2)
  %constant.2 = f32[]{:T(256)} constant(0)
  ROOT %reduce.2 = f32[2000]{0:T(1024)} reduce(f32[2000,2500]{1,0:T(8,128)} %multiply.2, f32[]{:T(256)} %constant.2), dimensions={1}, to_apply=%scalar_add_computation, metadata={op_name="pjit(dot)/jit(main)/dot_gen

## MLP example

### Preparing data

Install these modules if you created a new empty cloud machine

In [49]:
!pip install tensorflow



In [50]:
!pip install tensorflow_datasets



In [51]:
import jax
import tensorflow as tf
import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

data, info = tfds.load(name="mnist",
                       data_dir=data_dir,
                       as_supervised=True,
                       with_info=True)

data_train = data['train']
data_test  = data['test']

  from .autonotebook import tqdm as notebook_tqdm


In [52]:
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = info.features['label'].num_classes
NUM_DEVICES = jax.device_count()
BATCH_SIZE  = 32

In [53]:
def preprocess(img, label):
  """Resize and preprocess images."""
  return (tf.cast(img, tf.float32)/255.0), label

train_data = tfds.as_numpy(
    data_train.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)
test_data  = tfds.as_numpy(
    data_test.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)

In [54]:
len(train_data)

235

### Preparing MLP

In [55]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn import swish, logsumexp, one_hot

In [56]:
LAYER_SIZES = [28*28, 512, 10]
PARAM_SCALE = 0.01

In [57]:
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
  """Initialize all layers for a fully-connected neural network with given sizes"""

  def random_layer_params(m, n, key, scale=1e-2):
    """A helper function to randomly initialize weights and biases of a dense layer"""
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

In [58]:
def predict(params, image):
  """Function for per-example predictions."""
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = swish(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

batched_predict = vmap(predict, in_axes=(None, 0))

### Loss and update functions

In [59]:
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
import numpy as np

In [60]:
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5
NUM_EPOCHS  = 20

In [61]:
def loss(params, images, targets):
  """Categorical cross entropy loss function."""
  logits = batched_predict(params, images)
  log_preds = logits - logsumexp(logits) # logsumexp trick https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
  return -jnp.mean(targets*log_preds)

def update(params, x, y, epoch_number):
  loss_value, grads = value_and_grad(loss)(params, x, y)
  lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
  return [(w - lr * dw, b - lr * db)
          for (w, b), (dw, db) in zip(params, grads)], loss_value

In [62]:
f_update = pjit(update,
         in_axis_resources=(None, P('x'), P('x'), None),
         out_axis_resources=None
         )

### Training loop

In [63]:
devices = np.array(jax.devices())
devices

array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
       TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
       TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
       TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
       TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
       TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
       TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
       TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)],
      dtype=object)

In [64]:
def batch_accuracy(params, images, targets):
  images = jnp.reshape(images, (len(images), NUM_PIXELS))
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == targets)

f_batch_accuracy = pjit(batch_accuracy,
         in_axis_resources=(None, P('x'), P('x')),
         out_axis_resources=None
         )

def accuracy(params, data):
  accs = []
  for images, targets in data:
    accs.append(f_batch_accuracy(params, images, targets))
  return jnp.mean(jnp.array(accs))

In [65]:
import time

params = init_params
with Mesh(devices, ('x',)):
  for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    losses = []
    for x, y in train_data:
      x = jnp.reshape(x, (len(x), NUM_PIXELS))
      y = one_hot(y, NUM_LABELS)
      params, loss_value = f_update(params, x, y, epoch)
      losses.append(jnp.sum(loss_value))
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, train_data)
    test_acc = accuracy(params, test_data)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

Epoch 0 in 1.88 sec
Training set loss 0.6914460062980652
Training set accuracy 0.8648548126220703
Test set accuracy 0.874804675579071
Epoch 1 in 1.02 sec
Training set loss 0.6259850859642029
Training set accuracy 0.8882258534431458
Test set accuracy 0.893847644329071
Epoch 2 in 0.99 sec
Training set loss 0.6164931058883667
Training set accuracy 0.8978224396705627
Test set accuracy 0.901074230670929
Epoch 3 in 0.97 sec
Training set loss 0.6100647449493408
Training set accuracy 0.906527042388916
Test set accuracy 0.908984363079071
Epoch 4 in 1.00 sec
Training set loss 0.6044379472732544
Training set accuracy 0.9142231941223145
Test set accuracy 0.917285144329071
Epoch 5 in 1.02 sec
Training set loss 0.5995732545852661
Training set accuracy 0.9211103320121765
Test set accuracy 0.9248046875
Epoch 6 in 0.96 sec
Training set loss 0.5955986976623535
Training set accuracy 0.9261746406555176
Test set accuracy 0.929394543170929
Epoch 7 in 0.98 sec
Training set loss 0.5924407839775085
Training se

Some resources that might help you achieve weight sharding as well:
https://github.com/google/jax/discussions/8649

**!!! Do not forget to shutdown your Cloud TPU, or you'll spend much money on it!!!**