Copyright (c) 2023 Graphcore Ltd. All rights reserved.

# IPU Peak teraFlops

This notebook shows how you can reach the maximum usages of IPU flops in a couple of lines of Python.

## Environment setup

To run the demo using IPU hardware, you need to have the Poplar SDK enabled {and a PopTorch/TensorFlow wheel installed}. Refer to the [Getting Started guide](https://docs.graphcore.ai/en/latest/getting-started.html#getting-started) for your system for details on how to do this. Also refer to the [Jupyter Quick Start guide](https://docs.graphcore.ai/projects/jupyter-notebook-quick-start/en/latest/index.html) for how to set up Jupyter to be able to run this notebook on a remote IPU machine.

## Dependencies

Import dependencies and define configuration. 

In [80]:
from jax.config import config

# Select how many IPUs will be visible.
config.FLAGS.jax_ipu_device_count = 1

In [2]:
import jax
import jax.lax
import numpy as np
import plotly.graph_objects as go
from tqdm import tqdm

In [3]:
# Need real IPU hardware!
d = jax.devices("ipu")[0]
assert not d.is_ipu_model
d

IpuDevice(id=0, num_tiles=1472, version=ipu2)

In [4]:
num_tiles = d.num_tiles
tiles = tuple(range(num_tiles))
d.tile_clock_frequency

1850000000.0

In [5]:
from tessellate_ipu.tile import IpuConvVertexType, ipu_cycle_count, ipu_cycle_count_overhead, tile_map, tile_put_sharded


def tile_basic_matmul(lhs, rhs):
    """Run a basic matmul on every tile, using the `ConvPartial1x1` vertex.

    Most optimal hardware usage: loading a small weight matrix once in CCCS
    registers, and then performing a "long" AMP pipeline of size N.

    Args:
        lhs: [N, 8/16] array
        rhs: [8/16, 8/16] array.
    Returns:
        Matmul result. [N, 8/16]
    """
    accumulator_dtype = lhs.dtype
    # accumulator_dtype = np.float32
    output = tile_map(
        jax.lax.dot_general_p,
        lhs,
        rhs,
        dimension_numbers=(([1], [1]), ([], [])),
        precision=jax.lax.Precision.DEFAULT,
        preferred_element_type=accumulator_dtype,
        ipu_vertex_type=IpuConvVertexType.ConvPartial1x1,
    )
    return output


def tile_benchmark_matmul(lhs, rhs):
    """Benchmarking matmul IPU."""
    lhs = tile_put_sharded(lhs, tiles)
    rhs = tile_put_sharded(rhs, tiles)
    # Get IPU raw cycle before and after matmul
    lhs, rhs, start = ipu_cycle_count(lhs, rhs)
    out = tile_basic_matmul(lhs, rhs)
    out, end = ipu_cycle_count(out)
    return out, start, end


# CPU jitting to double check the result!
tile_benchmark_matmul_ipu = jax.jit(tile_benchmark_matmul, backend="ipu")
tile_benchmark_matmul_cpu = jax.jit(tile_benchmark_matmul, backend="cpu")

In [6]:
def measure_tile_matmul_flops(N, dtype):
    """Measure IPU flops usage by running small matmuls on every tile."""
    dtype = np.float16
    lhs_size = N
    rhs_size = 16
    contract_size = 16

    lhs_data = np.random.randn(len(tiles), lhs_size, contract_size).astype(dtype)
    rhs_data = np.random.randn(len(tiles), rhs_size, contract_size).astype(dtype)
    # Run independent matmuls on every tile, and measure cycle count.
    _, start, end = tile_benchmark_matmul_ipu(lhs_data, rhs_data)
    tile_cycle_count = np.asarray(end.array) - np.asarray(start.array)
    cycle_count = np.mean(tile_cycle_count[:, 0]) - ipu_cycle_count_overhead()
    # Flops conversion.
    execution_time = cycle_count / d.tile_clock_frequency
    # Count product & sum in matmul as floating point operation.
    num_ops = 2 * lhs_size * rhs_size * contract_size * num_tiles
    tflops = num_ops / execution_time * 1e-12
    return cycle_count, tflops

In [52]:
# Matrix sizes to check!
max_size = 1024 * 2
step_size = 16
sizes = np.arange(16, max_size + step_size, step_size)

In [53]:
# IPU Mk2 peak tflops.
peak_tflops = d.tile_clock_frequency * num_tiles * 128 * 1e-12
peak_tflops = np.array([peak_tflops for _ in sizes])

In [54]:
cycle_counts = []
tflops = []

for s in tqdm(sizes):
    c, f = measure_tile_matmul_flops(s, np.float16)
    cycle_counts.append(c)
    tflops.append(f)

cycle_counts = np.asarray(cycle_counts)
tflops = np.asarray(tflops)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [09:17<00:00,  4.36s/it]


In [55]:
sizes, tflops

(array([  16,   32,   48,   64,   80,   96,  112,  128,  144,  160,  176,
         192,  208,  224,  240,  256,  272,  288,  304,  320,  336,  352,
         368,  384,  400,  416,  432,  448,  464,  480,  496,  512,  528,
         544,  560,  576,  592,  608,  624,  640,  656,  672,  688,  704,
         720,  736,  752,  768,  784,  800,  816,  832,  848,  864,  880,
         896,  912,  928,  944,  960,  976,  992, 1008, 1024, 1040, 1056,
        1072, 1088, 1104, 1120, 1136, 1152, 1168, 1184, 1200, 1216, 1232,
        1248, 1264, 1280, 1296, 1312, 1328, 1344, 1360, 1376, 1392, 1408,
        1424, 1440, 1456, 1472, 1488, 1504, 1520, 1536, 1552, 1568, 1584,
        1600, 1616, 1632, 1648, 1664, 1680, 1696, 1712, 1728, 1744, 1760,
        1776, 1792, 1808, 1824, 1840, 1856, 1872, 1888, 1904, 1920, 1936,
        1952, 1968, 1984, 2000, 2016, 2032, 2048]),
 array([ 55.49366766,  96.70991447, 127.72015878, 152.79763288,
        170.72418021, 186.94235531, 201.23605773, 211.11875736,
      

In [77]:
fig = go.Figure()

# Peak TFlops.
fig.add_trace(
    go.Scatter(
        x=sizes,
        y=peak_tflops,
        mode="lines",
        name="IPU Mk2 Peak TFlops",
        line=go.scatter.Line(dash="dot"),
        showlegend=True,
    )
)
# FP16 TFlops.
fig.add_trace(
    go.Scatter(
        x=sizes,
        y=tflops,
        mode="markers",
        name="FP16[N,16] @ FP16[16,16] -> FP16",
        marker=go.scatter.Marker(size=5),
        showlegend=True,
    )
)

fig.update_layout(
    font_family="Courier New",
    title_text="IPU on-tile AMP micro-benchmarking (TFLOPS)",
    xaxis_title="N (input shape: [N,16])",
    yaxis_title="TFLOPS",
)
fig.update_xaxes(tickvals=[16, 64, 128, 256, 512, 768, 1024, 1536, 2048])
fig.show()