# Multi-IPUs in JAX: `pmap` quickstart

JAX experimental on IPUs supports multiple IPUs, and collective operations between them (with some limitations on the topology).

This guide is directly inspired by JAX official documentation and multi devices examples.

In [1]:
# Install experimental JAX for IPUs (SDK 3.1) from Github releases.
import sys
!{sys.executable} -m pip uninstall -y jax jaxlib
!{sys.executable} -m pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk310 -f https://graphcore-research.github.io/jax-experimental/wheels.html

Found existing installation: jax 0.3.16+ipu
Uninstalling jax-0.3.16+ipu:
  Successfully uninstalled jax-0.3.16+ipu
Found existing installation: jaxlib 0.3.15+ipu.sdk310
Uninstalling jaxlib-0.3.15+ipu.sdk310:
  Successfully uninstalled jaxlib-0.3.15+ipu.sdk310


In [2]:
from jax.config import config

# Select how many IPUs will be visible.
config.FLAGS.jax_ipu_device_count = 4

# Simulating `pmap` on CPU devices instead.
# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
# config.FLAGS.jax_platforms = "cpu,ipu"

In [3]:
import numpy as np
import jax
import jax.numpy as jnp

In [7]:
print(f"Platform={jax.default_backend()}")
print(f"Number of devices={jax.device_count()}")
ipu_devices = jax.devices("ipu")
ipu_devices

Platform=ipu
Number of devices=4


[IpuDevice(id=0, num_tiles=1472, version=ipu2),
 IpuDevice(id=1, num_tiles=1472, version=ipu2),
 IpuDevice(id=2, num_tiles=1472, version=ipu2),
 IpuDevice(id=3, num_tiles=1472, version=ipu2)]

# Basic `pmap`: pure map with no communication

In [8]:
from jax import pmap
from jax import lax

In [9]:
num_devices = len(jax.devices())
N = 3

In [10]:
@pmap
def square(x):
    return x ** 2

In [11]:
data = np.arange(N * num_devices, dtype=np.float32).reshape((num_devices, -1))
# First call triggers compilation, which can take a bit of time.
output = square(data)

print("INPUT:", data)
print("OUTPUT:", output, type(output))



INPUT: [[ 0.  1.  2.]
 [ 3.  4.  5.]
 [ 6.  7.  8.]
 [ 9. 10. 11.]]
OUTPUT: [[  0.   1.   4.]
 [  9.  16.  25.]
 [ 36.  49.  64.]
 [ 81. 100. 121.]] <class 'jaxlib.xla_extension.pmap_lib.ShardedDeviceArray'>


In [12]:
# Second call is fast, with code already compiled and loaded on IPU devices.
output = square(output)
print("OUTPUT:", output)

OUTPUT: [[0.0000e+00 1.0000e+00 1.6000e+01]
 [8.1000e+01 2.5600e+02 6.2500e+02]
 [1.2960e+03 2.4010e+03 4.0960e+03]
 [6.5610e+03 1.0000e+04 1.4641e+04]]


In [13]:
# Buffers of output sharded array living on different IPUs.
[b.device() for b in output.device_buffers]

[IpuDevice(id=0, num_tiles=1472, version=ipu2),
 IpuDevice(id=1, num_tiles=1472, version=ipu2),
 IpuDevice(id=2, num_tiles=1472, version=ipu2),
 IpuDevice(id=3, num_tiles=1472, version=ipu2)]

In [14]:
# Can move on the fly to a different IPU.
b = jax.device_put(output[1], ipu_devices[3])
b, b.device()

(DeviceArray([ 81., 256., 625.], dtype=float32),
 IpuDevice(id=3, num_tiles=1472, version=ipu2))

In [15]:
output[0:1]

DeviceArray([[ 0.,  1., 16.]], dtype=float32)

# Collective communication operations

