## ⚖️ Choose A or C:

## A: Emulating multi-device system on CPU

Use this section to initialize a set of virtual devices on CPU if you have no access to a multi-device system.

It can also help you prototype, debug and test your multi-device code locally before running it on the expensive system.

Even in the case of using Google Colab it can help you prototype faster because a CPU runtime is faster to restart.

In [None]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

In [None]:
import jax
import jax.numpy as jnp

In [None]:
jax.devices()

## (do not use) B: Setting up Colab TPU

#### [this section creates Colab TPU, which does not work with recent JAX versions anymore. Do not use TPU Runtime from Colab with this notebook]

Use this section if you want to use Google Cloud TPU (and don't forget to change the Runtime type in "Runtime"-> "Change runtime type" -> "TPU".

In [None]:
# in order to use TPU you have to run this cell before importing JAX
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [2]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

tpu


In [3]:
import jax
import jax.numpy as jnp

In [4]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

## C: Using Cloud TPU and a Local runtime in Colab

Make preparations according to the following links:

* Creating a Cloud TPU https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#tpu-vms

* Preparing Jupyter and connect to a Local runtime https://research.google.com/colaboratory/local-runtimes.html


In [5]:
!pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

  pid, fd = os.forkpty()


Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html


In [6]:
import jax
import jax.numpy as jnp

In [7]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

## Using pjit()

In [8]:
jax.__version__

'0.4.31'

In [9]:
from jax.experimental.pjit import pjit

In [10]:
from jax import random

### Old vmap+pmap example

In [20]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [21]:
rng_key = random.PRNGKey(42)

In [22]:
vs = random.normal(rng_key, shape=(20_000_000,3))
v1s = vs[:10_000_000,:]
v2s = vs[10_000_000:,:]

v1s.shape, v2s.shape

((10000000, 3), (10000000, 3))

In [23]:
v1sp = v1s.reshape((4, v1s.shape[0]//4, v1s.shape[1]))
v2sp = v2s.reshape((4, v2s.shape[0]//4, v2s.shape[1]))

v1sp.shape, v2sp.shape

((4, 2500000, 3), (4, 2500000, 3))

In [24]:
dot_parallel = jax.pmap(
    jax.vmap(dot, in_axes=(0,0)),
    in_axes=(0,0)
)

In [25]:
x_pmap = dot_parallel(v1sp,v2sp)

In [26]:
x_pmap.shape

(4, 2500000)

In [27]:
x_pmap = x_pmap.reshape((x_pmap.shape[0]*x_pmap.shape[1]))
x_pmap.shape

(10000000,)

### Replacing with pjit and 1D mesh

In [28]:
from jax.sharding import Mesh
from jax.sharding import PartitionSpec

In [29]:
import numpy as np

In [30]:
devices = np.array(jax.devices())
devices

array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
       TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
       TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
       TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)],
      dtype=object)

Non-vectorized function won't work here because output partitioning works only for rank>=1 tensors

In [31]:
f = pjit(dot,
         in_shardings =None,
         out_shardings =PartitionSpec('devices')
         )

In [32]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

ValueError: One of pjit outputs is incompatible with its sharding annotation NamedSharding(mesh=Mesh('devices': 4), spec=PartitionSpec('devices',)): Sharding NamedSharding(mesh=Mesh('devices': 4), spec=PartitionSpec('devices',)) is only valid for values of rank at least 1, but was applied to a value of rank 0. For scalars the PartitionSpec should be P()

Using a vectorized function that produces rank 1 tensor output. Input is replicated across all the devices.

In [33]:
f = pjit(jax.vmap(dot),
         in_shardings=None,
         out_shardings=PartitionSpec('devices')
         )

In [34]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [35]:
x_pjit.shape

(10000000,)

In [36]:
jax.numpy.all(x_pjit == x_pmap)

Array(True, dtype=bool)

Input is sharded across all the devices.

In [37]:
f = pjit(jax.vmap(dot),
         in_shardings=PartitionSpec('devices'),
         out_shardings=PartitionSpec('devices')
         )

In [38]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [39]:
x_pjit.shape

(10000000,)

In [40]:
jax.numpy.all(x_pjit == x_pmap)

Array(True, dtype=bool)

The same as previous

In [41]:
f = pjit(jax.vmap(dot),
         in_shardings=(PartitionSpec('devices'), None),
         out_shardings=PartitionSpec('devices')
         )

In [42]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [43]:
x_pjit.shape

(10000000,)

Also the same

In [44]:
f = pjit(jax.vmap(dot),
         in_shardings=PartitionSpec('devices', None),
         out_shardings=PartitionSpec('devices')
         )

In [45]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [46]:
x_pjit.shape

(10000000,)

Trying more values in PartitionSpec than there are parameters in the function

In [47]:
f = pjit(jax.vmap(dot),
         in_shardings=(PartitionSpec('devices'), None, None),
         out_shardings=PartitionSpec('devices')
         )

In [48]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

ValueError: pjit in_shardings specification must be a tree prefix of the positional arguments tuple passed to the `pjit`-decorated function. In particular, pjit in_shardings must either be a None, a PartitionSpec, or a tuple of length equal to the number of positional arguments. But pjit in_shardings is the wrong length: got a tuple or list of length 3 for an args tuple of length 2.

Finally sharding both input parameters

In [49]:
f = pjit(jax.vmap(dot),
         in_shardings=(PartitionSpec('devices'), PartitionSpec('devices')),
         out_shardings=PartitionSpec('devices')
         )

In [50]:
with Mesh(devices, ('devices',)):
  x_pjit=f(v1s,v2s)

In [51]:
x_pjit.shape

(10000000,)

### 2D mesh case

In [52]:
from jax.sharding import PartitionSpec as P   # could be useful to reduce typing
from jax.sharding import Mesh
import numpy as np

In [53]:
rng_key = random.PRNGKey(42)

In [54]:
vs = random.normal(rng_key, shape=(8_000,10_000))
v1s = vs[:4_000,:]
v2s = vs[4_000:,:]

v1s.shape, v2s.shape

((4000, 10000), (4000, 10000))

In [55]:
devices = np.array(jax.devices()).reshape(2, 2)
devices

array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
        TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0)],
       [TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
        TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]],
      dtype=object)

