From 38acafc921e6d1c11add2b235a7f98b2fbd3078d Mon Sep 17 00:00:00 2001 From: Jake Hall Date: Tue, 21 Mar 2023 19:04:43 +0000 Subject: [PATCH 1/4] Propose adding Float8E4M3FNUZ and Float8E5M2FNUZ to StableHLO. --- rfcs/20230321-fp8_fnuz.md | 122 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 rfcs/20230321-fp8_fnuz.md diff --git a/rfcs/20230321-fp8_fnuz.md b/rfcs/20230321-fp8_fnuz.md new file mode 100644 index 0000000000..52ba0ff951 --- /dev/null +++ b/rfcs/20230321-fp8_fnuz.md @@ -0,0 +1,122 @@ +# RFC: Float8E4M3FNUZ and Float8E5M2FNUZ + +## Summary +Graphcore, AMD, and Qualcomm have proposed two new FP8 types, Float8E4M3FNUZ and Float8E5M2FNUZ[^1]. +These types are implemented in commercially available hardware[^2], and added to MLIR builtin types[^4] and LLVM APFloat[^5]. + +These two types appear similar to the existing types Float8E4M3FN and Float8E5M2[^3], but differ in important ways. + +## Details +Both Float8E4M3FNUZ and Float8E5M2FNUZ differ from typical floating point types in their support for NaN, infinities, negative zero, and exponent bias. +The suffix "FNUZ" is derived from these differences. `F` is for "finite" (no infinities), `N` for with special NaN encoding, `UZ` for unsigned zero. +I propose keeping this naming scheme in StableHLO. + +### Float8E4M3FNUZ +8-bit floating point with 3 bit mantissa. + +An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. This is not a standard type as defined by IEEE-754, but it follows similar conventions, with the exception that there are no infinity values, no negative zero, and only one NaN representation. This type has the following characteristics: +* bit encoding: S1E4M3 - `0bSEEEEMMM` +* exponent bias: 8 +* infinities: Not supported +* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s +* denormals when exponent is 0 + +### Comparison of Float8E4M3FN and Float8E4M3FNUZ +| |Float8E4M3FN |Float8E4M3FNUZ | +|-------------------|----------------------------------------------------------|-----------------------------------------------------------| +|Bias |7 |8 | +|Min Normal Value |`0bS0001000` = -1S * 1.0 * 2-6 |`0bS0001000` = -1S * 1.0 * 2-7 | +|Max Normal Value |`0bS1111110` = -1S * 1.75 * 28 = 448|`0bS1111111` = -1S * 1.875 * 27 = 240| +|Min Subnormal Value|`0bS0000001` = -1S * 0.125 * 2-6 |`0bS0000001` = -1S * 0.125 * 2-7 | +|Max Subnormal Value|`0bS0000111` = -1S * 0.875 * 2-6 |`0bS0000111` = -1S * 0.875 * 2-7 | +|NaN |`0bS1111111` |`0b10000000` | +|Infinity |N/A |N/A | +|-0 |`0b10000000` |N/A | + +### Float8E5M2FNUZ +8-bit floating point with 2 bit mantissa. + +An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. This is not a standard type as defined by IEEE-754, but it follows similar conventions, with the exception that there are no infinity values, no negative zero, and only one NaN representation. This type has the following characteristics: +* bit encoding: S1E5M2 - `0bSEEEEEMM` +* exponent bias: 16 +* infinities: Not supported +* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s +* denormals when exponent is 0 + +### Comparison of Float8E5M2 and Float8E5M2FNUZ +| |Float8E5M2 |Float8E5M2FNUZ | +|-------------------|--------------------------------------------------------------|-------------------------------------------------------------| +|Bias |15 |16 | +|Min Normal Value |`0bS0000100` = -1S * 1.0 * 2-14 |`0bS0001000` = -1S * 1.0 * 2-15 | +|Max Normal Value |`0bS1111011` = -1S * 1.75 * 215 = 57344 |`0bS1111111` = -1S * 1.75 * 215 = 57344| +|Min Subnormal Value|`0bS0000001` = -1S * 0.25 * 2-14 |`0bS0000001` = -1S * 0.25 * 2-15 | +|Max Subnormal Value|`0bS0000011` = -1S * 0.75 * 2-14 |`0bS0000011` = -1S * 0.75 * 2-15 | +|NaN |`0bS11111MM`, where `MM` is non-zero. |`0b10000000` | +|Infinity |`0bS1111100` |N/A | +|-0 |`0b10000000` |N/A | + + +## Changes in StableHLO +I propose adding these types to StableHLO similar to the previously introduces FP8 types [FP8 RFC](https://github.com/openxla/xla/discussions/22) with some differences. + +### StableHLO Interpreter +To provide a reference implementation, I intend to add support for Float8E4M3FNUZ and Float8E5M2FNUZ in the StableHLO interpreter. +This will be useful for testing other backends and validating new implementations. +This will be achieved in two ways: +1. Map directly to the appropriate APFloat operation. +2. Cast up to the appropriate type, use that implementation, cast back down. + +### Float8E4M3FNUZ and Float8E5M2FNUZ Arithmetic +I intend for Float8E4M3FNUZ and Float8E5M2FNUZ to be types that support the appropriate arithmetic operations, like any other floating point type. +For platforms that don't have hardware support for these types, they may either throw an error and reject the program or cast up to an appropriate higher precision type that is supported, compute the answer, and cast back down. + +This is a simple approach that aligns with user expectations of a floating point data type, and is the approach taken by BFloat16. +This also gives backends freedom to exploit any hardware support. + +Here's an example of a real JAX program (logging the MLIR) computing a simple dot product in Float8E5M2FNUZ. +Note the answer is slightly "wrong", as expected due to the lower precision. +``` +>>> import jax +>>> import jax.numpy as jnp +>>> x = jnp.arange(16, dtype=jnp.float8_e5m2fnuz) +module @jit_iota { + func.func public @main() -> tensor<16xf8E5M2FNUZ> { + %0 = stablehlo.iota dim = 0 : tensor<16xf8E5M2FNUZ> + return %0 : tensor<16xf8E5M2FNUZ> + } +} +>>> x +Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 10, 12, 12, 12, 14, 16], dtype=float8_e5m2fnuz) +>>> x @ x +module @jit_matmul { + func.func public @main(%arg0: tensor<16xf8E5M2FNUZ> {mhlo.sharding = ""}, %arg1: tensor<16xf8E5M2FNUZ> {mhlo.sharding = ""}) -> tensor { + %0 = "stablehlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #stablehlo.dot, precision_config = [#stablehlo, #stablehlo]} : (tensor<16xf8E5M2FNUZ>, tensor<16xf8E5M2FNUZ>) -> tensor + return %0 : tensor + } +} +Array(1280, dtype=float8_e5m2fnuz) +``` + +### Scaling +At the StableHLO-level we won't impose any scaling requirements on users. +Given arithmetic operations can be supported, we leave the scaling choice to the user. +I believe this is the correct approach given that FP8 applications are an active area of research. + +Graphcore's IPU supports hardware scaling by biasing the hardware's interpretation of the exponent ±32 at runtime[^2]. +This is a backend-specific peephole optimisation that doesn't impact StableHLO. +Other backends may similarly optimise their implementation of these types. + +### Testing +Built on the StableHLO interpreter, I intend to introduce tests for all possible operations with Float8E4M3FNUZ and Float8E5M2FNUZ inputs. +This will at a minimum mean adding additional cases to the `interpret_X.mlir` family of tests. + +## Not Included +Given any new data types, we could also consider derived types. +One example would be hypothetical complex number types with the real and imaginary component being constructed from Float8E4M3FNUZ or Float8E5M2FNUZ. +I don't exclude the possibility of this being done in the future, but that is not being proposed here. + +[^1]: [8-bit Numerical Formats for Deep Neural Networks by Noune et al.](https://arxiv.org/abs/2206.02915) +[^2]: [Graphcore Tile Vertex ISA 1.3.1 IPU21](https://docs.graphcore.ai/projects/isa-mk2-with-fp8/en/latest/_static/TileVertexISA-IPU21-1.3.1.pdf) +[^3]: [FP8 Formats for Deep Learning by Micikevicius et al.](https://arxiv.org/abs/2209.05433) +[^4]: [Add Float8E5M2FNUZ and Float8E4M3FNUZ types to MLIR](https://reviews.llvm.org/D143744) +[^5]: [[llvm][APFloat] Add NaN-in-negative-zero formats by AMD and GraphCore](https://reviews.llvm.org/D141863) From fb1fca1ba54a106ffd27d68b83790fccabde76bd Mon Sep 17 00:00:00 2001 From: Jake Hall Date: Wed, 22 Mar 2023 15:41:40 +0000 Subject: [PATCH 2/4] Fix markdown linter issues. --- rfcs/20230321-fp8_fnuz.md | 145 ++++++++++++++++++++++++-------------- 1 file changed, 94 insertions(+), 51 deletions(-) diff --git a/rfcs/20230321-fp8_fnuz.md b/rfcs/20230321-fp8_fnuz.md index 52ba0ff951..3701bd1e1b 100644 --- a/rfcs/20230321-fp8_fnuz.md +++ b/rfcs/20230321-fp8_fnuz.md @@ -1,81 +1,115 @@ # RFC: Float8E4M3FNUZ and Float8E5M2FNUZ ## Summary -Graphcore, AMD, and Qualcomm have proposed two new FP8 types, Float8E4M3FNUZ and Float8E5M2FNUZ[^1]. -These types are implemented in commercially available hardware[^2], and added to MLIR builtin types[^4] and LLVM APFloat[^5]. -These two types appear similar to the existing types Float8E4M3FN and Float8E5M2[^3], but differ in important ways. +Graphcore, AMD, and Qualcomm have proposed two new FP8 types, Float8E4M3FNUZ +and Float8E5M2FNUZ[^1]. These types are implemented in commercially available +hardware[^2], and added to MLIR builtin types[^4] and LLVM APFloat[^5]. + +These two types appear similar to the existing types Float8E4M3FN and +Float8E5M2[^3], but differ in important ways. ## Details -Both Float8E4M3FNUZ and Float8E5M2FNUZ differ from typical floating point types in their support for NaN, infinities, negative zero, and exponent bias. -The suffix "FNUZ" is derived from these differences. `F` is for "finite" (no infinities), `N` for with special NaN encoding, `UZ` for unsigned zero. + +Both Float8E4M3FNUZ and Float8E5M2FNUZ differ from typical floating point types +in their support for NaN, infinities, negative zero, and exponent bias. The +suffix "FNUZ" is derived from these differences. `F` is for "finite" (no +infinities), `N` for with special NaN encoding, `UZ` for unsigned zero. I propose keeping this naming scheme in StableHLO. ### Float8E4M3FNUZ + 8-bit floating point with 3 bit mantissa. -An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. This is not a standard type as defined by IEEE-754, but it follows similar conventions, with the exception that there are no infinity values, no negative zero, and only one NaN representation. This type has the following characteristics: +An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits +mantissa. This is not a standard type as defined by IEEE-754, but it follows +similar conventions, with the exception that there are no infinity values, no +negative zero, and only one NaN representation. This type has the following +characteristics: + * bit encoding: S1E4M3 - `0bSEEEEMMM` * exponent bias: 8 * infinities: Not supported -* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s +* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set +to all 0s * denormals when exponent is 0 ### Comparison of Float8E4M3FN and Float8E4M3FNUZ -| |Float8E4M3FN |Float8E4M3FNUZ | -|-------------------|----------------------------------------------------------|-----------------------------------------------------------| -|Bias |7 |8 | -|Min Normal Value |`0bS0001000` = -1S * 1.0 * 2-6 |`0bS0001000` = -1S * 1.0 * 2-7 | -|Max Normal Value |`0bS1111110` = -1S * 1.75 * 28 = 448|`0bS1111111` = -1S * 1.875 * 27 = 240| -|Min Subnormal Value|`0bS0000001` = -1S * 0.125 * 2-6 |`0bS0000001` = -1S * 0.125 * 2-7 | -|Max Subnormal Value|`0bS0000111` = -1S * 0.875 * 2-6 |`0bS0000111` = -1S * 0.875 * 2-7 | -|NaN |`0bS1111111` |`0b10000000` | -|Infinity |N/A |N/A | -|-0 |`0b10000000` |N/A | + +| |Float8E4M3FN |Float8E4M3FNUZ | +|-------------------|------------------------------------------------------------------------|-------------------------------------------------------------------------| +|Bias |7 |8 | +|Min Normal Value |`0bS0001000` = -1S $\times$ 1.0 $\times$ 2-6 |`0bS0001000` = -1S $\times$ 1.0 $\times$ 2-7 | +|Max Normal Value |`0bS1111110` = -1S $\times$ 1.75 $\times$ 28 = 448|`0bS1111111` = -1S $\times$ 1.875 $\times$ 27 = 240| +|Min Subnormal Value|`0bS0000001` = -1S $\times$ 0.125 $\times$ 2-6 |`0bS0000001` = -1S $\times$ 0.125 $\times$ 2-7 | +|Max Subnormal Value|`0bS0000111` = -1S $\times$ 0.875 $\times$ 2-6 |`0bS0000111` = -1S $\times$ 0.875 $\times$ 2-7 | +|NaN |`0bS1111111` |`0b10000000` | +|Infinity |N/A |N/A | +|-0 |`0b10000000` |N/A | ### Float8E5M2FNUZ + 8-bit floating point with 2 bit mantissa. -An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. This is not a standard type as defined by IEEE-754, but it follows similar conventions, with the exception that there are no infinity values, no negative zero, and only one NaN representation. This type has the following characteristics: +An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits +mantissa. This is not a standard type as defined by IEEE-754, but it follows +similar conventions, with the exception that there are no infinity values, no +negative zero, and only one NaN representation. This type has the following +characteristics: + * bit encoding: S1E5M2 - `0bSEEEEEMM` * exponent bias: 16 * infinities: Not supported -* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s +* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set +to all 0s * denormals when exponent is 0 ### Comparison of Float8E5M2 and Float8E5M2FNUZ -| |Float8E5M2 |Float8E5M2FNUZ | -|-------------------|--------------------------------------------------------------|-------------------------------------------------------------| -|Bias |15 |16 | -|Min Normal Value |`0bS0000100` = -1S * 1.0 * 2-14 |`0bS0001000` = -1S * 1.0 * 2-15 | -|Max Normal Value |`0bS1111011` = -1S * 1.75 * 215 = 57344 |`0bS1111111` = -1S * 1.75 * 215 = 57344| -|Min Subnormal Value|`0bS0000001` = -1S * 0.25 * 2-14 |`0bS0000001` = -1S * 0.25 * 2-15 | -|Max Subnormal Value|`0bS0000011` = -1S * 0.75 * 2-14 |`0bS0000011` = -1S * 0.75 * 2-15 | -|NaN |`0bS11111MM`, where `MM` is non-zero. |`0b10000000` | -|Infinity |`0bS1111100` |N/A | -|-0 |`0b10000000` |N/A | +| |Float8E5M2 |Float8E5M2FNUZ | +|-------------------|----------------------------------------------------------------------------|---------------------------------------------------------------------------| +|Bias |15 |16 | +|Min Normal Value |`0bS0000100` = -1S $\times$ 1.0 $\times$ 2-14 |`0bS0001000` = -1S $\times$ 1.0 $\times$ 2-15 | +|Max Normal Value |`0bS1111011` = -1S $\times$ 1.75 $\times$ 215 = 57344 |`0bS1111111` = -1S $\times$ 1.75 $\times$ 215 = 57344| +|Min Subnormal Value|`0bS0000001` = -1S $\times$ 0.25 $\times$ 2-14 |`0bS0000001` = -1S $\times$ 0.25 $\times$ 2-15 | +|Max Subnormal Value|`0bS0000011` = -1S $\times$ 0.75 $\times$ 2-14 |`0bS0000011` = -1S $\times$ 0.75 $\times$ 2-15 | +|NaN |`0bS11111MM`, where `MM` is non-zero. |`0b10000000` | +|Infinity |`0bS1111100` |N/A | +|-0 |`0b10000000` |N/A | ## Changes in StableHLO -I propose adding these types to StableHLO similar to the previously introduces FP8 types [FP8 RFC](https://github.com/openxla/xla/discussions/22) with some differences. + +I propose adding these types to StableHLO similar to the previously introduces +FP8 types [FP8 RFC](https://github.com/openxla/xla/discussions/22) with some +differences. ### StableHLO Interpreter -To provide a reference implementation, I intend to add support for Float8E4M3FNUZ and Float8E5M2FNUZ in the StableHLO interpreter. -This will be useful for testing other backends and validating new implementations. -This will be achieved in two ways: + +To provide a reference implementation, I intend to add support for +Float8E4M3FNUZ and Float8E5M2FNUZ in the StableHLO interpreter. This will be +useful for testing other backends and validating new implementations. This will +be achieved in two ways: + 1. Map directly to the appropriate APFloat operation. 2. Cast up to the appropriate type, use that implementation, cast back down. ### Float8E4M3FNUZ and Float8E5M2FNUZ Arithmetic -I intend for Float8E4M3FNUZ and Float8E5M2FNUZ to be types that support the appropriate arithmetic operations, like any other floating point type. -For platforms that don't have hardware support for these types, they may either throw an error and reject the program or cast up to an appropriate higher precision type that is supported, compute the answer, and cast back down. -This is a simple approach that aligns with user expectations of a floating point data type, and is the approach taken by BFloat16. -This also gives backends freedom to exploit any hardware support. +I intend for Float8E4M3FNUZ and Float8E5M2FNUZ to be types that support the +appropriate arithmetic operations, like any other floating point type. For +platforms that don't have hardware support for these types, they may either +throw an error and reject the program or cast up to an appropriate higher +precision type that is supported, compute the answer, and cast back down. -Here's an example of a real JAX program (logging the MLIR) computing a simple dot product in Float8E5M2FNUZ. -Note the answer is slightly "wrong", as expected due to the lower precision. -``` +This is a simple approach that aligns with user expectations of a floating +point data type, and is the approach taken by BFloat16. This also gives +backends freedom to exploit any hardware support. + +Here's an example of a real JAX program (logging the MLIR) computing a simple +dot product in Float8E5M2FNUZ. Note the answer is slightly "wrong", as expected +due to the lower precision. + +```python >>> import jax >>> import jax.numpy as jnp >>> x = jnp.arange(16, dtype=jnp.float8_e5m2fnuz) @@ -98,22 +132,31 @@ Array(1280, dtype=float8_e5m2fnuz) ``` ### Scaling -At the StableHLO-level we won't impose any scaling requirements on users. -Given arithmetic operations can be supported, we leave the scaling choice to the user. -I believe this is the correct approach given that FP8 applications are an active area of research. -Graphcore's IPU supports hardware scaling by biasing the hardware's interpretation of the exponent ±32 at runtime[^2]. -This is a backend-specific peephole optimisation that doesn't impact StableHLO. -Other backends may similarly optimise their implementation of these types. +At the StableHLO-level we won't impose any scaling requirements on users. Given +arithmetic operations can be supported, we leave the scaling choice to the +user. I believe this is the correct approach given that FP8 applications are an +active area of research. + +Graphcore's IPU supports hardware scaling by biasing the hardware's +interpretation of the exponent ±32 at runtime[^2]. This is a backend-specific +peephole optimisation that doesn't impact StableHLO. Other backends may +similarly optimise their implementation of these types. ### Testing -Built on the StableHLO interpreter, I intend to introduce tests for all possible operations with Float8E4M3FNUZ and Float8E5M2FNUZ inputs. -This will at a minimum mean adding additional cases to the `interpret_X.mlir` family of tests. + +Built on the StableHLO interpreter, I intend to introduce tests for all +possible operations with Float8E4M3FNUZ and Float8E5M2FNUZ inputs. This will at +a minimum mean adding additional cases to the `interpret_X.mlir` family of +tests. ## Not Included -Given any new data types, we could also consider derived types. -One example would be hypothetical complex number types with the real and imaginary component being constructed from Float8E4M3FNUZ or Float8E5M2FNUZ. -I don't exclude the possibility of this being done in the future, but that is not being proposed here. + +Given any new data types, we could also consider derived types. One example +would be hypothetical complex number types with the real and imaginary +component being constructed from Float8E4M3FNUZ or Float8E5M2FNUZ. I don't +exclude the possibility of this being done in the future, but that is not being +proposed here. [^1]: [8-bit Numerical Formats for Deep Neural Networks by Noune et al.](https://arxiv.org/abs/2206.02915) [^2]: [Graphcore Tile Vertex ISA 1.3.1 IPU21](https://docs.graphcore.ai/projects/isa-mk2-with-fp8/en/latest/_static/TileVertexISA-IPU21-1.3.1.pdf) From cd78efaae7f2346104249275cb262a2b4fc48400 Mon Sep 17 00:00:00 2001 From: Jake Hall Date: Fri, 31 Mar 2023 21:23:10 +0100 Subject: [PATCH 3/4] Document different exponent bias. --- rfcs/20230321-fp8_fnuz.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/rfcs/20230321-fp8_fnuz.md b/rfcs/20230321-fp8_fnuz.md index 3701bd1e1b..b6421595b4 100644 --- a/rfcs/20230321-fp8_fnuz.md +++ b/rfcs/20230321-fp8_fnuz.md @@ -14,8 +14,13 @@ Float8E5M2[^3], but differ in important ways. Both Float8E4M3FNUZ and Float8E5M2FNUZ differ from typical floating point types in their support for NaN, infinities, negative zero, and exponent bias. The suffix "FNUZ" is derived from these differences. `F` is for "finite" (no -infinities), `N` for with special NaN encoding, `UZ` for unsigned zero. -I propose keeping this naming scheme in StableHLO. +infinities), `N` for with special NaN encoding, `UZ` for unsigned zero. I +propose keeping this naming scheme in StableHLO, matching LLVM/MLIR. + +These changes mean there's an additional exponent value available. To keep +the same dynamic range as an IEEE-like FP8 type, we bias the exponent one more +than would be expected given the number of exponent bits (8 for Float8E4M3FNUZ +and 16 for Float8E5M2FNUZ). ### Float8E4M3FNUZ From fd5c9a18de4571b85ca0a7bc7332e9193c443522 Mon Sep 17 00:00:00 2001 From: Jake Hall <60800749+jakeh-gc@users.noreply.github.com> Date: Wed, 5 Apr 2023 15:16:05 +0100 Subject: [PATCH 4/4] Wording change --- rfcs/20230321-fp8_fnuz.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rfcs/20230321-fp8_fnuz.md b/rfcs/20230321-fp8_fnuz.md index b6421595b4..eccc0634f9 100644 --- a/rfcs/20230321-fp8_fnuz.md +++ b/rfcs/20230321-fp8_fnuz.md @@ -18,9 +18,9 @@ infinities), `N` for with special NaN encoding, `UZ` for unsigned zero. I propose keeping this naming scheme in StableHLO, matching LLVM/MLIR. These changes mean there's an additional exponent value available. To keep -the same dynamic range as an IEEE-like FP8 type, we bias the exponent one more -than would be expected given the number of exponent bits (8 for Float8E4M3FNUZ -and 16 for Float8E5M2FNUZ). +the same dynamic range as an IEEE-like FP8 type, the exponent is biased one +more than would be expected given the number of exponent bits (8 for +Float8E4M3FNUZ and 16 for Float8E5M2FNUZ). ### Float8E4M3FNUZ