# SPMD JAX on IPUs: `pjit` quickstart

In [1]:
# Install experimental JAX for IPUs (SDK 3.1) from Github releases.
import sys
!{sys.executable} -m pip uninstall -y jax jaxlib
!{sys.executable} -m pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk310 -f https://graphcore-research.github.io/jax-experimental/wheels.html

Found existing installation: jax 0.3.16+ipu
Uninstalling jax-0.3.16+ipu:
  Successfully uninstalled jax-0.3.16+ipu
Found existing installation: jaxlib 0.3.15+ipu.sdk310
Uninstalling jaxlib-0.3.15+ipu.sdk310:
  Successfully uninstalled jaxlib-0.3.15+ipu.sdk310
Looking in links: https://graphcore-research.github.io/jax-experimental/wheels.html
Collecting jax==0.3.16+ipu
  Using cached https://github.com/graphcore-research/jax-experimental/releases/download/jax-v0.3.16-ipu-beta2-sdk3/jax-0.3.16%2Bipu-py3-none-any.whl (1.2 MB)
Collecting jaxlib==0.3.15+ipu.sdk310
  Using cached https://github.com/graphcore-research/jax-experimental/releases/download/jax-v0.3.16-ipu-beta2-sdk3/jaxlib-0.3.15%2Bipu.sdk310-cp38-none-manylinux2014_x86_64.whl (109.4 MB)
Installing collected packages: jaxlib, jax
Successfully installed jax-0.3.16+ipu jaxlib-0.3.15+ipu.sdk310


In [1]:
from jax.config import config

# Select how many IPUs will be visible.
config.FLAGS.jax_ipu_device_count = 4

# Simulating `pmap` on CPU devices instead.
# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
# config.FLAGS.jax_platforms = "cpu,ipu"

In [2]:
import jax
from jax.experimental import maps
from jax.experimental import PartitionSpec
from jax.experimental.pjit import pjit
import numpy as np

In [8]:
print("IPU devices:")
ipu_devices = jax.devices("ipu")
ipu_devices

IPU devices:


[IpuDevice(id=0, num_tiles=1472, version=ipu2),
 IpuDevice(id=1, num_tiles=1472, version=ipu2),
 IpuDevice(id=2, num_tiles=1472, version=ipu2),
 IpuDevice(id=3, num_tiles=1472, version=ipu2)]

# Basic `pjit` examples

## Mesh definition

In [9]:
mesh_shape = (2, 2)
mesh_size = np.prod(mesh_shape)
mesh_devices = np.asarray(ipu_devices).reshape(*mesh_shape)
# 'x', 'y' axis names are used here for simplicity
mesh = maps.Mesh(mesh_devices, ('x', 'y'))
print("IPU device mesh:", mesh)

IPU device mesh: Mesh(array([[0, 1],
       [2, 3]]), ('x', 'y'))


In [10]:
# Input data to shard or replicate on IPU mesh.
N = 3
input_data = np.arange(N * mesh_size, dtype=np.float32).reshape(-1, 2)
print("INPUT:\n", input_data, input_data.shape)

INPUT:
 [[ 0.  1.]
 [ 2.  3.]
 [ 4.  5.]
 [ 6.  7.]
 [ 8.  9.]
 [10. 11.]] (6, 2)


## `x` and `y` axes output sharding

In [11]:
# No partition for inputs: data is replicated.
in_axis_resources = None
# Output is sharded on x and y axes.
out_axis_resources=PartitionSpec('x', 'y')

f = pjit(
    # Simple unary op to run
    lambda x: jax.lax.integer_pow(x, 2),
    in_axis_resources=in_axis_resources,
    out_axis_resources=out_axis_resources)
 
# Sends data to accelerators based on partition_spec
with maps.Mesh(mesh.devices, mesh.axis_names):
    output_data = f(input_data)



In [12]:
# Full output, and how it is sharded between devices
output_data, output_data.device_buffers

(ShardedDeviceArray([[  0.,   1.],
                     [  4.,   9.],
                     [ 16.,  25.],
                     [ 36.,  49.],
                     [ 64.,  81.],
                     [100., 121.]], dtype=float32),
 [DeviceArray([[ 0.],
               [ 4.],
               [16.]], dtype=float32),
  DeviceArray([[ 1.],
               [ 9.],
               [25.]], dtype=float32),
  DeviceArray([[ 36.],
               [ 64.],
               [100.]], dtype=float32),
  DeviceArray([[ 49.],
               [ 81.],
               [121.]], dtype=float32)])

## `x` axis output sharding

In [13]:
# No partition for inputs: data is replicated.
in_axis_resources = None
# Output is sharded on x axes.
out_axis_resources=PartitionSpec('x', None)

f = pjit(
    # Simple unary op to run
    lambda x: jax.lax.integer_pow(x, 2),
    in_axis_resources=in_axis_resources,
    out_axis_resources=out_axis_resources)
 
# Sends data to accelerators based on partition_spec
with maps.Mesh(mesh.devices, mesh.axis_names):
    output_data = f(input_data)



In [14]:
# Output: sharded along X axis, replicated along Y.
output_data, output_data.device_buffers