In [56]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [57]:
f = pjit(jax.vmap(dot),
         in_shardings=(P('x', 'y'), P('x', 'y')),
         out_shardings=P('x')
         )

In [58]:
with Mesh(devices, ('x','y')):
  x_pjit=f(v1s,v2s)

In [59]:
x_pjit.shape

(4000,)

### Looking into HLO

You can't see any collective ops on the jaxpr level:

In [60]:
with Mesh(devices, ('x','y')):
  print(jax.make_jaxpr(f)(v1s,v2s))

{ lambda ; a:f32[4000,10000] b:f32[4000,10000]. let
    c:f32[4000] = pjit[
      name=dot
      in_shardings=(NamedSharding(mesh=Mesh('x': 2, 'y': 2), spec=PartitionSpec('x', 'y')), NamedSharding(mesh=Mesh('x': 2, 'y': 2), spec=PartitionSpec('x', 'y')))
      jaxpr={ lambda ; d:f32[4000,10000] e:f32[4000,10000]. let
          f:f32[4000] = dot_general[
            dimension_numbers=(([1], [1]), ([0], [0]))
            preferred_element_type=float32
          ] d e
        in (f,) }
      out_shardings=(NamedSharding(mesh=Mesh('x': 2, 'y': 2), spec=PartitionSpec('x',)),)
      resource_env=ResourceEnv(mesh=Mesh('x': 2, 'y': 2))
    ] a b
  in (c,) }


But you can see them after the compilation:

