## ⚖️ Choose A or C:

## A: Emulating multi-device system on CPU

Use this section to initialize a set of virtual devices on CPU if you have no access to a multi-device system.

It can also help you prototype, debug and test your multi-device code locally before running it on the expensive system.

Even in the case of using Google Colab it can help you prototype faster because a CPU runtime is faster to restart.

In [1]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

In [2]:
import jax
import jax.numpy as jnp

In [3]:
jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

## (do not use) B: Setting up Colab TPU

#### [this section creates Colab TPU, which does not work with recent JAX versions anymore. Do not use TPU Runtime from Colab with this notebook]

Use this section if you want to use Google Cloud TPU (and don't forget to change the Runtime type in "Runtime"-> "Change runtime type" -> "TPU".

In [None]:
# in order to use TPU you have to run this cell before importing JAX
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

tpu


In [None]:
import jax
import jax.numpy as jnp

In [None]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

## C: Using Cloud TPU and a Local runtime in Colab

Make preparations according to the following links:

* Creating a Cloud TPU https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#tpu-vms

* Preparing Jupyter and connect to a Local runtime https://research.google.com/colaboratory/local-runtimes.html


In [None]:
!pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html


In [None]:
import jax
import jax.numpy as jnp

In [None]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

## Using pjit()

In [4]:
jax.__version__

'0.4.23'

In [5]:
from jax.experimental.pjit import pjit

In [6]:
from jax import random

### Old vmap+pmap example

In [7]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [8]:
rng_key = random.PRNGKey(42)

In [9]:
vs = random.normal(rng_key, shape=(20_000_000,3))
v1s = vs[:10_000_000,:]
v2s = vs[10_000_000:,:]

v1s.shape, v2s.shape

((10000000, 3), (10000000, 3))

In [10]:
v1sp = v1s.reshape((8, v1s.shape[0]//8, v1s.shape[1]))
v2sp = v2s.reshape((8, v2s.shape[0]//8, v2s.shape[1]))

v1sp.shape, v2sp.shape

((8, 1250000, 3), (8, 1250000, 3))

In [11]:
dot_parallel = jax.pmap(
    jax.vmap(dot, in_axes=(0,0)),
    in_axes=(0,0)
)

In [12]:
x_pmap = dot_parallel(v1sp,v2sp)

In [13]:
x_pmap.shape

(8, 1250000)

In [14]:
x_pmap = x_pmap.reshape((x_pmap.shape[0]*x_pmap.shape[1]))
x_pmap.shape

(10000000,)

### Replacing with pjit and 1D mesh

In [15]:
from jax.sharding import Mesh
from jax.sharding import PartitionSpec

In [16]:
import numpy as np

In [17]:
devices = np.array(jax.devices())
devices

array([CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3),
       CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)],
      dtype=object)

Non-vectorized function won't work here because output partitioning works only for rank>=1 tensors

In [19]:
f = pjit(dot,
         in_shardings =None,
         out_shardings =PartitionSpec('devices')
         )

In [20]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

ValueError: One of pjit outputs is incompatible with its sharding annotation NamedSharding(mesh=Mesh('devices': 8), spec=PartitionSpec('devices',)): Sharding NamedSharding(mesh=Mesh('devices': 8), spec=PartitionSpec('devices',)) is only valid for values of rank at least 1, but was applied to a value of rank 0. For scalars the PartitionSpec should be P()

Using a vectorized function that produces rank 1 tensor output. Input is replicated across all the devices.

In [21]:
f = pjit(jax.vmap(dot),
         in_shardings=None,
         out_shardings=PartitionSpec('devices')
         )

In [22]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [23]:
x_pjit.shape

(10000000,)

In [24]:
jax.numpy.all(x_pjit == x_pmap)

Array(True, dtype=bool)

Input is sharded across all the devices.

In [25]:
f = pjit(jax.vmap(dot),
         in_shardings=PartitionSpec('devices'),
         out_shardings=PartitionSpec('devices')
         )

In [26]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [27]:
x_pjit.shape

(10000000,)

In [28]:
jax.numpy.all(x_pjit == x_pmap)

Array(True, dtype=bool)

The same as previous

In [29]:
f = pjit(jax.vmap(dot),
         in_shardings=(PartitionSpec('devices'), None),
         out_shardings=PartitionSpec('devices')
         )

In [30]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [31]:
x_pjit.shape

(10000000,)

Also the same

In [32]:
f = pjit(jax.vmap(dot),
         in_shardings=PartitionSpec('devices', None),
         out_shardings=PartitionSpec('devices')
         )

In [33]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [34]:
x_pjit.shape

(10000000,)

Trying more values in PartitionSpec than there are parameters in the function

In [35]:
f = pjit(jax.vmap(dot),
         in_shardings=(PartitionSpec('devices'), None, None),
         out_shardings=PartitionSpec('devices')
         )

In [36]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

ValueError: pjit in_shardings specification must be a tree prefix of the positional arguments tuple passed to the `pjit`-decorated function. In particular, pjit in_shardings must either be a None, a PartitionSpec, or a tuple of length equal to the number of positional arguments. But pjit in_shardings is the wrong length: got a tuple or list of length 3 for an args tuple of length 2.

Finally sharding both input parameters

In [37]:
f = pjit(jax.vmap(dot),
         in_shardings=(PartitionSpec('devices'), PartitionSpec('devices')),
         out_shardings=PartitionSpec('devices')
         )

In [38]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [39]:
x_pjit.shape

(10000000,)

### 2D mesh case

In [40]:
from jax.sharding import PartitionSpec as P   # could be useful to reduce typing
from jax.sharding import Mesh
import numpy as np

In [41]:
rng_key = random.PRNGKey(42)

In [42]:
vs = random.normal(rng_key, shape=(8_000,10_000))
v1s = vs[:4_000,:]
v2s = vs[4_000:,:]

v1s.shape, v2s.shape

((4000, 10000), (4000, 10000))

In [43]:
devices = np.array(jax.devices()).reshape(2, 4)
devices

array([[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2),
        CpuDevice(id=3)],
       [CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6),
        CpuDevice(id=7)]], dtype=object)

In [44]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [54]:
f = pjit(jax.vmap(dot),
         in_shardings=(P('x', 'y'), P('x', 'y')),
         out_shardings=P('x')
         )

In [46]:
with Mesh(devices, ('x','y')):
  x_pjit=f(v1s,v2s)

In [47]:
x_pjit.shape

(4000,)

### Looking into HLO

You can't see any collective ops on the jaxpr level:

In [48]:
with Mesh(devices, ('x','y')):
  print(jax.make_jaxpr(f)(v1s,v2s))

{ lambda ; a:f32[4000,10000] b:f32[4000,10000]. let
    c:f32[4000] = pjit[
      name=dot
      in_shardings=(GSPMDSharding({devices=[2,4]<=[8]}), GSPMDSharding({devices=[2,4]<=[8]}))
      jaxpr={ lambda ; d:f32[4000,10000] e:f32[4000,10000]. let
          f:f32[4000] = dot_general[
            dimension_numbers=(([1], [1]), ([0], [0]))
            preferred_element_type=float32
          ] d e
        in (f,) }
      out_shardings=(GSPMDSharding({devices=[2,4]<=[8] last_tile_dim_replicate}),)
      resource_env=ResourceEnv(mesh=Mesh('x': 2, 'y': 4), ())
    ] a b
  in (c,) }


But you can see them after the compilation:

In [65]:
with Mesh(devices, ('x','y')):
   print(f.lower(v1s, v2s).compile().as_text())

HloModule pjit_dot, entry_computation_layout={(f32[2000,2500]{1,0}, f32[2000,2500]{1,0})->f32[2000]{0}}, num_partitions=8

%add.clone (x.1: f32[], y.1: f32[]) -> f32[] {
  %x.1 = f32[] parameter(0)
  %y.1 = f32[] parameter(1)
  ROOT %add.1 = f32[] add(f32[] %x.1, f32[] %y.1)
}

ENTRY %main.6_spmd (param: f32[2000,2500], param.1: f32[2000,2500]) -> f32[2000] {
  %param = f32[2000,2500]{1,0} parameter(0), sharding={devices=[2,4]<=[8]}
  %param.1 = f32[2000,2500]{1,0} parameter(1), sharding={devices=[2,4]<=[8]}
  %dot = f32[2000]{0} dot(f32[2000,2500]{1,0} %param, f32[2000,2500]{1,0} %param.1), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}, metadata={op_name="pjit(dot)/jit(main)/dot_general[dimension_numbers=(((1,), (1,)), ((0,), (0,))) precision=None preferred_element_type=float32]" source_file="<ipython-input-44-ccdf370c8172>" source_line=2}
  ROOT %all-reduce = f32[2000]{0} all-reduce(f32[2000]{0} %dot), channel_id=1, replica_groups={{0,1,2,

In [66]:
# (doesn't work anymore)
# Thanks to https://github.com/google/jax/discussions/11275
with Mesh(devices, ('x','y')):
    modules = f.lower(v1s, v2s).compile().compiler_ir()
    for hlo in modules:
        print(hlo.to_string())

AttributeError: 'Compiled' object has no attribute 'compiler_ir'

## MLP example

### Preparing data

Install these modules if you created a new empty cloud machine

In [67]:
!pip install tensorflow



In [68]:
!pip install tensorflow_datasets



In [69]:
import jax
import tensorflow as tf
import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

data, info = tfds.load(name="mnist",
                       data_dir=data_dir,
                       as_supervised=True,
                       with_info=True)

data_train = data['train']
data_test  = data['test']

Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /tmp/tfds/mnist/3.0.1...


Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]

Dataset mnist downloaded and prepared to /tmp/tfds/mnist/3.0.1. Subsequent calls will reuse this data.


In [70]:
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = info.features['label'].num_classes
NUM_DEVICES = jax.device_count()
BATCH_SIZE  = 32

In [71]:
def preprocess(img, label):
  """Resize and preprocess images."""
  return (tf.cast(img, tf.float32)/255.0), label

train_data = tfds.as_numpy(
    data_train.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)
test_data  = tfds.as_numpy(
    data_test.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)

In [72]:
len(train_data)

235

### Preparing MLP

In [73]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn import swish, logsumexp, one_hot

In [74]:
LAYER_SIZES = [28*28, 512, 10]
PARAM_SCALE = 0.01

In [75]:
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
  """Initialize all layers for a fully-connected neural network with given sizes"""

  def random_layer_params(m, n, key, scale=1e-2):
    """A helper function to randomly initialize weights and biases of a dense layer"""
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

In [76]:
def predict(params, image):
  """Function for per-example predictions."""
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = swish(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

batched_predict = vmap(predict, in_axes=(None, 0))

### Loss and update functions

In [77]:
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
import numpy as np

In [78]:
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5
NUM_EPOCHS  = 20

In [79]:
def loss(params, images, targets):
  """Categorical cross entropy loss function."""
  logits = batched_predict(params, images)
  log_preds = logits - logsumexp(logits) # logsumexp trick https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
  return -jnp.mean(targets*log_preds)

def update(params, x, y, epoch_number):
  loss_value, grads = value_and_grad(loss)(params, x, y)
  lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
  return [(w - lr * dw, b - lr * db)
          for (w, b), (dw, db) in zip(params, grads)], loss_value

In [80]:
f_update = pjit(update,
         in_shardings=(None, P('x'), P('x'), None),
         out_shardings=None
         )

### Training loop

In [81]:
devices = np.array(jax.devices())
devices

array([CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3),
       CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)],
      dtype=object)

In [83]:
def batch_accuracy(params, images, targets):
  images = jnp.reshape(images, (len(images), NUM_PIXELS))
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == targets)

