# JAX FP8 (fused) matmul tutorial

## Quickstart: FP8 in deep learning

The latest generation of machine learning hardware ([Nvidia H100](https://www.nvidia.com/en-gb/data-center/h100/), [AMD MI300X](https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html), [Graphcore C600](https://www.graphcore.ai/products/c600), ...) have integrated direct **FP8** support in the hardware, improving energy efficiency and throughput.

As shown the low precision ML literature, two distinct formats are necessary to support to achieve similar accuracy to `bfloat16` (or `float16`) training: **`E4M3`** and **`E5M2`** `float8` formats. As presented below, the two formats differ in the trade-off between precision (i.e. mantissa bits) and dynamic range (i.e. exponent bits). In short, `E4M3` is used for storing **weights** and **activations** whereas `E5M2` for representing backward **gradients** (which require a higher dynamic range).

![image](img/fp-formats.webp)

Note that **different variations** of `E4M3` and `E5M2` exist in the literature, depending on whether infinities, NaN or negative zero have special encodings reserved (see below in the references). The Python library [`ml_dtypes`](https://github.com/jax-ml/ml_dtypes) implements these different 8-bit floating point representations as NumPy extensions.

These two new FP8 formats introduced a major hardware difference compared to FP16 and BF16 support: FP8 hardware needs to support mixed input matrix multiplication (i.e. `E4M3 @ E5M2`) for the model training backward pass. 

In this tutorial notebook, we investigate how the ML stack JAX + XLA handles the specificities of **FP8 matmuls**, while still generating an optimal fused kernel call including:
* FP8 inputs scaling;
* FP8 output scaling & clamping;
* Non-linearity & bias fusing;
* Abs-max output capture;

## FP8 E4M3 and E5M2 format datatypes

`E4M3` and `E5M2` datatype formats have been integrated in major ML frameworks (e.g. PyTorch and JAX), and can be used as any other classic NumPy dtype. [`ml_dtypes`](https://github.com/jax-ml/ml_dtypes) provides floating point information for these FP8 formats, showing in particular the small dynamic range of `E4M3` datatype (i.e. ±448) compared to `E5M2` (i.e. ±57344).

In [1]:
import ml_dtypes
import numpy as np

import jax
import jax.numpy as jnp

# Note: using the E4M3 format without +-infinity encodings.
print(ml_dtypes.finfo(jnp.float8_e4m3fn))
# Note: E5M3 format with infinities and NaNs encodings, in line with FP16 IEEE standard.
print(ml_dtypes.finfo(jnp.float8_e5m2))

Machine parameters for float8_e4m3fn
---------------------------------------------------------------
precision =   1   resolution = 1.00e-01
machep =     -3   eps =        1.25e-01
negep =      -4   epsneg =     6.25e-02
minexp =     -6   tiny =       1.56e-02
maxexp =      9   max =        4.48e+02
nexp =        4   min =        -max
smallest_normal = 1.56e-02   smallest_subnormal = 1.95e-03
---------------------------------------------------------------

Machine parameters for float8_e5m2
---------------------------------------------------------------
precision =   1   resolution = 1.00e-01
machep =     -2   eps =        2.50e-01
negep =      -3   epsneg =     1.25e-01
minexp =    -14   tiny =       6.10e-05
maxexp =     16   max =        5.73e+04
nexp =        5   min =        -max
smallest_normal = 6.10e-05   smallest_subnormal = 1.53e-05
---------------------------------------------------------------



## FP8 matmul in JAX: the simple case

With FP8 datatypes added in JAX, basic FP8 matrix multiplication is supported out-of-the-box. As highlighted above, it also means support for **mixed** `E4M3 @ E5M2` FP8 matmuls.

In [2]:
key = jax.random.PRNGKey(4352)
# Random FP8 inputs.
a = jax.random.normal(key, (32, 64), jnp.float8_e4m3fn)
b = jax.random.normal(key, (128, 64), jnp.float8_e4m3fn)

# E4M3 matrix multiplication (NOTE: transpose to reduce on last axis on both inputs).
c = jax.lax.dot(a, b.T)
print("E4M3 @ E4M3 FP8 matmul output:", c.aval)

# E4M3/E5M2 mixed matrix multiplication  (NOTE: transpose to reduce on last axis on both inputs).
c = jax.random.normal(key, (128, 64), jnp.float8_e5m2)
d = jax.lax.dot(a, c.T)
# Note: default output dtype is E5M2.
print("E4M3 @ E5M2 FP8 matmul output:", d.aval)

2024-09-25 14:11:11.258006: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


E4M3 @ E4M3 FP8 matmul output: ShapedArray(float8_e4m3fn[32,128])
E4M3 @ E5M2 FP8 matmul output: ShapedArray(float8_e5m2[32,128])


### FP8 matmul compiled HLO

Let's have a look at the compiled HLO module generated by JAX + XLA on latest generation GPUs: the XLA compiler recognizes an FP8 matrix multiplication and generates (on GPUs) a `custom_call` to the target **`__cublas$lt$matmul$f8`**, mapping to the FP8 [**`cublasLtMatmul`**](https://docs.nvidia.com/cuda/cublas/#cublasltmatmul) API (note: it will work similarly on other hardware platforms).

In [3]:
from jax_scalify.utils import print_hlo_module, parse_hlo_module

def matmul_fn(a_fp8, b_fp8):
    # FP8 x FP8 -> FP8 matmul
    return jax.lax.dot(a_fp8, b_fp8.T)

# AOT compilation with JAX, inspecting the (final) HLO module generated.
fn_compiled = jax.jit(matmul_fn).lower(a, b).compile()
# (Human readable) optimized Hlo module generated by XLA (ignoring GEMM backend config).
print_hlo_module(fn_compiled, backend_cfg=False)

HloModule jit_matmul_fn, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0})->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="f27e70a56b27e0bfb1ec7095f85081ca"}

ENTRY %main.5 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64]) -> f8e4m3fn[32,128] {
  %constant_1 = f32[] constant(1)
  %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)
  %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)
  %cublas-gemm.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %constant_1, f32[] %constant_1, f32[] %constant_1, /*index=5*/f32[] %constant_1), custom_call_target="__cublas$lt$matmul$f8"
  ROOT %get-tuple-element.1 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.1.0), index=0
}




