Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Add Float8E4M3FNUZ and Float8E5M2FNUZ to StableHLO. #1342

Merged
merged 4 commits into from
Apr 5, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions rfcs/20230321-fp8_fnuz.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# RFC: Float8E4M3FNUZ and Float8E5M2FNUZ

## Summary

Graphcore, AMD, and Qualcomm have proposed two new FP8 types, Float8E4M3FNUZ
and Float8E5M2FNUZ[^1]. These types are implemented in commercially available
hardware[^2], and added to MLIR builtin types[^4] and LLVM APFloat[^5].

These two types appear similar to the existing types Float8E4M3FN and
Float8E5M2[^3], but differ in important ways.

## Details

Both Float8E4M3FNUZ and Float8E5M2FNUZ differ from typical floating point types
in their support for NaN, infinities, negative zero, and exponent bias. The
suffix "FNUZ" is derived from these differences. `F` is for "finite" (no
infinities), `N` for with special NaN encoding, `UZ` for unsigned zero. I
propose keeping this naming scheme in StableHLO, matching LLVM/MLIR.

These changes mean there's an additional exponent value available. To keep
the same dynamic range as an IEEE-like FP8 type, we bias the exponent one more
than would be expected given the number of exponent bits (8 for Float8E4M3FNUZ
and 16 for Float8E5M2FNUZ).

### Float8E4M3FNUZ

8-bit floating point with 3 bit mantissa.

An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
mantissa. This is not a standard type as defined by IEEE-754, but it follows
similar conventions, with the exception that there are no infinity values, no
negative zero, and only one NaN representation. This type has the following
characteristics:

* bit encoding: S1E4M3 - `0bSEEEEMMM`
* exponent bias: 8
* infinities: Not supported
* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set
to all 0s
* denormals when exponent is 0

### Comparison of Float8E4M3FN and Float8E4M3FNUZ

| |Float8E4M3FN |Float8E4M3FNUZ |
|-------------------|------------------------------------------------------------------------|-------------------------------------------------------------------------|
|Bias |7 |8 |
|Min Normal Value |`0bS0001000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-6</sup> |`0bS0001000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-7</sup> |
|Max Normal Value |`0bS1111110` = -1<sup>S</sup> $\times$ 1.75 $\times$ 2<sup>8</sup> = 448|`0bS1111111` = -1<sup>S</sup> $\times$ 1.875 $\times$ 2<sup>7</sup> = 240|
|Min Subnormal Value|`0bS0000001` = -1<sup>S</sup> $\times$ 0.125 $\times$ 2<sup>-6</sup> |`0bS0000001` = -1<sup>S</sup> $\times$ 0.125 $\times$ 2<sup>-7</sup> |
|Max Subnormal Value|`0bS0000111` = -1<sup>S</sup> $\times$ 0.875 $\times$ 2<sup>-6</sup> |`0bS0000111` = -1<sup>S</sup> $\times$ 0.875 $\times$ 2<sup>-7</sup> |
|NaN |`0bS1111111` |`0b10000000` |
|Infinity |N/A |N/A |
|-0 |`0b10000000` |N/A |

### Float8E5M2FNUZ

8-bit floating point with 2 bit mantissa.

An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
mantissa. This is not a standard type as defined by IEEE-754, but it follows
similar conventions, with the exception that there are no infinity values, no
negative zero, and only one NaN representation. This type has the following
characteristics:

* bit encoding: S1E5M2 - `0bSEEEEEMM`
* exponent bias: 16
* infinities: Not supported
* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set
to all 0s
* denormals when exponent is 0

### Comparison of Float8E5M2 and Float8E5M2FNUZ

