In [1]:
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.

import numpy as np
import jax
import jax.numpy as jnp
import ml_dtypes
import gfloat
from gfloat.formats import format_info_ocp_e5m2
from timeit import Timer

jax.config.update("jax_enable_x64", True)

%load_ext autoreload
%autoreload 2

# Timing tests

The `gfloat` library is designed for readability over performance, and the reference code for computations is the (slow) scalar code e.g. `round_float`.

There are vectorized implementations (e.g. `round_ndarray`), and when combined with JAX, these can go reasonably fast.

Let's see how long it takes to encode some values to FP8...

In [2]:
N = 1_000_000
a = np.random.rand(N)

jax_round_jit = jax.jit(lambda x: gfloat.round_ndarray(format_info_ocp_e5m2, x, np=jnp))
ja = jnp.array(a)
jax_round_jit(ja)  # Cache compilation


def slow_round_ndarray(fi, a):
    return np.array([gfloat.round_float(fi, x) for x in a])


def time(f, problem_size=1.0):
    units = 1e9  # nsec
    t = Timer(f)
    n = t.autorange()[0] * 10  # About 2 sec per run
    ts = t.repeat(repeat=3, number=n)  # best of 3
    ts = [((t / n) / problem_size) * units for t in ts]  # per run
    return f"{min(ts):8.2f} nsec ({n} runs at size {problem_size})"


# fmt: off
print("GFloat scalar                  :", time(lambda: slow_round_ndarray(format_info_ocp_e5m2, a[: N // 100]), N // 100))
print("GFloat vectorized, numpy arrays:", time(lambda: gfloat.round_ndarray(format_info_ocp_e5m2, a), N))
print("GFloat vectorized, JAX JIT     :", time(lambda: jax_round_jit(ja), N))
print("ML_dtypes                      :", time(lambda: a.astype(ml_dtypes.float8_e5m2), N))

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


GFloat scalar                  :  6666.04 nsec (50 runs at size 10000)
GFloat vectorized, numpy arrays:    57.84 nsec (50 runs at size 1000000)
GFloat vectorized, JAX JIT     :     3.17 nsec (1000 runs at size 1000000)
ML_dtypes                      :     2.92 nsec (1000 runs at size 1000000)


On one CPU platform the timings were:
```
GFloat scalar                  :  6996.75 nsec (50 runs at size 10000)
GFloat vectorized, numpy arrays:    75.04 nsec (50 runs at size 1000000)
GFloat vectorized, JAX JIT     :     3.18 nsec (1000 runs at size 1000000)
ML_dtypes                      :     3.13 nsec (1000 runs at size 1000000)
```
So the JAX JIT code is ~1000x faster than the scalar code, and comparable to `ml_dtypes`'s C++ CPU implementation.