# Quickstart: running Jax on IPU

In [1]:
from jax.config import config
# Using IPU model emulator. Comment to use IPU hardware.
config.FLAGS.jax_ipu_use_model = True
config.FLAGS.jax_ipu_model_num_tiles = 8

* Jax will automatically select `ipu` as default backend, the order is ipu > tpu > gpu > cpu.

In [2]:
import jax

print(f"Platform={jax.default_backend()}")
print(f"Number ipus={jax.device_count()}")
devices = jax.devices()
print(devices)

Platform=ipu
Number ipus=1
[IpuDevice(id=0, tiles=8)]


2023-01-31 15:10:39.660557: W external/org_tensorflow/tensorflow/compiler/plugin/poplar/driver/poplar_executor.cc:1782] A version was not supplied when using the IPU Model. Defaulting to 'ipu2'.


* A demo to run a simple jit function on single IPU

In [3]:
import numpy as np
from jax import jit
import jax.numpy as jnp

@jit
def func(x, w, b):
    return jnp.matmul(w, x) + b

x = np.random.normal(size=[2, 3])
w = np.random.normal(size=[3, 2])
b = np.random.normal(size=[3, 3])

r = func(x, w, b)
print(f"Result = {r}")
print(f"Platform = {r.platform()}")
print(f"Device = {r.device()}")

Result = [[ 1.486894    2.635103   -4.4571934 ]
 [ 0.14053345 -1.1268771   0.8805641 ]
 [-1.3035688  -1.5401895  -1.3022151 ]]
Platform = ipu
Device = ipu:0


* With `jax.device_put` API, we can put variables to certain device. Here is an example to run jit function on `ipu:1`:

In [4]:
x = jax.device_put(x, devices[0])
w = jax.device_put(w, devices[0])
b = jax.device_put(b, devices[0])

r = func(x, w, b)

print(f"Result = {r}")
print(f"Platform = {r.platform()}")
print(f"Device = {r.device()}")

Result = [[ 1.486894    2.635103   -4.4571934 ]
 [ 0.14053345 -1.1268771   0.8805641 ]
 [-1.3035688  -1.5401895  -1.3022151 ]]
Platform = ipu
Device = ipu:0


* `jit` also support to config which backend the function will be running on, for example, below function will be running on `cpu` platform:

In [5]:
from functools import partial

@partial(jit, backend='cpu')
def func(x, w, b):
    return jnp.matmul(w, x) + b

r = func(x, w, b)
print(f"Result = {r}")
print(f"Platform = {r.platform()}")
print(f"Device = {r.device()}")

Result = [[ 1.4868938   2.635103   -4.4571934 ]
 [ 0.14053339 -1.1268772   0.8805641 ]
 [-1.3035688  -1.5401895  -1.3022151 ]]
Platform = cpu
Device = TFRT_CPU_0