JAX on IPU collective operations are implemented using Graphcore GCL library (https://docs.graphcore.ai/projects/poplar-user-guide/en/latest/gcl.html). Similarly to TPUs, some restrictions on the IPU mesh topology apply.

## Single `pmap` reduction across all devices

In [16]:
from functools import partial

@partial(pmap, axis_name="i")
def normalize(x):
    return x / lax.psum(x, axis_name="i")

In [17]:
# First call: compile & load on IPU devices.
output = normalize(data[0:2])
output



ShardedDeviceArray([[0.        , 0.2       , 0.2857143 ],
                    [1.        , 0.8       , 0.71428573]], dtype=float32)

In [18]:
# Proper normalization across IPUs!
np.sum(np.asarray(output), axis=0)

array([1., 1., 1.], dtype=float32)

## `pmap` reduction across different replica groups

A typical usecase of multiple `pmap` axes is combining data parallelism and tensor parallelism.

In [19]:
@partial(pmap, axis_name='rows')
@partial(pmap, axis_name='cols')
def f(x):
    row_normed = x / lax.psum(x, 'rows')
    col_normed = x / lax.psum(x, 'cols')
    doubly_normed = x / lax.psum(x, ('rows', 'cols'))
    return row_normed, col_normed, doubly_normed

x = np.arange(4., dtype=np.float32).reshape((2, 2))
outputs = f(x)

print("OUTPUTS:", repr(outputs))



OUTPUTS: (ShardedDeviceArray([[0.  , 0.25],
                    [1.  , 0.75]], dtype=float32), ShardedDeviceArray([[0. , 1. ],
                    [0.4, 0.6]], dtype=float32), ShardedDeviceArray([[0.        , 0.16666667],
                    [0.33333334, 0.5       ]], dtype=float32))


# Manual data sharding

In [23]:
sub_devices0 = ipu_devices[:2]
sub_devices1 = ipu_devices[2:]

indata0 = jax.device_put_sharded([v for v in data[:2]], sub_devices0)
indata1 = jax.device_put_sharded([v for v in data[2:]], sub_devices1)

In [24]:
def normalize_fn(x):
    return x / lax.psum(x, 'i')

def normalize_fn2(x):
    return x / lax.pmean(x, 'i')

normalize0 = pmap(normalize_fn, axis_name='i', devices=sub_devices0)
normalize1 = pmap(normalize_fn2, axis_name='i', devices=sub_devices1)

In [25]:
out0 = normalize0(indata0)
out1 = normalize1(indata1)

out0, out1



(ShardedDeviceArray([[0.        , 0.2       , 0.2857143 ],
                     [1.        , 0.8       , 0.71428573]], dtype=float32),
 ShardedDeviceArray([[0.8       , 0.8235294 , 0.84210527],
                     [1.2       , 1.1764706 , 1.1578947 ]], dtype=float32))

In [26]:
[b.device() for b in out0.device_buffers], [b.device() for b in out1.device_buffers]

([IpuDevice(id=0, num_tiles=1472, version=ipu2),
  IpuDevice(id=1, num_tiles=1472, version=ipu2)],
 [IpuDevice(id=2, num_tiles=1472, version=ipu2),
  IpuDevice(id=3, num_tiles=1472, version=ipu2)])

## Collective `permute` between IPUs

Potentially useful for compiling a pipeline on a Transformer model.

In [21]:
from jax._src.lib import xla_bridge
device_count = jax.device_count()

def send_right(x, axis_name):
    left_perm = [(i, (i + 1) % device_count) for i in range(device_count)]
    return lax.ppermute(x, perm=left_perm, axis_name=axis_name)

def send_left(x, axis_name):
    left_perm = [((i + 1) % device_count, i) for i in range(device_count)]
    return lax.ppermute(x, perm=left_perm, axis_name=axis_name)

def update_board(board):
    left = board[:-2]
    right = board[2:]
    center = board[1:-1]
    return lax.bitwise_xor(left, lax.bitwise_or(center, right))

@partial(pmap, axis_name='i')
def step(board_slice):
    left, right = board_slice[:1], board_slice[-1:]
    right, left = send_left(left, 'i'), send_right(right, 'i')
    enlarged_board_slice = jnp.concatenate([left, board_slice, right])
    return update_board(enlarged_board_slice)

def print_board(board):
    print(''.join('*' if x else ' ' for x in np.asarray(board).ravel()))

In [22]:
board = np.zeros(40, dtype=bool)
board[board.shape[0] // 2] = True
reshaped_board = board.reshape((device_count, -1))

print_board(reshaped_board)
for _ in range(20):
    reshaped_board = step(reshaped_board)
    print_board(reshaped_board)

                    *                   




                   ***                  
                  **  *                 
                 ** ****                
                **  *   *               
               ** **** ***              
              **  *    *  *             
             ** ****  ******            
            **  *   ***     *           
           ** **** **  *   ***          
          **  *    * **** **  *         
         ** ****  ** *    * ****        
        **  *   ***  **  ** *   *       
       ** **** **  *** ***  ** ***      
      **  *    * ***   *  ***  *  *     
     ** ****  ** *  * *****  *******    
    **  *   ***  **** *    ***      *   
   ** **** **  ***    **  **  *    ***  
  **  *    * ***  *  ** *** ****  **  * 
 ** ****  ** *  ******  *   *   *** ****
 *  *   ***  ****     **** *** **   *   
