<a href="https://colab.research.google.com/github/halcy/LearningJAX/blob/main/JAX_Day_1_Evening_Parallelism_Basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# JAX Day 1 - Evening - Parallelism

Now that we can train an MNIST digits classifier, lets take a step back and see how to do some basic parallelism so we can go faster!

In [None]:
# Catchall "what is this runtime" cell
!nvidia-smi
GPU = !nvidia_smi

if len(GPU) > 3:
    GPU = True
else:
    GPU = False

!vmstat
print("")

import os

if "COLAB_TPU_ADDR" in os.environ:
    from tensorflow.python.profiler import profiler_client
    print("tpu:", os.environ['COLAB_TPU_ADDR'])
    tpu_profile_service_address = os.environ['COLAB_TPU_ADDR'].replace('8470', '8466')
    print(profiler_client.monitor(tpu_profile_service_address, 100, 2).strip())
    TPU = True
else:
    print("tpu: no")
    TPU = False

CPUS = os.cpu_count()
print("\ncpus:", CPUS)

NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.

procs -----------memory---------- ---swap-- -----io---- -system-- ------cpu-----
 r  b   swpd   free   buff  cache   si   so    bi    bo   in   cs us sy id wa st
 7  0      0 10594840 102420 2037596    0    0   863    21  178  328  2  1 96  1  0

tpu: 10.103.29.18:8470
Timestamp: 17:06:33
  TPU type: TPU v2
  Utilization of TPU Matrix Units (higher is better): 0.000%

cpus: 2


# Basic setup and data loading

Basically the same as the previous notebook, but now we also want xmap and mesh from jax.experimental.maps, to parallelize computations across multiple TPU devices, and we don't actually need MNIST.

In [None]:
# Set JAX, haiku and optax up for the TPU
!pip install --upgrade -q jax jaxlib dm-haiku optax tqdm

import requests
import os

if 'TPU_DRIVER_MODE' not in globals():
    url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    TPU_DRIVER_MODE = 1

# TPU driver as backend for JAX
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

[K     |████████████████████████████████| 708 kB 6.8 MB/s 
[K     |████████████████████████████████| 284 kB 60.6 MB/s 
[K     |████████████████████████████████| 118 kB 63.0 MB/s 
[K     |████████████████████████████████| 76 kB 3.6 MB/s 
[K     |████████████████████████████████| 57 kB 4.6 MB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone


In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import jax
import jax.numpy as jnp
from jax import jit, nn

from jax.experimental.maps import xmap, mesh

import haiku as hk
import optax

import tqdm

# Pinky promise: We are now aware xmap is experimental, and will adjust our expectations accordingly
import warnings
warnings.filterwarnings("ignore", message="xmap is an experimental feature and probably has bugs!")

ModuleNotFoundError: ignored

In [None]:
# Generate PRNG state
prng = jax.random.PRNGKey(23)

# Basic data parallelism example
xmap lets us run a function in parallel on all available devices with relative ease. It does this using named axes. Lets see how, and how we can use that to Go Fast!

Note that we're only using named axes in a very simple manner here - it wouldn't have been hard to just use positional axes instead. However, named axes allow jax to keep track of how to split computation across more than just batches, which seems fairly powerful and might be useful later.

In [None]:
# A very simple feedforward network
def simple_nn(x):
    for i in range(100):
        lin1 = hk.Linear(100)
        x = lin1(x)
    lin2 = hk.Linear(1)        
    x = lin2(x)
    return x

# Set up the model
in_shape = (100000, 1)
data = np.random.normal(size=in_shape)

model = hk.transform(simple_nn)
params = model.init(prng, data)

# Run the model a lot, without coming back from the TPU, using jax.laxi.fori_loop
# Note that when not doing this (i.e. syncing relatively often between TPU and host) the regular, one device versions
# will often be _faster_ than the multi-device version!
def predict(x):
    x = jax.lax.fori_loop(0, 10000, lambda _, xt: model.apply(params, None, xt), x) # Probably a good idea to do this for your training loop!
    return x

NameError: ignored

In [None]:
# Run using only one device, with jit()
basic_jitted = jit(predict)

print("jit(), first run:") # Run things twice - first time includes compilation, second does not
data = np.random.normal(size=in_shape)
%time print(np.mean(basic_jitted(data)))

print("\njit(), second run:")
data = np.random.normal(size=in_shape)
%time print(np.mean(basic_jitted(data)))

jit(), first run:
0.0
CPU times: user 19.4 s, sys: 53.5 s, total: 1min 12s
Wall time: 2min 6s

jit(), second run:
0.0
CPU times: user 17 s, sys: 54.3 s, total: 1min 11s
Wall time: 2min 4s


In [None]:
# Now, do the same using xmap - but still only a single device 
# All we've done is name the first axis in the input "batch"
in_axes = ["batch", ...]
out_axes = ["batch", ...]
basic_xmapped = xmap(predict, in_axes, out_axes)

print("\nxmap(), first run:")
data = np.random.normal(size=in_shape)
%time print(np.mean(basic_xmapped(data)))

print("\nxmap(), second run:")
data = np.random.normal(size=in_shape)
%time print(np.mean(basic_xmapped(data)))


xmap(), first run:
0.0
CPU times: user 18.5 s, sys: 54.9 s, total: 1min 13s
Wall time: 2min 6s

xmap(), second run:
0.0
CPU times: user 17 s, sys: 54.3 s, total: 1min 11s
Wall time: 2min 4s


In [None]:
# Now, lets use xmap but also have it run our batch in parallel on multiple TPU devices (cores)
# Note that for this, the axis length must be evenly divisible by the device count!
# Things are now fast!
parallel_xmapped = xmap(predict, in_axes, out_axes, axis_resources={'batch': 'batch_tpus'})
with mesh(jax.devices(), ('batch_tpus',)):
    print("\nxmap(), parallel, first run:")
    data = np.random.normal(size=in_shape)
    %time print(np.mean(parallel_xmapped(data)))

    print("\nxmap(), parallel, second run:")
    data = np.random.normal(size=in_shape)
    %time print(np.mean(parallel_xmapped(data)))


xmap(), parallel, first run:
0.0
CPU times: user 4 s, sys: 8.89 s, total: 12.9 s
Wall time: 21.4 s

xmap(), parallel, second run:
0.0
CPU times: user 2.7 s, sys: 8.35 s, total: 11 s
Wall time: 19.3 s


# Next up:

Tomorrow, we convert the MNIST network to train in parallel!