# Quickstart: running JAX on IPU

In [1]:
# Install experimental JAX for IPUs (SDK 3.1) from Github releases.
!pip uninstall -y jax jaxlib
!pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk310 -f https://graphcore-research.github.io/jax-experimental/wheels.html

Found existing installation: jax 0.3.16+ipu
Uninstalling jax-0.3.16+ipu:
  Successfully uninstalled jax-0.3.16+ipu
Found existing installation: jaxlib 0.3.15+ipu.sdk310
Uninstalling jaxlib-0.3.15+ipu.sdk310:
  Successfully uninstalled jaxlib-0.3.15+ipu.sdk310
[0mLooking in links: https://graphcore-research.github.io/jax-experimental/wheels.html
Collecting jax==0.3.16+ipu
  Downloading https://github.com/graphcore-research/jax-experimental/releases/download/jax-v0.3.16-ipu-beta2-sdk3/jax-0.3.16%2Bipu-py3-none-any.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m32.6 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hCollecting jaxlib==0.3.15+ipu.sdk310
  Downloading https://github.com/graphcore-research/jax-experimental/releases/download/jax-v0.3.16-ipu-beta2-sdk3/jaxlib-0.3.15%2Bipu.sdk310-cp38-none-manylinux2014_x86_64.whl (109.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.4/109.4 MB[0m [31m40.3 MB/s[0m eta [36m0

Set the number of IPUs (or use IPUModel emulator).

In [20]:
# Enable IPU compilation cache
import os
os.environ['TF_POPLAR_FLAGS'] = """
  --executable_cache_path=/tmp/ipu-ef-cache
  --show_progress_bar=true
"""

import jax
from jax.config import config

# Uncomment to use IPU model emulator.
# config.FLAGS.jax_ipu_use_model = True
# config.FLAGS.jax_ipu_model_num_tiles = 8

# Select how many IPUs will be visible.
config.update('jax_ipu_device_count', 2)


print(f"Platform={jax.default_backend()}")
print(f"Number of devices={jax.device_count()}")
devices = jax.devices()
print("\n".join([str(d) for d in devices]))

Platform=ipu
Number of devices=2
IpuDevice(id=0, num_tiles=1472, version=ipu2)
IpuDevice(id=1, num_tiles=1472, version=ipu2)


## JAX basics on IPU

Let's run a simple jit function on single IPU

In [22]:
import numpy as np
from jax import jit
import jax.numpy as jnp

@jit
def func(x, w, b):
    return jnp.matmul(w, x) + b

x = np.random.normal(size=[2, 3])
w = np.random.normal(size=[3, 2])
b = np.random.normal(size=[3, 3])

r = func(x, w, b)
print(f"Result:\n{r}")
print(f"Result.platform = {r.platform()}")
print(f"Result.device = {r.device()}")



Result:
[[-0.6316366   0.23908725 -0.57826674]
 [ 0.7896802  -0.05535984  1.5372299 ]
 [-5.2348304  -2.7256932   2.8524191 ]]
Result.platform = ipu
Result.device = IpuDevice(id=0, num_tiles=1472, version=ipu2)


With the `jax.device_put` API, we can put variables onto specific devices. Here is an example to run the jit function on `ipu:1`:

In [23]:
x = jax.device_put(x, devices[1])
w = jax.device_put(w, devices[1])
b = jax.device_put(b, devices[1])

r = func(x, w, b)

print(f"Result:\n{r}")
print(f"Result.platform = {r.platform()}")
print(f"Result.device = {r.device()}")



Result:
[[-0.6316366   0.23908725 -0.57826674]
 [ 0.7896802  -0.05535984  1.5372299 ]
 [-5.2348304  -2.7256932   2.8524191 ]]
Result.platform = ipu
Result.device = IpuDevice(id=1, num_tiles=1472, version=ipu2)


`jit` also allows us to choose the backend the function will be running on.  For example, the below function will be running on `cpu`.

In [24]:
from functools import partial

@partial(jit, backend='cpu')
def func(x, w, b):
    return jnp.matmul(w, x) + b

r = func(x, w, b)

print(f"Result:\n{r}")
print(f"Result.platform = {r.platform()}")
print(f"Result.device = {r.device()}")



Result:
[[-0.6316366   0.23908725 -0.57826674]
 [ 0.7896801  -0.05535996  1.5372299 ]
 [-5.2348304  -2.7256935   2.8524191 ]]
Result.platform = cpu
Result.device = TFRT_CPU_0


## JAX Pseudo Random Numbers generation

Reproducible random numbers across platforms using JAX ThreeFry PRNG.  We will run this on both CPU and IPU.

This is a relatively complex workload for IPU, so the first time through it will take a few seconds to compile.
Let's switch on logging to see the compilation in action.

In [18]:
config.update('jax_log_compiles', True)

In [25]:
def random_fn(seed: int):
    key = jax.random.PRNGKey(seed)
    k1, k2 = jax.random.split(key)
    return k2, jax.random.uniform(k1, (4,))

random_fn_cpu = jax.jit(random_fn, backend="cpu")
print("CPU PRNG:", random_fn_cpu(42))

random_fn_ipu = jax.jit(random_fn, backend="ipu")
print("IPU PRNG:", random_fn_ipu(42))



CPU PRNG: (DeviceArray([255383827, 267815257], dtype=uint32), DeviceArray([0.7367313 , 0.92771065, 0.91349196, 0.3181516 ], dtype=float32))




IPU PRNG: (DeviceArray([255383827, 267815257], dtype=uint32), DeviceArray([0.7367313 , 0.92771065, 0.91349196, 0.3181516 ], dtype=float32))


If you run the same cell again, JAX compilation will be triggered by the redefinition of `random_fn`, but the IPU compilation will hit the cache, and be super quick.

## JAX asynchronous dispatch on IPUs

JAX IPU supports synchronous dispatch, allowing simple and efficient implementation of:
* Inference and training pipeline (see MNIST examples);
* Pipelining between multiple IPUs;

In [22]:
@partial(jit, backend='ipu')
def compute_fn(x, w):
    return jnp.matmul(w, x)

In [24]:
x = np.random.normal(size=[1024, 1024]).astype(np.float32)
# First run to compile jitted function, and load it on IPU.
compute_fn(x, x)

DeviceArray([[-36.361366  ,  -3.042955  , -12.644773  , ...,
              -40.753418  ,   4.5766077 , -24.013191  ],
             [ 59.276512  , -42.203323  , -34.57003   , ...,
               18.124418  , -36.59385   , -11.677992  ],
             [ -1.694912  ,  25.681198  ,  52.33486   , ...,
               27.405115  , -28.903233  ,  10.453694  ],
             ...,
             [ 77.24507   ,  -8.220667  , -22.550009  , ...,
               -6.22768   ,  22.722406  , -64.64627   ],
             [ 74.65662   , -10.569157  , -21.336151  , ...,
              -24.757969  ,  61.907845  ,   6.3100595 ],
             [  5.3183784 ,   0.59175587,  14.068951  , ...,
               -5.0356216 , -22.174866  ,  14.702527  ]], dtype=float32)

In [32]:
# No blocking: benchmarking only dispatch time.
%time w = compute_fn(x, x)
w.block_until_ready()

CPU times: user 190 µs, sys: 52 µs, total: 242 µs
Wall time: 325 µs


DeviceArray([[-36.361366  ,  -3.042955  , -12.644773  , ...,
              -40.753418  ,   4.5766077 , -24.013191  ],
             [ 59.276512  , -42.203323  , -34.57003   , ...,
               18.124418  , -36.59385   , -11.677992  ],
             [ -1.694912  ,  25.681198  ,  52.33486   , ...,
               27.405115  , -28.903233  ,  10.453694  ],
             ...,
             [ 77.24507   ,  -8.220667  , -22.550009  , ...,
               -6.22768   ,  22.722406  , -64.64627   ],
             [ 74.65662   , -10.569157  , -21.336151  , ...,
              -24.757969  ,  61.907845  ,   6.3100595 ],
             [  5.3183784 ,   0.59175587,  14.068951  , ...,
               -5.0356216 , -22.174866  ,  14.702527  ]], dtype=float32)

In [36]:
# Blocking: benchmarking properly the matmul.
%time w = compute_fn(x, x).block_until_ready()

CPU times: user 7.16 ms, sys: 0 ns, total: 7.16 ms
Wall time: 6.35 ms