One first interesting aspect of the custom call **`__cublas$lt$matmul$f8`** is that it takes **6 input arguments**: the first two are the classic matmul inputs, and the other four are FP32 scalars (set to a constant `%constant_1 = f32[] constant(1)` in this case).

The field `backend_config` in **`__cublas$lt$matmul$f8`** provides additional metadata passed to the GEMM API.

In [4]:
from IPython.display import JSON

hlo_module = parse_hlo_module(fn_compiled)
# Let's extract the `backend_config` dict from the FP8 matmul call.
backend_config = next((m.backend_config for m in hlo_module if "__cublas$lt$matmul$f8" in m.cmd))
JSON(backend_config, expanded=True)

<IPython.core.display.JSON object>

A couple of fields are of interest to us for FP8 matmuls:

* **`alpha_real`**, **`alpha_imag`** and **`beta`**: constant scaling factors which can be integrated into the matrix multiplication:
$$
D = \alpha \cdot (A @ B) + \beta \cdot C
$$
**Note:** these are different from the scalar FP32 tensors presented above! 
* **`epilogue`**: enum field describing fusing of post-matmul operation such as adding bias or non-linearity (see below).
* **`damax_output`**: a new FP8 matmul feature: computation of the absolute reduce-max of the output (useful for output re-scaling).

## Fused FP8 matmul in JAX: from simple to complicated!

As presented above, the FP8 XLA custom target **`__cublas$lt$matmul$f8`** has an extended API & config allowing **fusing** multiple operations in the GEMM kernel. More specifically:
* Scaling of input & output tensors;
* Capturing absolute-maximum of the output (usually called `damax`);
* Post-matmul bias or/and non-linearity;

We present below how to generate the proper fused matmul call directly from JAX (and checking the result in the compiled HLO!). Starting with inputs & outputs scaling, following the interface of **`__cublas$lt$matmul$f8`**. 

Let's first try with a naive implementation:

In [5]:
def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, c_scale):
    # First try: can we just scale the input with an FP32 scalar?
    a_fp8 = a_fp8 * a_scale
    out = jax.lax.dot(a_fp8, b_fp8.T)
    return out

