# Quickstart: running JAX on IPU

In [1]:
# Install experimental JAX for IPUs (SDK 3.1) from Github releases.
import sys
!{sys.executable} -m pip uninstall -y jax jaxlib
!{sys.executable} -m pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk310 -f https://graphcore-research.github.io/jax-experimental/wheels.html

Found existing installation: jax 0.3.16+ipu
Not uninstalling jax at /nethome/paulb/github/jax-experimental-internal, outside environment /nethome/paulb/venvs/3.1.0+1205/3.1.0+1205_poptorch
Can't uninstall 'jax'. No files were found to uninstall.
Found existing installation: jaxlib 0.3.15+ipu.sdk310
Uninstalling jaxlib-0.3.15+ipu.sdk310:
  Successfully uninstalled jaxlib-0.3.15+ipu.sdk310
Collecting jaxlib==0.3.15+ipu.sdk310
  Downloading https://github.com/graphcore-research/jax-experimental/releases/latest/download/jaxlib-0.3.15+ipu.sdk310-cp38-none-manylinux2014_x86_64.whl (109.4 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.4/109.4 MB[0m [31m27.1 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:01[0m
Installing collected packages: jaxlib
Successfully installed jaxlib-0.3.15+ipu.sdk310
Collecting jax==0.3.16+ipu
  Downloading https://github.com/graphcore-research/jax-experimental/releases/latest/download/jax-0.3.16+ipu-py3-none-any.w

In [5]:
from jax.config import config

# Uncomment to use IPU model emulator.
# config.FLAGS.jax_ipu_use_model = True
# config.FLAGS.jax_ipu_model_num_tiles = 8

# Select how many IPUs will be visible.
config.FLAGS.jax_ipu_device_count = 4

* JAX will automatically select `ipu` as default backend, the order is ipu > tpu > gpu > cpu.

In [6]:
import jax

print(f"Platform={jax.default_backend()}")
print(f"Number of devices={jax.device_count()}")
devices = jax.devices()
print("\n".join([str(d) for d in devices]))

Platform=ipu
Number of devices=4
IpuDevice(id=0, num_tiles=1472, version=ipu2)
IpuDevice(id=1, num_tiles=1472, version=ipu2)
IpuDevice(id=2, num_tiles=1472, version=ipu2)
IpuDevice(id=3, num_tiles=1472, version=ipu2)


## JAX basics on IPU

* A demo to run a simple jit function on single IPU

In [7]:
import numpy as np
from jax import jit
import jax.numpy as jnp

@jit
def func(x, w, b):
    return jnp.matmul(w, x) + b

x = np.random.normal(size=[2, 3])
w = np.random.normal(size=[3, 2])
b = np.random.normal(size=[3, 3])

r = func(x, w, b)
print(f"Result = {r}")
print(f"Platform = {r.platform()}")
print(f"Device = {r.device()}")

Result = [[-0.03298193  0.1450178  -3.8689601 ]
 [ 0.26485693  0.6824822  -2.7068207 ]
 [-1.5246258   1.0486794  -1.1777194 ]]
Platform = ipu
Device = IpuDevice(id=0, num_tiles=1472, version=ipu2)


* With `jax.device_put` API, we can put variables to certain device. Here is an example to run jit function on `ipu:0`:

In [8]:
x = jax.device_put(x, devices[0])
w = jax.device_put(w, devices[0])
b = jax.device_put(b, devices[0])

r = func(x, w, b)

print(f"Result = {r}")
print(f"Platform = {r.platform()}")
print(f"Device = {r.device()}")

Result = [[-0.03298193  0.1450178  -3.8689601 ]
 [ 0.26485693  0.6824822  -2.7068207 ]
 [-1.5246258   1.0486794  -1.1777194 ]]
Platform = ipu
Device = IpuDevice(id=0, num_tiles=1472, version=ipu2)


* `jit` also support to config which backend the function will be running on, for example, below function will be running on `cpu` platform:

In [9]:
from functools import partial

@partial(jit, backend='cpu')
def func(x, w, b):
    return jnp.matmul(w, x) + b

