Copyright (c) 2023 Graphcore Ltd. All rights reserved.

# Tessellate IPU - Basics of tile mapping on tensors and vertex

The IPU is a highly parallel architecture with 1472 independent IPU-core (also called IPU tiles) connected with an all-to-all IPU-exchange. Each IPU-tile has 6 independent program threads and 639kB of local SRAM available.

**Tessellate IPU** is a library exposing low-level IPU programming primitives in Python, allowing users to take full advantage of the IPU unique architecture and features. In this tutorial notebook, we present the basics of Tessellate IPU API, learning how to:

* Shard tensors/arrays between IPU tiles using `tile_put_replicated` and `tile_put_sharded`;
* Map an IPU vertex (i.e. base function) on sharded tensor using `tile_map`;
* Micro-benchmarks Tessellate IPU functions by capturing hardware cycle counts;

## Dependencies and configuration

Install JAX experimental for IPU & Tessellate IPU library.

In [1]:
!pip install -q jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html
!pip install -q git+https://github.com/graphcore-research/tessellate-ipu.git@main

In [2]:
import jax

# Uncomment to use IPU model emulator.
# from jax.config import config
# config.FLAGS.jax_ipu_use_model = True
# config.FLAGS.jax_ipu_model_num_tiles = 8

In [3]:
# Check IPU hardware configuration
print(f"Platform={jax.default_backend()}")
print(f"Number of devices={jax.device_count()}")

devices = jax.devices()
print("\n".join([str(d) for d in devices]))
d = jax.devices()[0]

Platform=ipu
Number of devices=4
IpuDevice(id=0, num_tiles=1472, version=ipu2)
IpuDevice(id=1, num_tiles=1472, version=ipu2)
IpuDevice(id=2, num_tiles=1472, version=ipu2)
IpuDevice(id=3, num_tiles=1472, version=ipu2)


## Tile tensor sharding in Tessellate IPU

Prior to doing any compute on IPU tiles, one needs to decide how to shard the data between tiles. Poptorch, TensorFlow and JAX frameworks rely on Poplar compiler to automatically decide on the optimal mapping of ML workloads. Tessellate IPU provides two primitives to allow the user to control this mapping directly:
* `tile_put_sharded`: Shard a tensor over the first axis between a set of tiles;
* `tile_put_replicated`: Replicate a tensor on a set of tiles;

Both methods return a `TileShardedArray` Python object, which wraps a common JAX array and explicit (static) tile mapping. By convention, a `TileShardedArray` tensor is always sharded over the first axis, and on-tile shards are contiguous in memory.

Here is an example on how to use this two methods.

In [4]:
import jax
import numpy as np

from tessellate_ipu import tile_put_replicated, tile_put_sharded

data = np.random.rand(3, 5).astype(np.float32)


@jax.jit
def compute_fn(data):
    # Shard data on tiles (0, 1, 3)
    t0 = tile_put_sharded(data, (0, 1, 3))
    # Replicate data on tiles (1, 3)
    t1 = tile_put_replicated(data, (1, 3))
    return t0, t1


t0, t1 = compute_fn(data)

In [5]:
# `t0` has the same shape, just sharded between tiles.
print(f"Tensor `t0` shape {t0.shape} and tile mapping {t0.tiles}.")
# `t1` has an additional replication axis.
print(f"Tensor `t1` shape {t1.shape} and tile mapping {t1.tiles}.")

Tensor `t0` shape (3, 5) and tile mapping (0, 1, 3).
Tensor `t1` shape (2, 3, 5) and tile mapping (1, 3).


`TileShardedArray` tensors support the basic array API, such as slicing and indexing. Note that an error will be raised if the slicing of a tensor would result into non-contiguous on tile shards.

In [6]:
t3 = t0[1:, 2:5]
print(f"Tensor `t3` shape {t3.shape} and tile mapping {t3.tiles}.")
# Extract the underlying tensor/array, or convert to a Numpy array
t3.array, np.array(t3)

Tensor `t3` shape (2, 3) and tile mapping (1, 3).


(DeviceArray([[0.41966587, 0.14291303, 0.10358273],
              [0.1413389 , 0.7546814 , 0.7027907 ]], dtype=float32),
 array([[0.41966587, 0.14291303, 0.10358273],
        [0.1413389 , 0.7546814 , 0.7027907 ]], dtype=float32))

## Tile map a vertex using Tessellate IPU