f_batch_accuracy = pjit(batch_accuracy,
         in_shardings=(None, P('x'), P('x')),
         out_shardings=None
         )

def accuracy(params, data):
  accs = []
  for images, targets in data:
    accs.append(f_batch_accuracy(params, images, targets))
  return jnp.mean(jnp.array(accs))

In [84]:
import time

params = init_params
with Mesh(devices, ('x',)):
  for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    losses = []
    for x, y in train_data:
      x = jnp.reshape(x, (len(x), NUM_PIXELS))
      y = one_hot(y, NUM_LABELS)
      params, loss_value = f_update(params, x, y, epoch)
      losses.append(jnp.sum(loss_value))
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, train_data)
    test_acc = accuracy(params, test_data)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

Epoch 0 in 10.87 sec
Training set loss 0.691484808921814
Training set accuracy 0.8647883534431458
Test set accuracy 0.875
Epoch 1 in 7.85 sec
Training set loss 0.6260023713111877
Training set accuracy 0.8881925940513611
Test set accuracy 0.8935546875
Epoch 2 in 7.83 sec
Training set loss 0.6165075302124023
Training set accuracy 0.8978723287582397
Test set accuracy 0.9009765982627869
Epoch 3 in 7.42 sec
Training set loss 0.6100791692733765
Training set accuracy 0.9065104126930237
Test set accuracy 0.9090820550918579
Epoch 4 in 7.19 sec
Training set loss 0.6044540405273438
Training set accuracy 0.9142231941223145
Test set accuracy 0.917285144329071
Epoch 5 in 7.66 sec
Training set loss 0.5995851755142212
Training set accuracy 0.9210605025291443
Test set accuracy 0.924609363079071
Epoch 6 in 6.53 sec
Training set loss 0.5956109166145325
Training set accuracy 0.9260915517807007
Test set accuracy 0.9291015863418579
Epoch 7 in 6.44 sec
Training set loss 0.5924531817436218
Training set accura

Some resources that might help you achieve weight sharding as well:
https://github.com/google/jax/discussions/8649

**!!! Do not forget to shutdown your Cloud TPU, or you'll spend much money on it!!!**