r = func(x, w, b)
print(f"Result = {r}")
print(f"Platform = {r.platform()}")
print(f"Device = {r.device()}")

Result = [[-0.03298193  0.1450178  -3.8689601 ]
 [ 0.26485693  0.6824822  -2.7068207 ]
 [-1.5246258   1.0486792  -1.1777194 ]]
Platform = cpu
Device = TFRT_CPU_0


## JAX Pseudo Random Numbers generation

Reproducible random numbers across platforms using JAX ThreeFry PRNG.

In [10]:
def random_fn(seed: int):
    key = jax.random.PRNGKey(seed)
    k1, k2 = jax.random.split(key)
    return k2, jax.random.uniform(k1, (3,))

random_fn_cpu = jax.jit(random_fn, backend="cpu")
random_fn_ipu = jax.jit(random_fn, backend="ipu")

print("CPU PRNG:", random_fn_cpu(42))
print("IPU PRNG:", random_fn_ipu(42))

CPU PRNG: (DeviceArray([255383827, 267815257], dtype=uint32), DeviceArray([0.7367313 , 0.83174706, 0.91349196], dtype=float32))
IPU PRNG: (DeviceArray([255383827, 267815257], dtype=uint32), DeviceArray([0.7367313 , 0.83174706, 0.91349196], dtype=float32))


## JAX asynchronous dispatch on IPUs

JAX IPU supports synchronous dispatch, allowing simple and efficient implementation of:
* Inference and training pipeline (see MNIST examples);
* Pipelining between multiple IPUs;

In [22]:
@partial(jit, backend='ipu')
def compute_fn(x, w):
    return jnp.matmul(w, x)

In [24]:
x = np.random.normal(size=[1024, 1024]).astype(np.float32)
# First run to compile jitted function, and load it on IPU.
compute_fn(x, x)

DeviceArray([[-36.361366  ,  -3.042955  , -12.644773  , ...,
              -40.753418  ,   4.5766077 , -24.013191  ],
             [ 59.276512  , -42.203323  , -34.57003   , ...,
               18.124418  , -36.59385   , -11.677992  ],
             [ -1.694912  ,  25.681198  ,  52.33486   , ...,
               27.405115  , -28.903233  ,  10.453694  ],
             ...,
             [ 77.24507   ,  -8.220667  , -22.550009  , ...,
               -6.22768   ,  22.722406  , -64.64627   ],
             [ 74.65662   , -10.569157  , -21.336151  , ...,
              -24.757969  ,  61.907845  ,   6.3100595 ],
             [  5.3183784 ,   0.59175587,  14.068951  , ...,
               -5.0356216 , -22.174866  ,  14.702527  ]], dtype=float32)

In [32]:
# No blocking: benchmarking only dispatch time.
%time w = compute_fn(x, x)
w.block_until_ready()

CPU times: user 190 µs, sys: 52 µs, total: 242 µs
Wall time: 325 µs


DeviceArray([[-36.361366  ,  -3.042955  , -12.644773  , ...,
              -40.753418  ,   4.5766077 , -24.013191  ],
             [ 59.276512  , -42.203323  , -34.57003   , ...,
               18.124418  , -36.59385   , -11.677992  ],
             [ -1.694912  ,  25.681198  ,  52.33486   , ...,
               27.405115  , -28.903233  ,  10.453694  ],
             ...,
             [ 77.24507   ,  -8.220667  , -22.550009  , ...,
               -6.22768   ,  22.722406  , -64.64627   ],
             [ 74.65662   , -10.569157  , -21.336151  , ...,
              -24.757969  ,  61.907845  ,   6.3100595 ],
             [  5.3183784 ,   0.59175587,  14.068951  , ...,
               -5.0356216 , -22.174866  ,  14.702527  ]], dtype=float32)

In [36]:
# Blocking: benchmarking properly the matmul.
%time w = compute_fn(x, x).block_until_ready()

CPU times: user 7.16 ms, sys: 0 ns, total: 7.16 ms
Wall time: 6.35 ms