Once tensors has been sharded across IPU tiles, users can map computation kernels (called IPU vertices) to these arrays. Tessellate IPU supports out of the box (part of) [JAX LAX operations](https://jax.readthedocs.io/en/latest/jax.lax.html) by mapping them to pre-existing Graphcore Poplar SDK optimized vertices, allowing to perform basic operations in a couple of lines.

In the following example, we write a simple `broadcast_add` using Tessellate IPU. In this broadcast operation, the left hand term is sharded across a collection of tiles whereas the right hand term is broadcasted (i.e. replicated on all tiles). 

In [None]:
import jax
import jax.lax
import numpy as np

from tessellate_ipu import tile_map, tile_put_replicated, tile_put_sharded

lhs_data = np.random.rand(3, 5).astype(np.float32)
rhs_data = np.random.rand(5).astype(np.float32)


@jax.jit
def broadcast_add(lhs, rhs):
    # Tiles to split the workload on.
    tiles = tuple(range(len(lhs)))
    # Shard lhs on tiles (0, 1, ..., N)
    lhs = tile_put_sharded(lhs, tiles)
    # Replicate rhs on tiles (0, 1, ..., N)
    rhs = tile_put_replicated(rhs, tiles)

    # Map Poplar optimized `add` vertex to the sharded data.
    out = tile_map(jax.lax.add_p, lhs, rhs)
    return out


out = broadcast_add(lhs_data, rhs_data)

print("Output:", out)
print("Excepted output:", lhs_data + rhs_data)

Output: TileShardedArray(array=DeviceArray([[0.60848534, 1.0852832 , 0.90640026, 1.0351304 , 1.1440121 ],
             [0.8170862 , 1.2950337 , 1.0239053 , 0.484455  , 1.3905538 ],
             [0.43845683, 0.9908368 , 1.6852813 , 0.5040692 , 0.71705866]],            dtype=float32), tiles=(0, 1, 2))
Excepted output: [[0.60848534 1.0852832  0.90640026 1.0351304  1.1440121 ]
 [0.8170862  1.2950337  1.0239053  0.484455   1.3905538 ]
 [0.43845683 0.9908368  1.6852813  0.5040692  0.71705866]]


As seen above, `tile_map` will always return `TileShardedArray` objects, with tile mapping deduced from inputs. It will as well check that inputs are `TileShardedArray` instances. Since the `TileShardedArray` class insures that data is already sharded on IPU tiles in a contiguous form, `tile_map` has no performance overhead (i.e. no implicit on-tile-copy or tile exchange).

**Note:** Tessellate IPU will always check that the tile mapping is consistent, and will raise an error if not. As the goal of Tessellate IPU is to provide a way to write performant & efficient algorithms directly in Python, implicit exchange between IPU tiles (or on-tile-copy) is not allowed.

In [8]:
@jax.jit
def broadcast_add_error(lhs, rhs):
    lhs = tile_put_sharded(lhs, range(len(lhs)))
    rhs = tile_put_replicated(rhs, range(1, len(lhs) + 1))
    out = tile_map(jax.lax.add_p, lhs, rhs)
    return out


# Raise `ValueError`: inconsistent tile mapping!
try:
    broadcast_add_error(lhs_data, rhs_data)
except Exception as e:
    print("Tessellate error!", e)

Tessellate error! Inconsistent tile mapping between input arrays: (0, 1, 2) vs (1, 2, 3).


Tessellate code written using standard JAX LAX primitives remains fully compatible with **other backends (CPU, GPU, TPU)**. `tile_put_sharded` is a no-op on other backends and `tile_put_replicated` is a simple `concatenate` of the input tensor. Finally `tile_map` is translated into a standard JAX `vmap` call.

As a consequence, one can run the exact same function on JAX CPU backend:

In [9]:
# JIT function on CPU backend.
broadcast_add_cpu = jax.jit(broadcast_add, device=jax.devices("cpu")[0])
# Running on CPU.
out_cpu = broadcast_add_cpu(lhs_data, rhs_data)
# Check data & device.
print(f"Output: {out_cpu} on device: {out_cpu.array.device()}")

Output: TileShardedArray(array=DeviceArray([[0.60848534, 1.0852832 , 0.90640026, 1.0351304 , 1.1440121 ],
             [0.8170862 , 1.2950337 , 1.0239053 , 0.484455  , 1.3905538 ],
             [0.43845683, 0.9908368 , 1.6852813 , 0.5040692 , 0.71705866]],            dtype=float32), tiles=(0, 1, 2)) on device: TFRT_CPU_0


# Micro-benchmarking in Tessellate IPU

When writing performant algorithms, micro-benchmarking is a recommended practice to ensure quick progress and no performance regression. Tessellate IPU is fully compatible with [Graphcore Popvision tools](https://www.graphcore.ai/developer/popvision-tools), but also provides a way to directly measure IPU hardware cycle count with the Python function `ipu_cycle_count`.

**Note:** Cycle count is not available on the IPU model simulator, `ipu_cycle_count` will always return a zeroed tensor on the latter.

In [11]:
from tessellate_ipu import ipu_cycle_count


@jax.jit
def broadcast_add(lhs, rhs):
    # Tiles to split the workload on.
    tiles = tuple(range(len(lhs)))
    # Shard lhs on tiles (0, 1, ..., N)
    lhs = tile_put_sharded(lhs, tiles)
    # Replicate rhs on tiles (0, 1, ..., N)
    rhs = tile_put_replicated(rhs, tiles)

    # Cycle count once inputs are sharded.
    lhs, rhs, start = ipu_cycle_count(lhs, rhs)
    # Map Poplar optimized `add` vertex to the sharded data.
    out = tile_map(jax.lax.add_p, lhs, rhs)
    # Cycle count after output is computed.
    out, end = ipu_cycle_count(out)
    return out, start, end


_, start, end = broadcast_add(lhs_data, rhs_data)

print("Start cycle count:", start, start.shape)
print("End cycle count:", end, end.shape)

Start cycle count: TileShardedArray(array=DeviceArray([[391787498,         0],
             [391787499,         0],
             [391787499,         0]], dtype=uint32), tiles=(0, 1, 2)) (3, 2)
End cycle count: TileShardedArray(array=DeviceArray([[391787852,         0],
             [391787853,         0],
             [391787853,         0]], dtype=uint32), tiles=(0, 1, 2)) (3, 2)


The function `ipu_cycle_count` returns raw cycle counts on every IPU-tile, directly measured by the hardware. Note that `ipu_cycle_count` takes input arguments and return them unchanged in order to provide control flow information to XLA and Poplar compilers (i.e. measure cycle counts after these tensors have been computed).


Tessellate provides the raw values returned by IPU C++ intrinsics (https://docs.graphcore.ai/projects/poplar-api/en/2.4.0/ipu_builtins.html#get-count-l-from-csr), hence why `start` and `end` tensors have `uint32` as datatype. As shown below, the raw cycle count can be easily translated into time performance figures. Please note timing measured this way will differ massively from simple Python benchmarking of the function `broadcast_add`, as the latter will also include any JAX overhead, host to/from IPU communications and IPU tile exchange!

In [16]:
def cycle_count_to_timing(start, end, ipu_device):
    """Convert raw cycle count into timing."""
    # Lower & upper bounds on cycle count.
    start_min = np.min(start[:, 0])
    end_max = np.max(end[:, 0])
    cycle_count_diff = end_max - start_min
    timing = cycle_count_diff / d.tile_clock_frequency
    return timing


timing = cycle_count_to_timing(start, end, d)
print("Tile map `add` execution time (micro-seconds):", timing * 1e6)

Tile map `add` execution time (micro-seconds): 0.1918918918918919


### Cycle count & IPU tile parallelism

Let's demonstrate IPU tile parallelism in a simple way by using hardware cycle count: `broadcast_add` performance should (roughly) independent of the input first axis size, as the workload is splitted uniformly between tiles. Note that timing can slighly varied depending on the number of tiles used as the latter are not synchronized when running compute independent workloads.

In [19]:
# Number of tiles to use.
num_tiles_list = [4, 16, 128, 1024]


def ipu_benchmark(num_tiles):
    """IPU benchmarking, splitting the workload over a collection of tiles."""
    # Workload size per tile.
    wsize = 1024
    lhs_data = np.random.rand(num_tiles, wsize).astype(np.float32)
    rhs_data = np.random.rand(wsize).astype(np.float32)

    _, start, end = broadcast_add(lhs_data, rhs_data)
    timing = cycle_count_to_timing(start, end, d)
    return timing


benchmarks = []
try:
    benchmarks = [ipu_benchmark(N) * 1e6 for N in num_tiles_list]
except:
    print("Not working on using IPU model!")

# (Roughly) constant timing independently of the number of tiles used.
print("Benchmark timing (us):", benchmarks)

Benchmark timing (us): [1.0043243243243243, 1.0075675675675677, 1.0178378378378379, 1.0475675675675675]
