In [1]:
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.

import numpy as np
import jax
import jax.numpy as jnp
import ml_dtypes
import gfloat
from gfloat.formats import format_info_ocp_e5m2

%load_ext autoreload
%autoreload 2

## Timing tests

The `gfloat` library is designed for readability over performance, and the reference code for computations is the (slow) scalar code e.g. `round_float`.  There are vectorized implementations (e.g. `round_ndarray`, and when combined with JAX, these can go reasonably fast).

Let's try to convert some values to FP8:

In [2]:
N = 100_000
a = np.random.rand(N)

jax_round_jit = jax.jit(lambda x: gfloat.round_ndarray(format_info_ocp_e5m2, x, np=jnp))
ja = jnp.array(a)
jax_round_jit(ja)  # Cache compilation


def slow_round_ndarray(fi, a):
    return np.array([gfloat.round_float(fi, x) for x in a])


print("GFloat scalar                  :", end="")
%timeit slow_round_ndarray(format_info_ocp_e5m2, a)

print("GFloat vectorized, numpy arrays:", end="")
%timeit gfloat.round_ndarray(format_info_ocp_e5m2, a)

print("GFloat vectorized, JAX JIT     :", end="")
%timeit jax_round_jit(ja)

print("ML_dtypes                      :", end="")
%timeit a.astype(ml_dtypes.float8_e5m2)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)


GFloat scalar                  :616 ms ± 23.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
GFloat vectorized, numpy arrays:4.49 ms ± 255 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
GFloat vectorized, JAX JIT     :596 µs ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
ML_dtypes                      :266 µs ± 16.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


On one CPU platform the timings were:
```
GFloat scalar                  :629     ms ± 22.3 ms 
GFloat vectorized, numpy arrays:  4.420 ms ± 153 µs 
GFloat vectorized, JAX JIT     :    585 µs ± 13.7 µs 
ML_dtypes                      :    253 µs ± 12 µs 
```
So the JAX JIT code is 1000x faster than the scalar code, although `ml_dtypes`'s C++ is 2-3x faster still.