# Quickstart: running JAX on IPU

In [None]:
# Install experimental JAX for IPUs
import sys
!{sys.executable} -m pip uninstall -y jax jaxlib
!{sys.executable} -m pip install https://github.com/graphcore-research/jax-mk2-experimental/releases/latest/download/jaxlib-0.3.15-cp38-none-manylinux2014_x86_64.whl
!{sys.executable} -m pip install https://github.com/graphcore-research/jax-mk2-experimental/releases/latest/download/jax-0.3.16-py3-none-any.whl

In [None]:
from jax.config import config
# Uncomment to use IPU model emulator.
# config.FLAGS.jax_ipu_use_model = True
# config.FLAGS.jax_ipu_model_num_tiles = 8

* JAX will automatically select `ipu` as default backend, the order is ipu > tpu > gpu > cpu.

In [None]:
import jax

print(f"Platform={jax.default_backend()}")
print(f"Number of devices={jax.device_count()}")
devices = jax.devices()
print(devices)

* A demo to run a simple jit function on single IPU

In [None]:
import numpy as np
from jax import jit
import jax.numpy as jnp

@jit
def func(x, w, b):
    return jnp.matmul(w, x) + b

x = np.random.normal(size=[2, 3])
w = np.random.normal(size=[3, 2])
b = np.random.normal(size=[3, 3])

r = func(x, w, b)
print(f"Result = {r}")
print(f"Platform = {r.platform()}")
print(f"Device = {r.device()}")

* With `jax.device_put` API, we can put variables to certain device. Here is an example to run jit function on `ipu:0`:

In [None]:
x = jax.device_put(x, devices[0])
w = jax.device_put(w, devices[0])
b = jax.device_put(b, devices[0])

r = func(x, w, b)

print(f"Result = {r}")
print(f"Platform = {r.platform()}")
print(f"Device = {r.device()}")

* `jit` also support to config which backend the function will be running on, for example, below function will be running on `cpu` platform:

In [None]:
from functools import partial

@partial(jit, backend='cpu')
def func(x, w, b):
    return jnp.matmul(w, x) + b

r = func(x, w, b)
print(f"Result = {r}")
print(f"Platform = {r.platform()}")
print(f"Device = {r.device()}")

### JAX Pseudo Random Numbers generation

Reproducible random numbers across platforms using JAX ThreeFry PRNG.

In [None]:
def random_fn(seed: int):
    key = jax.random.PRNGKey(seed)
    k1, k2 = jax.random.split(key)
    return k2, jax.random.uniform(k1, (3,))

random_fn_cpu = jax.jit(random_fn, backend="cpu")
random_fn_ipu = jax.jit(random_fn, backend="ipu")

print("CPU PRNG:", random_fn_cpu(42))
print("IPU PRNG:", random_fn_ipu(42))