Copyright (c) 2023 Graphcore Ltd. All rights reserved.

# TessellateIPU - Basics of Tile Mapping on Tensors and Vertices

The IPU is a highly parallel AI accelerator with 1,472 independent cores (also called IPU tiles) connected with an all-to-all IPU-exchange.
Each IPU tile has six independent program threads and 640 KB of local SRAM.

**TessellateIPU** is a library exposing low-level IPU programming primitives in Python, allowing users to take full advantage of the IPU's unique architecture and features. In this tutorial notebook, we present the basics of the TessellateIPU API, showing how to:

* Shard tensors across IPU tiles using `tile_put_replicated` and `tile_put_sharded`;
* Map an IPU vertex (computational kernel) over a sharded tensor using `tile_map`;
* Micro-benchmark TessellateIPU functions by capturing hardware cycle counts;

**Note:** This notebook can be run on IPU hardware (e.g. using [Paperspace Gradient](https://www.paperspace.com/graphcore)) or IPU model simulator (e.g. on a local laptop). Each cell only takes a couple of seconds to execute (excluding the initial TessellateIPU library compilation).

[![Run on Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://console.paperspace.com/github/graphcore-research/tessellate-ipu?container=graphcore%2Fpytorch-jupyter%3A3.2.0-ubuntu-20.04&machine=Free-IPU-POD4&file=%2Fnotebooks%2F01-tessellate-ipu-tile-api-basics.ipynb)


## Dependencies and configuration

Install the JAX experimental for IPU and TessellateIPU libraries.

In [2]:
%pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html
%pip install git+https://github.com/graphcore-research/tessellate-ipu.git@main

Looking in indexes: https://paulb%40graphcore.ai:****@artifactory.sourcevertex.net:443/api/pypi/pypi-virtual/simple, https://pypi.python.org/simple/
Looking in links: https://graphcore-research.github.io/jax-experimental/wheels.html
Note: you may need to restart the kernel to use updated packages.
Looking in indexes: https://paulb%40graphcore.ai:****@artifactory.sourcevertex.net:443/api/pypi/pypi-virtual/simple, https://pypi.python.org/simple/
Collecting git+https://github.com/graphcore-research/tessellate-ipu.git@main
  Cloning https://github.com/graphcore-research/tessellate-ipu.git (to revision main) to /tmp/pip-req-build-kq23yyf4
  Running command git clone --filter=blob:none --quiet https://github.com/graphcore-research/tessellate-ipu.git /tmp/pip-req-build-kq23yyf4
  Resolved https://github.com/graphcore-research/tessellate-ipu.git to commit a70117d4ef70cb818f0a83c740ba90aea95f5ee3
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?2

In [3]:
from jax.config import config

USE_IPU_MODEL = False
if USE_IPU_MODEL or config.FLAGS.jax_ipu_use_model:
    print("Using IPU model")
    config.FLAGS.jax_ipu_use_model = True
    config.FLAGS.jax_ipu_model_num_tiles = 8
    USE_IPU_MODEL = True

# Set to true to see when JAX recompiles - see README for other flags
config.update("jax_log_compiles", False)

import jax

# Check IPU hardware configuration
assert jax.default_backend() == "ipu"
print(f"Platform={jax.default_backend()}")
print(f"Number of devices={jax.device_count()}")

devices = jax.devices()
print("\n".join([str(d) for d in devices]))
device0 = jax.devices()[0]

Platform=ipu
Number of devices=4
IpuDevice(id=0, num_tiles=1472, version=ipu2)
IpuDevice(id=1, num_tiles=1472, version=ipu2)
IpuDevice(id=2, num_tiles=1472, version=ipu2)
IpuDevice(id=3, num_tiles=1472, version=ipu2)


In [4]:
# TessellateIPU module initial compilation may take a couple of minutes...
import tessellate_ipu

## Tile tensor sharding in TessellateIPU

Prior to doing any compute on IPU tiles, you need to decide how to shard the data across tiles.  In normal usage, our IPU frameworks (PyTorch, TensorFlow, and JAX) will automatically decide on the optimal mapping of ML workloads using the Poplar compiler.

TessellateIPU provides two primitives to allow the user to control this mapping directly:
* `tile_put_sharded`: Shards a tensor over the first axis between a set of tiles.
* `tile_put_replicated`: Replicates a tensor on a set of tiles.

Both methods return a `TileShardedArray` Python object, which wraps a common JAX array and explicit (static) tile mapping.
A `TileShardedArray` tensor is always sharded over the first axis, and on-tile shards are contiguous in memory.
Reshaping or permuting the indices of such a tensor will use the all-to-all IPU exchange to efficiently rearrange the data across tiles. 

Here is an example showing how to use these two methods.

In [5]:
import numpy as np

from tessellate_ipu import tile_put_replicated, tile_put_sharded

data = np.random.rand(3, 5).astype(np.float32)


@jax.jit
def compute_fn(data):
    # Shard data on tiles (0, 1, 3)
    t0 = tile_put_sharded(data, (0, 1, 3))
    # Replicate data on tiles (1, 3)
    t1 = tile_put_replicated(data, (1, 3))
    return t0, t1


print("First time compilation may take a few seconds... run the next cell to see the results")
t0, t1 = compute_fn(data)

First time compilation may take a few seconds... run the next cell to see the results


In [6]:
# `t0` has the same shape, just sharded between tiles.
print(f"Tensor `t0` shape {t0.shape} and tile mapping {t0.tiles}.")
# `t1` has an additional replication axis.
print(f"Tensor `t1` shape {t1.shape} and tile mapping {t1.tiles}.")

Tensor `t0` shape (3, 5) and tile mapping (0, 1, 3).
Tensor `t1` shape (2, 3, 5) and tile mapping (1, 3).


`TileShardedArray` tensors support the basic array API, such as slicing and indexing. Note that an error will be raised if the slicing of a tensor will result in non-contiguous on-tile shards.

In [7]:
t3 = t0[1:, 2:5]  # t0 is on tiles (0,1,3), t3 will be on (1,3)
print(f"Tensor `t3` shape {t3.shape} and tile mapping {t3.tiles}.")
# Extract the underlying tensor/array, or convert to a Numpy array
t3.array, np.array(t3)

Tensor `t3` shape (2, 3) and tile mapping (1, 3).


(DeviceArray([[0.78745633, 0.501618  , 0.13274129],
              [0.7968004 , 0.6544128 , 0.8342705 ]], dtype=float32),
 array([[0.78745633, 0.501618  , 0.13274129],
        [0.7968004 , 0.6544128 , 0.8342705 ]], dtype=float32))

## Tile map a vertex using TessellateIPU

Once tensors have been sharded across IPU tiles, you can then map computation kernels (called IPU vertices) over these arrays.
TessellateIPU supports out-of-the-box (part of) [JAX LAX operations](https://jax.readthedocs.io/en/latest/jax.lax.html) by mapping them to pre-existing Graphcore Poplar SDK optimized vertices, allowing you to perform basic operations in a couple of lines.

In the following example, we write a simple `broadcast_add` operation using TessellateIPU. In this broadcast operation, the left hand term is sharded across a collection of tiles whereas the right hand term is broadcasted (so replicated on all tiles). 

In [8]:
import jax
import jax.lax
import numpy as np

from tessellate_ipu import tile_map, tile_put_replicated, tile_put_sharded

lhs_data = np.random.rand(3, 5, 2).astype(np.float32)
rhs_data = np.random.rand(5, 2).astype(np.float32)


@jax.jit
def broadcast_add(lhs, rhs):
    # LHS is size TxMxNx... split onto first T tiles
    T, M, N = lhs.shape
    tiles = tuple(range(len(lhs)))

    # Shard lhs on tiles (0, 1, ..., T)
    lhs = tile_put_sharded(lhs, tiles)
    # Replicate rhs on tiles (0, 1, ..., T)
    rhs = tile_put_replicated(rhs, tiles)

    # Map Poplar optimized `add` vertex to the sharded data.
    out = tile_map(jax.lax.add_p, lhs, rhs)
    return out


out = broadcast_add(lhs_data, rhs_data)

print(f"{out=}")
print(f"{lhs_data + rhs_data - out=}")

out=TileShardedArray(array=DeviceArray([[[1.9107958 , 1.4595628 ],
              [1.0427275 , 1.2379227 ],
              [1.7571661 , 1.1414461 ],
              [0.6959478 , 1.3926532 ],
              [0.7868924 , 0.9820492 ]],

             [[1.2923156 , 0.92205626],
              [1.100035  , 0.54362   ],
              [1.4347286 , 1.1840247 ],
              [1.069698  , 0.64454913],
              [1.2633996 , 0.600798  ]],

             [[1.234113  , 1.377755  ],
              [1.3052844 , 1.3783791 ],
              [1.69265   , 1.4170127 ],
              [1.2864281 , 0.88172984],
              [0.6535058 , 0.9297954 ]]], dtype=float32), tiles=(0, 1, 2))
lhs_data + rhs_data - out=array([[[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]],

       [[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]],

       [[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]]], dtype=float32)


As seen above, `tile_map` will always return `TileShardedArray` objects, with tile mapping deduced from inputs. It will also check that inputs are `TileShardedArray` instances. Since the `TileShardedArray` class insures that data is already sharded on IPU tiles in a contiguous form, `tile_map` has no performance overhead (so there is no implicit on-tile copy or tile exchange).

**Note:** TessellateIPU will always check that the tile mapping is consistent, and will raise an error if it isn't. As the goal of TessellateIPU is to provide a way to write performant and efficient algorithms directly in Python, implicit exchanges between IPU tiles (or on-tile copies) are not allowed.

In [9]:
@jax.jit
def broadcast_add_error(lhs, rhs):
    lhs = tile_put_sharded(lhs, range(len(lhs)))
    rhs = tile_put_replicated(rhs, range(1, len(lhs) + 1))
    out = tile_map(jax.lax.add_p, lhs, rhs)
    return out


# Raise `ValueError`: inconsistent tile mapping!
try:
    broadcast_add_error(lhs_data, rhs_data)
except Exception as e:
    print(f"Good! Raised exception:\n{e}")

Good! Raised exception:
Inconsistent tile mapping between input arrays: (0, 1, 2) vs (1, 2, 3).


TessellateIPU code written using standard JAX LAX primitives remains fully compatible with **other backends (CPU, GPU, TPU)**. `tile_put_sharded` is a no-op on other backends and `tile_put_replicated` is a simple `concatenate` of the input tensor. Finally `tile_map` is translated into a standard JAX `vmap` call.

As a consequence, one can run the identical function on a JAX CPU backend:

In [10]:
# JIT function on CPU backend.
broadcast_add_cpu = jax.jit(broadcast_add, device=jax.devices("cpu")[0])
# Running on CPU.
out_cpu = broadcast_add_cpu(lhs_data, rhs_data)
# Check data & device.
print(f"Output: {out_cpu} on device: {out_cpu.array.device()}")

Output: TileShardedArray(array=DeviceArray([[[1.9107958 , 1.4595628 ],
              [1.0427275 , 1.2379227 ],
              [1.7571661 , 1.1414461 ],
              [0.6959478 , 1.3926532 ],
              [0.7868924 , 0.9820492 ]],

             [[1.2923156 , 0.92205626],
              [1.100035  , 0.54362   ],
              [1.4347286 , 1.1840247 ],
              [1.069698  , 0.64454913],
              [1.2633996 , 0.600798  ]],

             [[1.234113  , 1.377755  ],
              [1.3052844 , 1.3783791 ],
              [1.69265   , 1.4170127 ],
              [1.2864281 , 0.88172984],
              [0.6535058 , 0.9297954 ]]], dtype=float32), tiles=(0, 1, 2)) on device: TFRT_CPU_0


# Micro-benchmarking in TessellateIPU

When writing performant algorithms, micro-benchmarking is a recommended practice to ensure quick progress and no performance regression. TessellateIPU is fully compatible with the [Graphcore PopVision tools](https://www.graphcore.ai/developer/popvision-tools), but also provides a way to directly measure IPU hardware cycle count with the Python function `ipu_cycle_count`.

**Note:** Cycle count is not available on the IPU model simulator. `ipu_cycle_count` will always return a zeroed tensor on the IPU model simulator.

In [11]:
from tessellate_ipu import ipu_cycle_count


@jax.jit
def broadcast_add(lhs, rhs):
    # Tiles to split the workload on.
    tiles = tuple(range(len(lhs)))
    # Shard lhs on tiles (0, 1, ..., N)
    lhs = tile_put_sharded(lhs, tiles)
    # Replicate rhs on tiles (0, 1, ..., N)
    rhs = tile_put_replicated(rhs, tiles)

    # Cycle count after `(lhs,rhs)` are sharded.
    lhs, rhs, start = ipu_cycle_count(lhs, rhs)
    # Map Poplar optimized `add` vertex to the sharded data.
    out = tile_map(jax.lax.add_p, lhs, rhs)
    # Cycle count after `out` is computed.
    out, end = ipu_cycle_count(out)
    return out, start, end


_, start, end = broadcast_add(lhs_data, rhs_data)

print("Start cycle count:", start, start.shape)
print("End cycle count:", end, end.shape)

Start cycle count: TileShardedArray(array=DeviceArray([[539858071,         0],
             [539858071,         0],
             [539858072,         0]], dtype=uint32), tiles=(0, 1, 2)) (3, 2)
End cycle count: TileShardedArray(array=DeviceArray([[539858400,         0],
             [539858400,         0],
             [539858402,         0]], dtype=uint32), tiles=(0, 1, 2)) (3, 2)


The function `ipu_cycle_count` returns raw cycle counts on every IPU tile, directly measured by the hardware. Note that `ipu_cycle_count` takes input arguments and returns them unchanged in order to provide control flow information to XLA and Poplar compilers (it measures cycle counts after these tensors have been computed).

TessellateIPU provides the raw values returned by IPU C++ [get_scount](https://docs.graphcore.ai/projects/poplar-api/en/3.3.0/ipu_intrinsics/ipu_builtins.html#ipu-functionality-and-memory) intrinsics, and this is why `start` and `end` tensors have type `uint32x2xT`.
As shown below, the raw cycle count can be easily translated into time performance figures.
Please note timing measured in this way will differ significantly from simple Python benchmarking of the `broadcast_add` function, as the latter will also include any JAX overhead, as well as all communication between the host and the IPU, and IPU tile exchange.

In [12]:
def cycle_count_to_timing(start, end, ipu_device):
    """Convert raw cycle count into timing."""
    # Lower & upper bounds on cycle count.
    start = np.array(start).view(dtype=np.int64)
    end = np.array(end).view(dtype=np.int64)
    cycle_count_max = np.max(end - start)
    timing = cycle_count_max / ipu_device.tile_clock_frequency
    return timing


timing = cycle_count_to_timing(start, end, device0)
print("Tile map `add` execution time (micro-seconds):", timing * 1e6)

Tile map `add` execution time (micro-seconds): 0.1783783783783784


### Cycle count and IPU tile parallelism

Let's demonstrate IPU tile parallelism in a simple way by using hardware cycle count. The performance of `broadcast_add` should be (roughly) independent of the size of the first axis of the input, as the workload is split uniformly between tiles.
Note that timing can vary slightly depending on the number of tiles used as the tiles are not synchronized when running compute-independent workloads.

In [13]:
if USE_IPU_MODEL:
    print("Benchmarks don't run on IPUModel")
else:
    # Number of tiles to use.
    num_tiles_list = [4, 16, 128, 1024]

    def ipu_benchmark(num_tiles):
        """IPU benchmarking, splitting the workload over a collection of tiles."""
        print(f"Benchmarking on {num_tiles} tiles")
        # Workload size per tile.
        wsize = 1024
        lhs_data = np.random.rand(num_tiles, wsize).astype(np.float32)
        rhs_data = np.random.rand(wsize).astype(np.float32)

        _, start, end = broadcast_add(lhs_data, rhs_data)
        timing = cycle_count_to_timing(start, end, device0)
        return timing

    benchmarks = [ipu_benchmark(N) * 1e6 for N in num_tiles_list]

    # (Roughly) constant timing independently of the number of tiles used.
    print("Benchmark timing (us):", benchmarks)

Benchmarking on 4 tiles
Benchmarking on 16 tiles
Benchmarking on 128 tiles
Benchmarking on 1024 tiles
Benchmark timing (us): [1.0054054054054054, 1.0054054054054054, 1.0054054054054054, 1.0054054054054054]


And that's it: a three-function API showing how to directly map computations onto IPU hardware.

You might like to look next at the [IPU Peak Flops](IPU%20Peak%20Flops.ipynb) notebook, or at the [`demo_vertex.py`](../examples/demo/demo_vertex.py) example.