In [61]:
with Mesh(devices, ('x','y')):
   print(f.lower(v1s, v2s).compile().as_text())

HloModule pjit_dot, is_scheduled=true, entry_computation_layout={(f32[2000,5000]{1,0:T(8,128)}, f32[2000,5000]{1,0:T(8,128)})->f32[2000]{0:T(1024)}}, allow_spmd_sharding_propagation_to_parameters={false,false}, num_partitions=4

%scalar_add_computation.clone (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
  %scalar_rhs.1 = f32[]{:T(128)} parameter(1)
  %scalar_lhs.1 = f32[]{:T(128)} parameter(0)
  ROOT %add.1 = f32[]{:T(128)} add(f32[]{:T(128)} %scalar_lhs.1, f32[]{:T(128)} %scalar_rhs.1)
}

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_rhs = f32[]{:T(128)} parameter(1)
  %scalar_lhs = f32[]{:T(128)} parameter(0)
  ROOT %add = f32[]{:T(128)} add(f32[]{:T(128)} %scalar_lhs, f32[]{:T(128)} %scalar_rhs)
}

%fused_computation (param_0.2: f32[2000,5000], param_1.2: f32[2000,5000]) -> f32[2000] {
  %param_0.2 = f32[2000,5000]{1,0:T(8,128)} parameter(0)
  %param_1.2 = f32[2000,5000]{1,0:T(8,128)} parameter(1)
  %multiply.2 = f32[2000,5000]{1,0:T(8,1

In [63]:
# (doesn't work anymore)
# Thanks to https://github.com/google/jax/discussions/11275
with Mesh(devices, ('x','y')):
    modules = f.lower(v1s, v2s).compile().as_text()
    print(modules)
    #for hlo in modules:
    #    print(hlo.to_string())

HloModule pjit_dot, is_scheduled=true, entry_computation_layout={(f32[2000,5000]{1,0:T(8,128)}, f32[2000,5000]{1,0:T(8,128)})->f32[2000]{0:T(1024)}}, allow_spmd_sharding_propagation_to_parameters={false,false}, num_partitions=4

%scalar_add_computation.clone (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
  %scalar_rhs.1 = f32[]{:T(128)} parameter(1)
  %scalar_lhs.1 = f32[]{:T(128)} parameter(0)
  ROOT %add.1 = f32[]{:T(128)} add(f32[]{:T(128)} %scalar_lhs.1, f32[]{:T(128)} %scalar_rhs.1)
}

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_rhs = f32[]{:T(128)} parameter(1)
  %scalar_lhs = f32[]{:T(128)} parameter(0)
  ROOT %add = f32[]{:T(128)} add(f32[]{:T(128)} %scalar_lhs, f32[]{:T(128)} %scalar_rhs)
}

%fused_computation (param_0.2: f32[2000,5000], param_1.2: f32[2000,5000]) -> f32[2000] {
  %param_0.2 = f32[2000,5000]{1,0:T(8,128)} parameter(0)
  %param_1.2 = f32[2000,5000]{1,0:T(8,128)} parameter(1)
  %multiply.2 = f32[2000,5000]{1,0:T(8,1

## MLP example

### Preparing data

Install these modules if you created a new empty cloud machine

In [None]:
!pip install tensorflow

In [None]:
!pip install tensorflow_datasets

In [64]:
import jax
import tensorflow as tf
import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

data, info = tfds.load(name="mnist",
                       data_dir=data_dir,
                       as_supervised=True,
                       with_info=True)

data_train = data['train']
data_test  = data['test']

2024-08-13 00:18:51.823306: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-13 00:18:51.844751: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-13 00:18:51.851340: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [65]:
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = info.features['label'].num_classes
NUM_DEVICES = jax.device_count()
BATCH_SIZE  = 32

In [66]:
def preprocess(img, label):
  """Resize and preprocess images."""
  return (tf.cast(img, tf.float32)/255.0), label

train_data = tfds.as_numpy(
    data_train.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)
test_data  = tfds.as_numpy(
    data_test.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)

In [67]:
len(train_data)

469

### Preparing MLP

In [68]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn import swish, logsumexp, one_hot

In [69]:
LAYER_SIZES = [28*28, 512, 10]
PARAM_SCALE = 0.01

In [70]:
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
  """Initialize all layers for a fully-connected neural network with given sizes"""

  def random_layer_params(m, n, key, scale=1e-2):
    """A helper function to randomly initialize weights and biases of a dense layer"""
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

In [71]:
def predict(params, image):
  """Function for per-example predictions."""
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = swish(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

batched_predict = vmap(predict, in_axes=(None, 0))

### Loss and update functions

In [72]:
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
import numpy as np

In [73]:
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5
NUM_EPOCHS  = 20

In [74]:
def loss(params, images, targets):
  """Categorical cross entropy loss function."""
  logits = batched_predict(params, images)
  log_preds = logits - jnp.expand_dims(logsumexp(logits, axis=1), 1)  # logsumexp trick https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
  return -jnp.mean(targets*log_preds)

def update(params, x, y, epoch_number):
  loss_value, grads = value_and_grad(loss)(params, x, y)
  lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
  return [(w - lr * dw, b - lr * db)
          for (w, b), (dw, db) in zip(params, grads)], loss_value

In [75]:
f_update = pjit(update,
         in_shardings=(None, P('x'), P('x'), None),
         out_shardings=None
         )

### Training loop

In [76]:
devices = np.array(jax.devices())
devices

array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
       TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
       TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
       TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)],
      dtype=object)

In [77]:
def batch_accuracy(params, images, targets):
  images = jnp.reshape(images, (len(images), NUM_PIXELS))
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == targets)

f_batch_accuracy = pjit(batch_accuracy,
         in_shardings=(None, P('x'), P('x')),
         out_shardings=None
         )

def accuracy(params, data):
  accs = []
  for images, targets in data:
    accs.append(f_batch_accuracy(params, images, targets))
  return jnp.mean(jnp.array(accs))

In [79]:
import time

params = init_params
with Mesh(devices, ('x',)):
  for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    losses = []
    for x, y in train_data:
      x = jnp.reshape(x, (len(x), NUM_PIXELS))
      y = one_hot(y, NUM_LABELS)
      params, loss_value = f_update(params, x, y, epoch)
      losses.append(jnp.sum(loss_value))
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, train_data)
    test_acc = accuracy(params, test_data)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

Epoch 0 in 2.11 sec
Training set loss 0.0822610855102539
Training set accuracy 0.8977934122085571
Test set accuracy 0.9017009735107422
Epoch 1 in 0.84 sec
Training set loss 0.033943504095077515
Training set accuracy 0.9136738777160645
Test set accuracy 0.915249228477478
Epoch 2 in 0.88 sec
Training set loss 0.030296433717012405
Training set accuracy 0.9209200143814087
Test set accuracy 0.9210838675498962
Epoch 3 in 0.93 sec
Training set loss 0.027934584766626358
Training set accuracy 0.9271888732910156
Test set accuracy 0.9252373576164246
Epoch 4 in 0.89 sec
Training set loss 0.025554420426487923
Training set accuracy 0.9339630603790283
Test set accuracy 0.9323576092720032
Epoch 5 in 0.93 sec
Training set loss 0.023182667791843414
Training set accuracy 0.9401763677597046
Test set accuracy 0.938686728477478
Epoch 6 in 0.97 sec
Training set loss 0.02106645330786705
Training set accuracy 0.9453791379928589
Test set accuracy 0.9431368708610535
Epoch 7 in 0.95 sec
Training set loss 0.019250

Some resources that might help you achieve weight sharding as well:
https://github.com/google/jax/discussions/8649

**!!! Do not forget to shutdown your Cloud TPU, or you'll spend much money on it!!!**