# `__cublas$lt$matmul$f8` expecting FP32 scales.
scale_aval = jax.core.ShapedArray((), jnp.float32)
try:
    fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()
except Exception as e:
    # Issue: JAX does not support implicit mixed-multiplication FP8 x FP32
    print(f"<<< JAX compilation error >>>\n{e}")

<<< JAX compilation error >>>
Input dtypes ('float8_e4m3fn', 'float32') have no available implicit dtype promotion path. To avoid unintended promotion, 8-bit floats do not support implicit promotion. If you'd like your inputs to be promoted to another type, you can do so explicitly using e.g. x.astype('float32')


### FP8 matmul with scaled inputs & outputs

JAX and XLA do not allow implicit conversion between FP8 and FP32, meaning that we need to write something more explicit for the XLA compiler to pattern match and generate the fused call. More specifically, as presented in [XLA FP8 RFC](https://github.com/openxla/xla/discussions/22), one needs to adopt a dequantization/quantization type of semantics:
* Upcast inputs to `float32` and then scale;
* Scale output, clamp to `float8` range (not optional!) and then downcast to `float8`;

As presented below, when using this pattern, the XLA compiler is able to fuse all the operations into a single call of `__cublas$lt$matmul$f8`.

In [6]:
e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max

# XLA requires a "dequantize/quantize" pattern to properly support scaled FP8 inputs/outputs. 
def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):
    # Dequantize x and y
    a_fp32 = a_fp8.astype(jnp.float32) * a_scale
    b_fp32 = b_fp8.astype(jnp.float32) * b_scale
    
    # Do the matmul (NOTE: adding transpose to reduce on last axis).
    d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)
    
    # Rescale & clamp to -max/+max FP8 E4M3 values.
    d_fp32 = d_fp32 * d_scale
    # NOTE: clamping is NOT optional for proper pattern matching!
    d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))
    # (Re)Quantize the scaled matmul output.
    return d_fp32.astype(jnp.float8_e4m3fn)

# AOT compilation with JAX, inspecting the (final) HLO module generated.
fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()
# (Human readable) optimized Hlo module generated by XLA (ignoring GEMM backend config)
print_hlo_module(fn_compiled, backend_cfg=False)

HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="230c40ffa1e1e3ba7f06e4a65ac9e2bd"}

ENTRY %main.22 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {
  %constant_1 = f32[] constant(1)
  %Arg_4.5.0 = f32[] parameter(4)
  %Arg_3.4.0 = f32[] parameter(3)
  %Arg_2.3.0 = f32[] parameter(2)
  %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)
  %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)
  %cublas-gemm.clone.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1, /*index=5*/f32[] %Arg_4.5.0), custom

### Adding non-linearity `relu` to the FP8 matmul

Can we get XLA to fuse a post-matmul non-linearity `relu` function as well?

In [7]:
e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max

# XLA requires a "dequantize/quantize" pattern to properly support scaled FP8 inputs/outputs. 
def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):
    # Dequantize x and y
    a_fp32 = a_fp8.astype(jnp.float32) * a_scale
    b_fp32 = b_fp8.astype(jnp.float32) * b_scale
    
    # Do the matmul (NOTE: adding transpose to simplify HLO).
    d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)
    # ReLU non-linearity. Note: applied before scaling.
    d_fp32 = jax.nn.relu(d_fp32)
    
    # Rescale & clamp to -max/+max FP8 E4M3 values.
    d_fp32 = d_fp32 * d_scale
    # NOTE: clamping is NOT optional for proper pattern matching!
    d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))
    # (Re)Quantize the scaled matmul output.
    return d_fp32.astype(jnp.float8_e4m3fn)

# AOT compilation with JAX, inspecting the (final) HLO module generated.
fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()
# (Human readable) optimized Hlo module generated by XLA (ignoring GEMM backend config)
print_hlo_module(fn_compiled, backend_cfg=False)

HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="f1fb5db9dad54941d7d17e04fdbe9515"}

