Status: Approved
Initial version: 3/21/2023
Last updated: 4/5/2023
Discussion thread: GitHub
Graphcore, AMD, and Qualcomm have proposed two new FP8 types, Float8E4M3FNUZ and Float8E5M2FNUZ1. These types are implemented in commercially available hardware2, and added to MLIR builtin types3 and LLVM APFloat4.
These two types appear similar to the existing types Float8E4M3FN and Float8E5M25, but differ in important ways.
Both Float8E4M3FNUZ and Float8E5M2FNUZ differ from typical floating point types
in their support for NaN, infinities, negative zero, and exponent bias. The
suffix "FNUZ" is derived from these differences. F
is for "finite" (no
infinities), N
for with special NaN encoding, UZ
for unsigned zero. I
propose keeping this naming scheme in StableHLO, matching LLVM/MLIR.
These changes mean there's an additional exponent value available. To keep the same dynamic range as an IEEE-like FP8 type, the exponent is biased one more than would be expected given the number of exponent bits (8 for Float8E4M3FNUZ and 16 for Float8E5M2FNUZ).
8-bit floating point with 3 bit mantissa.
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. This is not a standard type as defined by IEEE-754, but it follows similar conventions, with the exception that there are no infinity values, no negative zero, and only one NaN representation. This type has the following characteristics:
- bit encoding: S1E4M3 -
0bSEEEEMMM
- exponent bias: 8
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
- denormals when exponent is 0
Float8E4M3FN | Float8E4M3FNUZ | |
---|---|---|
Bias | 7 | 8 |
Min Normal Value |
0bS0001000 = -1S |
0bS0001000 = -1S |
Max Normal Value |
0bS1111110 = -1S |
0bS1111111 = -1S |
Min Subnormal Value |
0bS0000001 = -1S |
0bS0000001 = -1S |
Max Subnormal Value |
0bS0000111 = -1S |
0bS0000111 = -1S |
NaN | 0bS1111111 |
0b10000000 |
Infinity | N/A | N/A |
-0 | 0b10000000 |
N/A |
8-bit floating point with 2 bit mantissa.
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. This is not a standard type as defined by IEEE-754, but it follows similar conventions, with the exception that there are no infinity values, no negative zero, and only one NaN representation. This type has the following characteristics:
- bit encoding: S1E5M2 -
0bSEEEEEMM
- exponent bias: 16
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
- denormals when exponent is 0
Float8E5M2 | Float8E5M2FNUZ | |
---|---|---|
Bias | 15 | 16 |
Min Normal Value |
0bS0000100 = -1S |
0bS0000100 = -1S |
Max Normal Value |
0bS1111011 = -1S |
0bS1111111 = -1S |
Min Subnormal Value |
0bS0000001 = -1S |
0bS0000001 = -1S |
Max Subnormal Value |
0bS0000011 = -1S |
0bS0000011 = -1S |
NaN |
0bS11111MM , where MM is non-zero. |
0b10000000 |
Infinity | 0bS1111100 |
N/A |
-0 | 0b10000000 |
N/A |
I propose adding these types to StableHLO similar to the previously introduces FP8 types FP8 RFC with some differences.
To provide a reference implementation, I intend to add support for Float8E4M3FNUZ and Float8E5M2FNUZ in the StableHLO interpreter. This will be useful for testing other backends and validating new implementations. This will be achieved in two ways:
- Map directly to the appropriate APFloat operation.
- Cast up to the appropriate type, use that implementation, cast back down.
I intend for Float8E4M3FNUZ and Float8E5M2FNUZ to be types that support the appropriate arithmetic operations, like any other floating point type. For platforms that don't have hardware support for these types, they may either throw an error and reject the program or cast up to an appropriate higher precision type that is supported, compute the answer, and cast back down.
This is a simple approach that aligns with user expectations of a floating point data type, and is the approach taken by BFloat16. This also gives backends freedom to exploit any hardware support.
Here's an example of a real JAX program (logging the MLIR) computing a simple dot product in Float8E5M2FNUZ. Note the answer is slightly "wrong", as expected due to the lower precision.
>>> import jax
>>> import jax.numpy as jnp
>>> x = jnp.arange(16, dtype=jnp.float8_e5m2fnuz)
module @jit_iota {
func.func public @main() -> tensor<16xf8E5M2FNUZ> {
%0 = stablehlo.iota dim = 0 : tensor<16xf8E5M2FNUZ>
return %0 : tensor<16xf8E5M2FNUZ>
}
}
>>> x
Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 10, 12, 12, 12, 14, 16], dtype=float8_e5m2fnuz)
>>> x @ x
module @jit_matmul {
func.func public @main(%arg0: tensor<16xf8E5M2FNUZ> {mhlo.sharding = ""}, %arg1: tensor<16xf8E5M2FNUZ> {mhlo.sharding = ""}) -> tensor<f8E5M2FNUZ> {
%0 = "stablehlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #stablehlo.dot<lhs_contracting_dimensions = [0], rhs_contracting_dimensions = [0]>, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<16xf8E5M2FNUZ>, tensor<16xf8E5M2FNUZ>) -> tensor<f8E5M2FNUZ>
return %0 : tensor<f8E5M2FNUZ>
}
}
Array(1280, dtype=float8_e5m2fnuz)
At the StableHLO-level we won't impose any scaling requirements on users. Given arithmetic operations can be supported, we leave the scaling choice to the user. I believe this is the correct approach given that FP8 applications are an active area of research.
Graphcore's IPU supports hardware scaling by biasing the hardware's interpretation of the exponent ±32 at runtime2. This is a backend-specific peephole optimisation that doesn't impact StableHLO. Other backends may similarly optimise their implementation of these types.
Built on the StableHLO interpreter, I intend to introduce tests for all
possible operations with Float8E4M3FNUZ and Float8E5M2FNUZ inputs. This will at
a minimum mean adding additional cases to the interpret_X.mlir
family of
tests.
Given any new data types, we could also consider derived types. One example would be hypothetical complex number types with the real and imaginary component being constructed from Float8E4M3FNUZ or Float8E5M2FNUZ. I don't exclude the possibility of this being done in the future, but that is not being proposed here.