(ShardedDeviceArray([[  0.,   1.],
                     [  4.,   9.],
                     [ 16.,  25.],
                     [ 36.,  49.],
                     [ 64.,  81.],
                     [100., 121.]], dtype=float32),
 [DeviceArray([[ 0.,  1.],
               [ 4.,  9.],
               [16., 25.]], dtype=float32),
  DeviceArray([[ 0.,  1.],
               [ 4.,  9.],
               [16., 25.]], dtype=float32),
  DeviceArray([[ 36.,  49.],
               [ 64.,  81.],
               [100., 121.]], dtype=float32),
  DeviceArray([[ 36.,  49.],
               [ 64.,  81.],
               [100., 121.]], dtype=float32)])

## `y` axis output sharding

In [15]:
# No partition for inputs: data is replicated.
in_axis_resources = None
# Output is sharded on y axes.
out_axis_resources=PartitionSpec('y', None)

f = pjit(
    # Simple unary op to run
    lambda x: jax.lax.integer_pow(x, 2),
    in_axis_resources=in_axis_resources,
    out_axis_resources=out_axis_resources)
 
# Sends data to accelerators based on partition_spec
with maps.Mesh(mesh.devices, mesh.axis_names):
    output_data = f(input_data)



In [16]:
# Output: sharded along Y axis, replicated along X.
output_data, output_data.device_buffers

(ShardedDeviceArray([[  0.,   1.],
                     [  4.,   9.],
                     [ 16.,  25.],
                     [ 36.,  49.],
                     [ 64.,  81.],
                     [100., 121.]], dtype=float32),
 [DeviceArray([[ 0.,  1.],
               [ 4.,  9.],
               [16., 25.]], dtype=float32),
  DeviceArray([[ 36.,  49.],
               [ 64.,  81.],
               [100., 121.]], dtype=float32),
  DeviceArray([[ 0.,  1.],
               [ 4.,  9.],
               [16., 25.]], dtype=float32),
  DeviceArray([[ 36.,  49.],
               [ 64.,  81.],
               [100., 121.]], dtype=float32)])

## (`x`,`y`) axis output sharding

In [17]:
# Requires first axis divisible by mesh size!
input_data = np.arange(mesh_size * mesh_size, dtype=np.float32).reshape(-1, 2)

# No partition for inputs: data is replicated.
in_axis_resources = None
# Output is sharded on x+y axes.
out_axis_resources=PartitionSpec(('x', 'y'), None)

f = pjit(
    # Simple unary op to run
    lambda x: jax.lax.integer_pow(x, 2),
    in_axis_resources=in_axis_resources,
    out_axis_resources=out_axis_resources)
 
# Sends data to accelerators based on partition_spec
with maps.Mesh(mesh.devices, mesh.axis_names):
    output_data = f(input_data)



In [18]:
output_data, output_data.device_buffers

(ShardedDeviceArray([[  0.,   1.],
                     [  4.,   9.],
                     [ 16.,  25.],
                     [ 36.,  49.],
                     [ 64.,  81.],
                     [100., 121.],
                     [144., 169.],
                     [196., 225.]], dtype=float32),
 [DeviceArray([[0., 1.],
               [4., 9.]], dtype=float32),
  DeviceArray([[16., 25.],
               [36., 49.]], dtype=float32),
  DeviceArray([[ 64.,  81.],
               [100., 121.]], dtype=float32),
  DeviceArray([[144., 169.],
               [196., 225.]], dtype=float32)])

# Matmul `pjit` example

In [19]:
M, N, K = 128, 64, 256
M, N, K = 12, 8, 16

lhs = np.random.rand(M, N).astype(np.float32)
rhs = np.random.rand(N, K).astype(np.float32)

In [20]:
# TODO: proper sharding of inputs?
in_axis_resources = None
# Output is sharded on x and y axes.
out_axis_resources=PartitionSpec('x', 'y')

def compute_fn(lhs, rhs):
    return lhs @ rhs

f = pjit(
    compute_fn,
    in_axis_resources=in_axis_resources,
    out_axis_resources=out_axis_resources)

In [21]:
# Sends data to accelerators based on partition_spec
with maps.Mesh(mesh.devices, mesh.axis_names):
    output = f(lhs, rhs)



In [22]:
output

ShardedDeviceArray([[1.5954874 , 1.568143  , 2.1878989 , 2.2802362 ,
                     1.9145454 , 1.326968  , 2.38342   , 1.4342963 ,
                     1.8710132 , 2.2760766 , 1.2393537 , 1.806525  ,
                     1.6943976 , 1.8497736 , 2.0054781 , 1.9822028 ],
                    [2.5935414 , 2.467551  , 3.4812937 , 3.4287057 ,
                     2.6669626 , 2.2495036 , 3.5671082 , 2.6938055 ,
                     2.4568684 , 3.3633797 , 2.2307868 , 3.063488  ,
                     2.4313595 , 2.6677954 , 2.6706545 , 3.0482678 ],
                    [1.2307926 , 1.468817  , 1.5824169 , 1.6121603 ,
                     1.6673343 , 0.95006704, 2.0006866 , 1.3617406 ,
                     1.1706672 , 1.3099413 , 1.0319011 , 1.400886  ,
                     1.2033337 , 1.3988451 , 0.9909877 , 0.90187824],
                    [1.8417573 , 1.8773335 , 1.8307042 , 2.980454  ,
                     2.2116432 , 2.26642   , 2.7127883 , 2.1408713 ,
                     2.8655107 