## ⚖️ Choose A or B:

## A: Emulating multi-device system on CPU

Use this section to initialize a set of virtual devices on CPU if you have no access to a multi-device system.

It can also help you prototype, debug and test your multi-device code locally before running it on the expensive system.

Even in the case of using Google Colab it can help you prototype faster because a CPU runtime is faster to restart.

In [None]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

In [None]:
import jax
import jax.numpy as jnp

In [None]:
jax.devices()

## B: Setting up TPU

Make preparations according to the Appendix C or Chapter 3 example

In [None]:
!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

In [1]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

tpu


In [2]:
import jax
import jax.numpy as jnp

In [3]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

## Be sure you use JAX version >= 0.4.11

In [4]:
jax.__version__

'0.4.31'

## Introducing named axes

In [5]:
from jax.experimental.maps import xmap

ModuleNotFoundError: No module named 'jax.experimental.maps'

: 

In [None]:
from jax import random

### Replacing vmap+pmap

In [None]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [None]:
rng_key = random.PRNGKey(42)

In [None]:
vs = random.normal(rng_key, shape=(20_000_000,3))
v1s = vs[:10_000_000,:].T
v2s = vs[10_000_000:,:].T

v1s.shape, v2s.shape

In [None]:
v1sp = v1s.reshape((v1s.shape[0], 8, v1s.shape[1]//8))
v2sp = v2s.reshape((v2s.shape[0], 8, v2s.shape[1]//8))

v1sp.shape, v2sp.shape

In [None]:
dot_parallel = jax.pmap(
    jax.vmap(dot, in_axes=(1,1)),
    in_axes=(1,1)
)

In [None]:
x_pmap = dot_parallel(v1sp,v2sp)

In [None]:
x_pmap.shape

In [None]:
x_pmap = x_pmap.reshape((x_pmap.shape[0]*x_pmap.shape[1]))
x_pmap.shape

In [None]:
f = xmap(dot,
         in_axes=(
             {1:'device', 2:'batch'},
             {1:'device', 2:'batch'}
         ),
         out_axes=['device', 'batch', ...]
)

In [None]:
x_xmap=f(v1sp,v2sp)

In [None]:
x_xmap.shape

In [None]:
x_xmap = x_xmap.reshape((x_xmap.shape[0]*x_xmap.shape[1]))
x_xmap.shape

In [None]:
jax.numpy.all(x_xmap == x_pmap)

Einsum for comparison

In [None]:
import numpy as np

In [None]:
dots = np.einsum("ib,ib->b", v1s, v2s)

In [None]:
dots.shape

In [None]:
jax.numpy.all(x_xmap == dots)

Changing order of output axes

In [None]:
f = xmap(dot,
         in_axes=(
             {1:'device', 2:'batch'},
             {1:'device', 2:'batch'}
         ),
         out_axes=['batch', 'device', ...]
)

In [None]:
x_xmap=f(v1sp,v2sp)

In [None]:
x_xmap.shape

## Using broadcasting

In [None]:
image = random.normal(rng_key, shape=(480,640,3)) # RGB image 640x480px
filters = random.normal(rng_key, shape=(5,3,3))   # 5 matrix filters of size 3x3

In [None]:
from jax.scipy.signal import convolve2d

In [None]:
def apply_filter(channel, kernel):
  return convolve2d(channel, kernel, mode="same")

In [None]:
apply_filters_to_image = xmap(apply_filter,
         in_axes=(
             {2:'channel'},
             {0:'filter'}
         ),
         out_axes={0:'filter', 3: 'channel'}
)

In [None]:
res = apply_filters_to_image(image, filters)

In [None]:
res.shape # (filters, h, w, channels)

## Using reductions

In [None]:
f = xmap(
     lambda x: jnp.sum(x, axis=['row']),
     in_axes=['row', 'col'],
     out_axes=['col']
  )

In [None]:
C = jnp.array([
    [1,2,3],
    [4,5,6],
    [7,8,9]
])

In [None]:
f(C)

## Using collectives

In [None]:
arr = jnp.array(range(8)).reshape(2,4)
arr

In [None]:
n_pmap = jax.pmap(
    jax.pmap(
        lambda x: x/jax.lax.psum(x, axis_name=('rows','cols')),
        axis_name='cols'
    ),
    axis_name='rows')

In [None]:
jnp.sum(n_pmap(arr))

In [None]:
n_pmap(arr)

In [None]:
n_xmap = xmap(
    lambda x: x/jax.lax.psum(x, axis_name=('rows','cols')),
    in_axes=['rows', 'cols', ...],
    out_axes=['rows', 'cols', ...]
)

In [None]:
jnp.sum(n_xmap(arr))

In [None]:
n_xmap(arr)

In [None]:
arr = jnp.array(range(10000)).reshape(100,100)

In [None]:
n_pmap(arr)

In [None]:
n_xmap(arr)

## Using meshes

In [None]:
from jax.sharding import Mesh

In [None]:
import numpy as np

jnp.array doesn't work for this type:

In [None]:
jnp.array(jax.devices()).reshape(4, 2)

In [None]:
devices = np.array(jax.devices()).reshape(4, 2)
devices

In [None]:
with Mesh(devices, ('x', 'y')):
  n_xmap = xmap(
    lambda x: x/jax.lax.psum(x, axis_name=('rows','cols')),
    in_axes=['rows', 'cols', ...],
    out_axes=['rows', 'cols', ...],
    axis_resources={'rows': 'x', 'cols': 'y'}
  )

  res = n_xmap(arr)

In [None]:
type(res), res.shape

### Simplifying initial xmap example (getting rid of reshaping)

In [None]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [None]:
rng_key = random.PRNGKey(42)

vs = random.normal(rng_key, shape=(20_000_000,3))
v1s = vs[:10_000_000,:].T
v2s = vs[10_000_000:,:].T

v1s.shape, v2s.shape

In [None]:
with Mesh(np.array(jax.devices()), ('device')):
  f = xmap(dot,
         in_axes=(
             {1:'batch'},
             {1:'batch'}
         ),
         out_axes=['batch', ...],
         axis_resources={'batch': 'device'}
  )
  x_xmap=f(v1s,v2s)

In [None]:
x_xmap.shape

In [None]:
jax.numpy.all(x_xmap == x_pmap)

## Neural network example with xmap() [NOT WORKING]

### Preparing data

In [None]:
!pip install tensorflow

In [None]:
!pip install tensorflow_datasets

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

data, info = tfds.load(name="mnist",
                       data_dir=data_dir,
                       as_supervised=True,
                       with_info=True)

data_train = data['train']
data_test  = data['test']

In [None]:
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = info.features['label'].num_classes
NUM_DEVICES = jax.device_count()
BATCH_SIZE  = 32

In [None]:
def preprocess(img, label):
  """Resize and preprocess images."""
  return (tf.cast(img, tf.float32)/255.0), label

train_data = tfds.as_numpy(
    data_train.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)
test_data  = tfds.as_numpy(
    data_test.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)

In [None]:
len(train_data)

### Preparing MLP

Potentially useful links:

- my question https://github.com/google/jax/discussions/13861
- translating simplified SPMD MLP to xmap (https://github.com/google/jax/issues/7167). Doesn't work because logsumexp uses pmax for which no differentiation rules implemented
- some code for MLP with bias term and transformer blocks (https://gist.github.com/mattjj/ba9b24df446a90902d7b41aeb0766a99). Only xmap for loss, not xmap for diff(loss).
- lax.pdot() documentation is actually missing (https://github.com/google/jax/pull/5020) (https://github.com/google/jax/discussions/13851)


In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn import swish, logsumexp, one_hot

In [None]:
LAYER_SIZES = [28*28, 512, 10]
AXES_NAMES  = ['inputs', 'hidden', 'classes']
PARAM_SCALE = 0.01

In [None]:
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
  """Initialize all layers for a fully-connected neural network with given sizes"""

  def random_layer_params(m, n, key, scale=1e-2):
    """A helper function to randomly initialize weights and biases of a dense layer"""
    w_key, b_key = random.split(key)
    print(f'Generating layer params: w={(m,n)} b={(n,)}')
    return scale * random.normal(w_key, (m, n)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

In [None]:
def predict(params, image):
  """Function for per-example predictions."""
  activations = image
  for (w,b), axis in zip(params[:-1], AXES_NAMES):
    outputs = jax.lax.pdot(activations, w, axis) + b
    activations = swish(outputs)

  final_w, final_b = params[-1]
  axis = AXES_NAMES[-2]
  logits = jax.lax.pdot(activations, final_w, axis) + final_b
  return logits

### Loss and update functions

In [None]:
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5
NUM_EPOCHS  = 20

In [None]:
def loss(params, images, labels):
  """Categorical cross entropy loss function."""
  logits = predict(params, images)
  log_preds = logits - logsumexp(logits, AXES_NAMES[-1])
  num_classes = jax.lax.psum(1, AXES_NAMES[-1])
  targets = one_hot(labels, num_classes, axis=AXES_NAMES[-1])
  losses = jax.lax.psum(targets*log_preds, AXES_NAMES[-1])
  return -jax.lax.pmean(losses, "batch")

In [None]:
def update(params, x, y, epoch_number):
  loss_value, grads = value_and_grad(loss)(params, x, y)
  lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
  return [(w - lr * dw, b - lr * db)
          for (w, b), (dw, db) in zip(params, grads)], loss_value

In [None]:
update_named = xmap(update,
                  in_axes=[
                      [
                          ({0: 'inputs', 1: 'hidden'}, {0: 'hidden'}),
                          ({0: 'hidden', 1:'classes'}, {0:'classes'})
                      ],
                      {0: 'batch',  1: 'inputs'},
                      {0: 'batch'},
                      {}
                  ],
                  out_axes=(
                      ([
                        (['inputs', 'hidden', ...], ['hidden', ...]),
                        (['hidden', 'classes', ...], ['classes', ...])
                      ],
                      {})
                  ),
                  )

### Section for debugging purposes

In [None]:
train_data_iter = iter(train_data)
x, y = next(train_data_iter)

In [None]:
x.shape, y.shape

In [None]:
x = jnp.reshape(x, (NUM_DEVICES*BATCH_SIZE, NUM_PIXELS))
#y = jnp.reshape(one_hot(y, NUM_LABELS), (NUM_DEVICES*BATCH_SIZE, NUM_LABELS))
x.shape, y.shape

In [None]:
updated_params, loss_value = update_named(init_params, x, y, 0)

In [None]:
loss_value

In [None]:
?jax.lax.pdot

### Training loop

In [None]:
@jit
def batch_accuracy(params, images, targets):
  images = jnp.reshape(images, (len(images), NUM_PIXELS))
  predicted_class = jnp.argmax(vmap(predict)(params, images), axis=1)
  return jnp.mean(predicted_class == targets)

def accuracy(params, data):
  accs = []
  for images, targets in data:
    accs.append(batch_accuracy(params, images, targets))
  return jnp.mean(jnp.array(accs))

In [None]:
import time

params = init_params

for epoch in range(NUM_EPOCHS):
  start_time = time.time()
  losses = []
  for x, y in train_data:
    num_elements = len(y)
    x = jnp.reshape(x, (num_elements, NUM_PIXELS))
    #y = jnp.reshape(one_hot(y, NUM_LABELS), (NUM_DEVICES, num_elements//NUM_DEVICES, NUM_LABELS))
    params, loss_value = update_named(params, x, y, epoch)
    losses.append(loss_value)
  epoch_time = time.time() - start_time

  #train_acc = accuracy(params, train_data)
  #test_acc = accuracy(params, test_data)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
  #print("Training set accuracy {}".format(train_acc))
  #print("Test set accuracy {}".format(test_acc))