| |Float8E5M2 |Float8E5M2FNUZ |
|-------------------|----------------------------------------------------------------------------|---------------------------------------------------------------------------|
|Bias |15 |16 |
|Min Normal Value |`0bS0000100` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-14</sup> |`0bS0001000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-15</sup> |
|Max Normal Value |`0bS1111011` = -1<sup>S</sup> $\times$ 1.75 $\times$ 2<sup>15</sup> = 57344 |`0bS1111111` = -1<sup>S</sup> $\times$ 1.75 $\times$ 2<sup>15</sup> = 57344|
|Min Subnormal Value|`0bS0000001` = -1<sup>S</sup> $\times$ 0.25 $\times$ 2<sup>-14</sup> |`0bS0000001` = -1<sup>S</sup> $\times$ 0.25 $\times$ 2<sup>-15</sup> |
|Max Subnormal Value|`0bS0000011` = -1<sup>S</sup> $\times$ 0.75 $\times$ 2<sup>-14</sup> |`0bS0000011` = -1<sup>S</sup> $\times$ 0.75 $\times$ 2<sup>-15</sup> |
|NaN |`0bS11111MM`, where `MM` is non-zero. |`0b10000000` |
|Infinity |`0bS1111100` |N/A |
|-0 |`0b10000000` |N/A |

## Changes in StableHLO

I propose adding these types to StableHLO similar to the previously introduces
FP8 types [FP8 RFC](https://github.com/openxla/xla/discussions/22) with some
differences.

### StableHLO Interpreter

To provide a reference implementation, I intend to add support for
Float8E4M3FNUZ and Float8E5M2FNUZ in the StableHLO interpreter. This will be
useful for testing other backends and validating new implementations. This will
be achieved in two ways:

1. Map directly to the appropriate APFloat operation.
2. Cast up to the appropriate type, use that implementation, cast back down.

### Float8E4M3FNUZ and Float8E5M2FNUZ Arithmetic

I intend for Float8E4M3FNUZ and Float8E5M2FNUZ to be types that support the
appropriate arithmetic operations, like any other floating point type. For
platforms that don't have hardware support for these types, they may either
throw an error and reject the program or cast up to an appropriate higher
precision type that is supported, compute the answer, and cast back down.

This is a simple approach that aligns with user expectations of a floating
point data type, and is the approach taken by BFloat16. This also gives
backends freedom to exploit any hardware support.

Here's an example of a real JAX program (logging the MLIR) computing a simple
dot product in Float8E5M2FNUZ. Note the answer is slightly "wrong", as expected
due to the lower precision.

```python
>>> import jax
>>> import jax.numpy as jnp
>>> x = jnp.arange(16, dtype=jnp.float8_e5m2fnuz)
jakeh-gc marked this conversation as resolved.
Show resolved Hide resolved
module @jit_iota {
func.func public @main() -> tensor<16xf8E5M2FNUZ> {
%0 = stablehlo.iota dim = 0 : tensor<16xf8E5M2FNUZ>
return %0 : tensor<16xf8E5M2FNUZ>
}
}
>>> x
Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 10, 12, 12, 12, 14, 16], dtype=float8_e5m2fnuz)
>>> x @ x
module @jit_matmul {
func.func public @main(%arg0: tensor<16xf8E5M2FNUZ> {mhlo.sharding = ""}, %arg1: tensor<16xf8E5M2FNUZ> {mhlo.sharding = ""}) -> tensor<f8E5M2FNUZ> {
%0 = "stablehlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #stablehlo.dot<lhs_contracting_dimensions = [0], rhs_contracting_dimensions = [0]>, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<16xf8E5M2FNUZ>, tensor<16xf8E5M2FNUZ>) -> tensor<f8E5M2FNUZ>
return %0 : tensor<f8E5M2FNUZ>
}
}
Array(1280, dtype=float8_e5m2fnuz)
```

### Scaling

At the StableHLO-level we won't impose any scaling requirements on users. Given
arithmetic operations can be supported, we leave the scaling choice to the
user. I believe this is the correct approach given that FP8 applications are an
active area of research.

Graphcore's IPU supports hardware scaling by biasing the hardware's
interpretation of the exponent ±32 at runtime[^2]. This is a backend-specific
peephole optimisation that doesn't impact StableHLO. Other backends may
similarly optimise their implementation of these types.

### Testing

Built on the StableHLO interpreter, I intend to introduce tests for all
possible operations with Float8E4M3FNUZ and Float8E5M2FNUZ inputs. This will at
a minimum mean adding additional cases to the `interpret_X.mlir` family of
tests.

## Not Included

Given any new data types, we could also consider derived types. One example
would be hypothetical complex number types with the real and imaginary
component being constructed from Float8E4M3FNUZ or Float8E5M2FNUZ. I don't
exclude the possibility of this being done in the future, but that is not being
proposed here.

[^1]: [8-bit Numerical Formats for Deep Neural Networks by Noune et al.](https://arxiv.org/abs/2206.02915)
[^2]: [Graphcore Tile Vertex ISA 1.3.1 IPU21](https://docs.graphcore.ai/projects/isa-mk2-with-fp8/en/latest/_static/TileVertexISA-IPU21-1.3.1.pdf)
[^3]: [FP8 Formats for Deep Learning by Micikevicius et al.](https://arxiv.org/abs/2209.05433)
[^4]: [Add Float8E5M2FNUZ and Float8E4M3FNUZ types to MLIR](https://reviews.llvm.org/D143744)
[^5]: [[llvm][APFloat] Add NaN-in-negative-zero formats by AMD and GraphCore](https://reviews.llvm.org/D141863)