# Quickstart: running Jax on IPU

* Set `XLA_IPU_PLATFORM_DEVICE_COUNT=N` to attach N IPU devices for a program. Only 1 IPU device will be attached for a program by default.
* Set `TF_POPLAR_FLAGS='--use_ipu_model'` to run on ipu_model if you don't have an IPU. ipu_model is an IPU emulator running on the CPU. Max IPU device number is 2 when running on ipu_model.

In [1]:
import os

os.environ['TF_POPLAR_FLAGS'] = '--use_ipu_model'
os.environ['XLA_IPU_PLATFORM_DEVICE_COUNT'] = '2'

* Jax will automatically select `ipu` as default backend, the order is ipu > tpu > gpu > cpu.

In [2]:
import jax

print(f"Platform={jax.default_backend()}")
print(f"Number ipus={jax.device_count()}")
devices = jax.devices()
print(devices)

Platform=ipu
Number ipus=2
[IpuDevice(id=0), IpuDevice(id=1)]


* A demo to run a simple jit function on single IPU

In [3]:
import numpy as np
from jax import jit
import jax.numpy as jnp

@jit
def func(x, w, b):
  return jnp.matmul(w, x) + b

x = np.random.normal(size=[2, 3])
w = np.random.normal(size=[3, 2])
b = np.random.normal(size=[3, 3])

r = func(x, w, b)
print(f"Result = {r}")
print(f"Platform = {r.platform()}")
print(f"Device = {r.device()}")

Result = [[ 1.0280436  -0.526331   -0.9054605 ]
 [-1.9880531  -0.03559947  1.6808541 ]
 [-0.85754204 -1.2670149  -0.6839774 ]]
Platform = ipu
Device = ipu:0


* With `jax.device_put` API, we can put variables to certain device. Here is an example to run jit function on `ipu:1`:

In [4]:
x = jax.device_put(x, devices[1])
w = jax.device_put(w, devices[1])
b = jax.device_put(b, devices[1])

r = func(x, w, b)

print(f"Result = {r}")
print(f"Platform = {r.platform()}")
print(f"Device = {r.device()}")

Result = [[ 1.0280436  -0.526331   -0.9054605 ]
 [-1.9880531  -0.03559947  1.6808541 ]
 [-0.85754204 -1.2670149  -0.6839774 ]]
Platform = ipu
Device = ipu:1


* `jit` also support to config which backend the function will be running on, for example, below function will be running on `cpu` platform:

In [6]:
@partial(jit, backend='cpu')
def func(x, w, b):
  return jnp.matmul(w, x) + b

r = func(x, w, b)
print(f"Result = {r}")
print(f"Platform = {r.platform()}")
print(f"Device = {r.device()}")

Result = [[ 1.0280436  -0.526331   -0.9054605 ]
 [-1.9880531  -0.03559947  1.6808541 ]
 [-0.85754204 -1.2670149  -0.6839774 ]]
Platform = cpu
Device = cpu:0
