Skip to content

Commit

Permalink
Add Float8E4M3FNUZ and Float8E5M2FNUZ type support. (#1379)
Browse files Browse the repository at this point in the history
As proposed in [RFC: Float8E4M3FNUZ and
Float8E5M2FNUZ](https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md)
(#1342), this change adds support for these types to StableHLO.

This includes the type definitions, vhlo, and interpreter support. The
testing approach mirrors the BFloat16 tests, since it is also a
"non-standard" floating point type supported by StableHLO.
  • Loading branch information
jakeh-gc committed Apr 19, 2023
1 parent f2666c0 commit b140075
Show file tree
Hide file tree
Showing 18 changed files with 166 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/bytecode.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ this format, which we successfully did for StableHLO

Portable artifacts can be created using either the `stablehlo-translate` tool,
or directly in C++ or Python APIs. Serialization needs a target version of
StableHLO to write an artifact (the current version is 0.9.0). Deserialization
StableHLO to write an artifact (the current version is 0.10.0). Deserialization
uses the current version of StableHLO to read an artifact.

### Using the `stablehlo-translate` tool
Expand Down
2 changes: 1 addition & 1 deletion docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ that StableHLO provides, based on the process established in

## Versions

The current version of StableHLO is **0.9.0**.
The current version of StableHLO is **0.10.0**.

In the 0.x.x series, StableHLO is providing limited compatibility guarantees
as described in the Guarantees section. In H2 2023, we are planning to release
Expand Down
6 changes: 5 additions & 1 deletion docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ ElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
| 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'bf16' | 'f16' | 'f32' | 'f64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ('f32' | 'f64') '>'
```

Expand All @@ -246,6 +247,9 @@ values of type `tensor<T>`).
* `f8E4M3FN` and `f8E5M2` types corresponding to respectively the
`E4M3` and `E5M2` encodings of the FP8 format described in
[FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433).
* `f8E4M3FNUZ` and `f8E5M2FNUZ` types corresponding to the `E4M3` and `E5M2`
encodings of the FP8 formats described in
[8-bit Numerical Formats for Deep Neural Networks](https://arxiv.org/abs/2206.02915).
* `bf16` type corresponding to the `bfloat16` format described in
[BFloat16: The secret to high performance on Cloud TPUs](https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus).
* `f16`, `f32` and `f64` types corresponding to respectively
Expand Down
3 changes: 2 additions & 1 deletion stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def HLO_SInt : SignlessIntOfWidths<[4, 8, 16, 32, 64]>;
def HLO_UInt : UnsignedIntOfWidths<[4, 8, 16, 32, 64]>;
def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>;

def HLO_Float : AnyTypeOf<[F8E4M3FN, F8E5M2, F16, F32, F64, BF16]>;
def HLO_Float : AnyTypeOf<[F8E4M3FN, F8E5M2, F8E4M3FNUZ, F8E5M2FNUZ,
F16, F32, F64, BF16]>;
def HLO_Float32Or64 : AnyTypeOf<[F32, F64]>;

def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Version {
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current VHLO dialect version.
static Version getCurrentVersion() { return Version(0, 9, 0); }
static Version getCurrentVersion() { return Version(0, 10, 0); }

/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }
Expand Down
24 changes: 24 additions & 0 deletions stablehlo/dialect/VhloBytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,14 @@ enum TypeCode {
/// WitnessV1Type {
/// }
kWitnessV1Type = 26,

/// FloatF8E4M3FNUZV1Type {
/// }
kFloatF8E4M3FNUZV1Type = 27,

/// FloatF8E5M2FNUZV1Type {
/// }
kFloatF8E5M2FNUZV1Type = 28,
};

} // namespace vhlo_encoding
Expand Down Expand Up @@ -693,7 +701,9 @@ const llvm::fltSemantics &getFloatSemantics(Type type) {
if (type.isa<FloatF16V1Type>()) return APFloat::IEEEhalf();
if (type.isa<FloatF32V1Type>()) return APFloat::IEEEsingle();
if (type.isa<FloatF64V1Type>()) return APFloat::IEEEdouble();
if (type.isa<FloatF8E4M3FNUZV1Type>()) return APFloat::Float8E4M3FNUZ();
if (type.isa<FloatF8E4M3FNV1Type>()) return APFloat::Float8E4M3FN();
if (type.isa<FloatF8E5M2FNUZV1Type>()) return APFloat::Float8E5M2FNUZ();
if (type.isa<FloatF8E5M2V1Type>()) return APFloat::Float8E5M2();
llvm::report_fatal_error("unsupported floating-point type");
}
Expand Down Expand Up @@ -961,6 +971,10 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const {
return FloatF8E5M2V1Type::get(getContext());
case vhlo_encoding::kFloatF8E4M3FNV1Type:
return FloatF8E4M3FNV1Type::get(getContext());
case vhlo_encoding::kFloatF8E5M2FNUZV1Type:
return FloatF8E5M2FNUZV1Type::get(getContext());
case vhlo_encoding::kFloatF8E4M3FNUZV1Type:
return FloatF8E4M3FNUZV1Type::get(getContext());
case vhlo_encoding::kFunctionV1Type:
return readFunctionV1Type(reader);
case vhlo_encoding::kIndexV1Type:
Expand Down Expand Up @@ -1044,6 +1058,16 @@ LogicalResult VhloBytecodeInterface::writeType(
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF8E5M2V1Type), success();
})
.Case([&](FloatF8E4M3FNUZV1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF8E4M3FNUZV1Type),
success();
})
.Case([&](FloatF8E5M2FNUZV1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF8E5M2FNUZV1Type),
success();
})
.Case([&](IndexV1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kIndexV1Type), success();
Expand Down
1 change: 1 addition & 0 deletions stablehlo/dialect/VhloDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def VHLO_Dialect : Dialect {

Version log:
0.9.0: Initial stability guarantees.
0.10.0: Introduce `f8E4M3FNUZ` and `f8E5M2FNUZ` types.
}];

let useDefaultAttributePrinterParser = 0;
Expand Down
12 changes: 12 additions & 0 deletions stablehlo/dialect/VhloTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ void VhloTypeConverter::addBuiltinToVhloConversions() {
addConversion([&](Float8E5M2Type type) {
return FloatF8E5M2V1Type::get(type.getContext());
});
addConversion([&](Float8E4M3FNUZType type) {
return FloatF8E4M3FNUZV1Type::get(type.getContext());
});
addConversion([&](Float8E5M2FNUZType type) {
return FloatF8E5M2FNUZV1Type::get(type.getContext());
});
addConversion([&](FunctionType type) -> Type {
SmallVector<Type> convertedInputs;
SmallVector<Type> convertedResults;
Expand Down Expand Up @@ -143,6 +149,12 @@ void VhloTypeConverter::addVhloToBuiltinConversions() {
addConversion([&](FloatF8E5M2V1Type type) {
return Float8E5M2Type::get(type.getContext());
});
addConversion([&](FloatF8E4M3FNUZV1Type type) {
return Float8E4M3FNUZType::get(type.getContext());
});
addConversion([&](FloatF8E5M2FNUZV1Type type) {
return Float8E5M2FNUZType::get(type.getContext());
});
addConversion([&](FunctionV1Type type) -> Type {
SmallVector<Type> convertedInputs;
SmallVector<Type> convertedOutputs;
Expand Down
6 changes: 6 additions & 0 deletions stablehlo/dialect/VhloTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def VHLO_FloatF8E4M3FNV1 : VHLO_TypeDef<"FloatF8E4M3FNV1", "f8E4M3FN_v1", "0.9.0
// Corresponds to the 'f8E5M2' FloatType from the StableHLO spec.
def VHLO_FloatF8E5M2V1 : VHLO_TypeDef<"FloatF8E5M2V1", "f8E5M2_v1", "0.9.0", "current">;

// Corresponds to the 'f8E4M3FNUZ' FloatType from the StableHLO spec.
def VHLO_FloatF8E4M3FNUZV1 : VHLO_TypeDef<"FloatF8E4M3FNUZV1", "f8E4M3FNUZ_v1", "0.10.0", "current">;

// Corresponds to the 'f8E5M2FNUZ' FloatType from the StableHLO spec.
def VHLO_FloatF8E5M2FNUZV1 : VHLO_TypeDef<"FloatF8E5M2FNUZV1", "f8E5M2FNUZ_v1", "0.10.0", "current">;

// Corresponds to FunctionType from the StableHLO spec.
def VHLO_FunctionV1 : VHLO_TypeDef<"FunctionV1", "func_v1", "0.9.0", "current"> {
let parameters = (ins
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/integrations/python/tests/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_api_version():
@run
def test_serialization_apis():
curr_version = stablehlo.get_current_version()
assert curr_version == "0.9.0"
assert curr_version == "0.10.0"

ASM = """
func.func @test(%arg0: tensor<2xf32>) -> tensor<2xf32> {
Expand Down
29 changes: 29 additions & 0 deletions stablehlo/reference/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ Element Tensor::get(const Index &index) const {
getSizeInBytes(elementType) * flattenIndex(getShape(), index);

// Handle floating-point types.
if (elementType.isFloat8E4M3FNUZ()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3FNUZ(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E5M2FNUZ()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E5M2FNUZ(),
APInt(8, *elementData)));
}
if (elementType.isF16()) {
auto elementData = reinterpret_cast<const uint16_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::IEEEhalf(),
Expand Down Expand Up @@ -207,6 +217,13 @@ void Tensor::set(const Index &index, const Element &element) {
getSizeInBytes(elementType) * flattenIndex(getShape(), index);

// Handle floating-point types.
if (elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2FNUZ()) {
auto elementData = reinterpret_cast<uint8_t *>(elementPtr);
auto value = element.getFloatValue();
*elementData = (uint8_t)value.bitcastToAPInt().getZExtValue();
return;
}

if (elementType.isF16() || elementType.isBF16()) {
auto elementData = reinterpret_cast<uint16_t *>(elementPtr);
auto value = element.getFloatValue();
Expand Down Expand Up @@ -354,6 +371,18 @@ Tensor makeTensor(DenseElementsAttr attr) {
auto elemType = type.getElementType();

// Handle floating-point types.
if (elemType.isFloat8E4M3FNUZ() || elemType.isFloat8E5M2FNUZ()) {
auto floatValues = llvm::to_vector(llvm::map_range(
attr.getValues<APFloat>(), [&](APFloat value) -> uint8_t {
return value.bitcastToAPInt().getZExtValue();
}));

// For both f8E4M3FNUZ and f8E5M2FNUZ floating-point types, we use uint8_t
// as their storage type because there are no builtin types for those.
return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign<uint8_t>(
floatValues));
}

if (elemType.isF16() || elemType.isBF16()) {
auto floatValues = llvm::to_vector(llvm::map_range(
attr.getValues<APFloat>(), [&](APFloat value) -> uint16_t {
Expand Down
3 changes: 2 additions & 1 deletion stablehlo/reference/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ bool isSupportedIntegerType(Type type) {
}

bool isSupportedFloatType(Type type) {
return type.isF16() || type.isBF16() || type.isF32() || type.isF64();
return type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() || type.isF16() ||
type.isBF16() || type.isF32() || type.isF64();
}

bool isSupportedComplexType(Type type) {
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/tests/interpret_constant.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,22 @@ func.func @constant_op_test_ui64() {

// -----

func.func @constant_op_test_f8_e4m3_fnuz() {
%0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E4M3FNUZ>
check.expect_eq_const %0, dense<[0.0, 0.0, 1.0, 0.125, 0.1, 3.25, 240.0, -240.0, 0.000976562, -0.000976562]> : tensor<10xf8E4M3FNUZ>
func.return
}

// -----

func.func @constant_op_test_f8_e5m2_fnuz() {
%0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E5M2FNUZ>
check.expect_eq_const %0, dense<[0.0, 0.0, 1.0, 0.125, 0.1, 3.0, 57344.0, -57344.0, 7.62939e-06, -7.62939e-06]> : tensor<10xf8E5M2FNUZ>
func.return
}

// -----

func.func @constant_op_test_bf16() {
%0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.140630, 0x7F80, 0xFF80, 0x7FFF, 0x0001, 0x8001]> : tensor<11xbf16>
check.expect_almost_eq_const %0, dense<[0.000000e+00, -0.000000e+00, 1.000000e+00, 1.250000e-01, 1.000980e-01, 3.140630e+00, 0x7F80, 0xFF80, 0x7FFF, 9.183550e-41, -9.183550e-41]> : tensor<11xbf16>
Expand Down
Loading

0 comments on commit b140075

Please sign in to comment.