ENTRY %main.28 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {
  %constant_1_0 = f32[] constant(1)
  %Arg_4.5.0 = f32[] parameter(4)
  %Arg_3.4.0 = f32[] parameter(3)
  %Arg_2.3.0 = f32[] parameter(2)
  %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)
  %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)
  %cublas-gemm.2.clone.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1_0, /*index=5*/f32[] %Arg_4.5.0), 

As shown in the `backend_config` below, the `epilogue` is changed to `RELU`.

In [8]:
hlo_module = parse_hlo_module(fn_compiled)
backend_config = next((m.backend_config for m in hlo_module if "__cublas$lt$matmul$f8" in m.cmd))
# the `epilogue` is set to `RELU`
JSON(backend_config, expanded=True)

<IPython.core.display.JSON object>

### Extracting the `abs-max` of the output

Delayed rescaling is a common technique in FP8 training, using the output **abs-max scaling** in the next training iteration. The benefit of delayed rescaling is that it can also be merged directly in the FP8 matmul kernel, as shown below, with very small performance impact.

In [9]:
e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max

# XLA requires a "dequantize/quantize" pattern to properly support scaled FP8 inputs/outputs. 
def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):
    # Dequantize x and y
    a_fp32 = a_fp8.astype(jnp.float32) * a_scale
    b_fp32 = b_fp8.astype(jnp.float32) * b_scale
    
    # Do the matmul (NOTE: adding transpose to simplify HLO).
    d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)
    # ReLU non-linearity. Note: needs to be before the scaling.
    d_fp32 = jax.nn.relu(d_fp32)
    # Delayed rescaling: capture the raw output scaling for latter.
    out_scale = jnp.max(jnp.abs(d_fp32))

    # Rescale & clamp to -max/+max FP8 E4M3 values.
    d_fp32 = d_fp32 * d_scale
    # NOTE: clamping is NOT optional for proper pattern matching!
    d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))
    # (Re)Quantize the scaled matmul output.
    return d_fp32.astype(jnp.float8_e4m3fn), out_scale

# AOT compilation with JAX, inspecting the (final) HLO module generated.
fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()
# (Human readable) optimized Hlo module generated by XLA (ignoring GEMM backend config)
print_hlo_module(fn_compiled, backend_cfg=False)

HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->(f8e4m3fn[32,128]{1,0}, f32[])}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true}, frontend_attributes={fingerprint_before_lhs="352d42558cb7282e2f79d89bdc48e6d6"}

ENTRY %main.36 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> (f8e4m3fn[32,128], f32[]) {
  %constant_1_0 = f32[] constant(1)
  %Arg_4.5.0 = f32[] parameter(4)
  %Arg_3.4.0 = f32[] parameter(3)
  %Arg_2.3.0 = f32[] parameter(2)
  %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)
  %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)
  %cublas-gemm.2.clone.1.0 = (f8e4m3fn[32,128]{1,0}, f32[], s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1_0, 

As shown in the `backend_config` below, the `damax_output` is changed to `true`, meaning that the **`__cublas$lt$matmul$f8`** method is also computing the `abs-max` of the matmul output (prior to converting back to FP8).

In [10]:
hlo_module = parse_hlo_module(fn_compiled)
backend_config = next((m.backend_config for m in hlo_module if "__cublas$lt$matmul$f8" in m.cmd))
# the `epilogue` is set to `RELU` & `damax_output` to `true`
JSON(backend_config, expanded=True)

<IPython.core.display.JSON object>

### Additional notebook improvements & clarifications

* Fusing Linear layer `bias` add;
* Fusing `jax.nn.gelu` activation layer;
* FP8 peak flops & performance;

### References

* [8-bit Numerical Formats for Deep Neural Networks](https://arxiv.org/abs/2206.02915)
* [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433)
* [FP8-LM: Training FP8 Large Language Models](https://arxiv.org/pdf/2310.18313)
* [Training and inference of large language models
using 8-bit floating point](https://openreview.net/pdf?id=nErbvDkucY)
* [OCP 8-bit Floating Point Specification (OFP8)](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1)
* [IEEE Working Group P3109 Interim Report
on 8-bit Binary Floating-point Formats](https://github.com/P3109/Public/blob/main/Shared%20Reports/P3109%20WG%20Interim%20Report.pdf)