Skip to content

Commit

Permalink
Add Float8E4M3B11FNUZ type support.
Browse files Browse the repository at this point in the history
As proposed in [RFC: E4M3B11FNUZ in XLA](https://github.com/openxla/stablehlo/blob/main/rfcs/20230309-e4m3b11.md) (openxla#1308),
this change adds support for this type to StableHLO.

This includes the type definitions, vhlo, and interpreter support. The
testing approach mirrors the Float8E4M3FNUZ tests, since it is also a
"non-standard" floating point type supported by StableHLO.
  • Loading branch information
majnemer committed May 2, 2023
1 parent 43e9dda commit c42cc10
Show file tree
Hide file tree
Showing 13 changed files with 2,322 additions and 21 deletions.
4 changes: 3 additions & 1 deletion docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ BooleanType ::= 'i1'
IntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
| 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'bf16' | 'f16' | 'f32' | 'f64'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ('f32' | 'f64') '>'
```

Expand All @@ -250,6 +250,8 @@ values of type `tensor<T>`).
* `f8E4M3FNUZ` and `f8E5M2FNUZ` types corresponding to the `E4M3` and `E5M2`
encodings of the FP8 formats described in
[8-bit Numerical Formats for Deep Neural Networks](https://arxiv.org/abs/2206.02915).
* `f8E4M3B11FNUZ` type corresponding to the `E4M3` encoding of the FP8 formats described in
[Hybrid 8-bit Floating Point (HFP8) Training and Inference for Deep Neural Networks](https://proceedings.neurips.cc/paper_files/paper/2019/file/65fc9fb4897a89789352e211ca2d398f-Paper.pdf).
* `bf16` type corresponding to the `bfloat16` format described in
[BFloat16: The secret to high performance on Cloud TPUs](https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus).
* `f16`, `f32` and `f64` types corresponding to respectively
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def HLO_UInt : UnsignedIntOfWidths<[4, 8, 16, 32, 64]>;
def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>;

def HLO_Float : AnyTypeOf<[F8E4M3FN, F8E5M2, F8E4M3FNUZ, F8E5M2FNUZ,
F16, F32, F64, BF16]>;
F8E4M3B11FNUZ, F16, F32, F64, BF16]>;
def HLO_Float32Or64 : AnyTypeOf<[F32, F64]>;

def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Version {
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current VHLO dialect version.
static Version getCurrentVersion() { return Version(0, 10, 2); }
static Version getCurrentVersion() { return Version(0, 11, 0); }

/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }
Expand Down
12 changes: 12 additions & 0 deletions stablehlo/dialect/VhloBytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ enum TypeCode {
/// FloatF8E5M2FNUZV1Type {
/// }
kFloatF8E5M2FNUZV1Type = 28,

/// FloatF8E4M3B11FNUZV1Type {
/// }
kFloatF8E4M3B11FNUZV1Type = 29,
};

} // namespace vhlo_encoding
Expand Down Expand Up @@ -702,6 +706,7 @@ const llvm::fltSemantics &getFloatSemantics(Type type) {
if (type.isa<FloatF32V1Type>()) return APFloat::IEEEsingle();
if (type.isa<FloatF64V1Type>()) return APFloat::IEEEdouble();
if (type.isa<FloatF8E4M3FNUZV1Type>()) return APFloat::Float8E4M3FNUZ();
if (type.isa<FloatF8E4M3B11FNUZV1Type>()) return APFloat::Float8E4M3B11FNUZ();
if (type.isa<FloatF8E4M3FNV1Type>()) return APFloat::Float8E4M3FN();
if (type.isa<FloatF8E5M2FNUZV1Type>()) return APFloat::Float8E5M2FNUZ();
if (type.isa<FloatF8E5M2V1Type>()) return APFloat::Float8E5M2();
Expand Down Expand Up @@ -975,6 +980,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const {
return FloatF8E5M2FNUZV1Type::get(getContext());
case vhlo_encoding::kFloatF8E4M3FNUZV1Type:
return FloatF8E4M3FNUZV1Type::get(getContext());
case vhlo_encoding::kFloatF8E4M3B11FNUZV1Type:
return FloatF8E4M3B11FNUZV1Type::get(getContext());
case vhlo_encoding::kFunctionV1Type:
return readFunctionV1Type(reader);
case vhlo_encoding::kIndexV1Type:
Expand Down Expand Up @@ -1063,6 +1070,11 @@ LogicalResult VhloBytecodeInterface::writeType(
return writer.writeVarInt(vhlo_encoding::kFloatF8E4M3FNUZV1Type),
success();
})
.Case([&](FloatF8E4M3B11FNUZV1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF8E4M3B11FNUZV1Type),
success();
})
.Case([&](FloatF8E5M2FNUZV1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF8E5M2FNUZV1Type),
Expand Down
1 change: 1 addition & 0 deletions stablehlo/dialect/VhloDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def VHLO_Dialect : Dialect {
Version log:
0.9.0: Initial stability guarantees.
0.10.0: Introduce `f8E4M3FNUZ` and `f8E5M2FNUZ` types.
0.11.0: Introduce `f8E4M3B11FNUZ` type.
}];

let useDefaultAttributePrinterParser = 0;
Expand Down
6 changes: 6 additions & 0 deletions stablehlo/dialect/VhloTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ void VhloTypeConverter::addBuiltinToVhloConversions() {
addConversion([&](Float8E4M3FNUZType type) {
return FloatF8E4M3FNUZV1Type::get(type.getContext());
});
addConversion([&](Float8E4M3B11FNUZType type) {
return FloatF8E4M3B11FNUZV1Type::get(type.getContext());
});
addConversion([&](Float8E5M2FNUZType type) {
return FloatF8E5M2FNUZV1Type::get(type.getContext());
});
Expand Down Expand Up @@ -152,6 +155,9 @@ void VhloTypeConverter::addVhloToBuiltinConversions() {
addConversion([&](FloatF8E4M3FNUZV1Type type) {
return Float8E4M3FNUZType::get(type.getContext());
});
addConversion([&](FloatF8E4M3B11FNUZV1Type type) {
return Float8E4M3B11FNUZType::get(type.getContext());
});
addConversion([&](FloatF8E5M2FNUZV1Type type) {
return Float8E5M2FNUZType::get(type.getContext());
});
Expand Down
3 changes: 3 additions & 0 deletions stablehlo/dialect/VhloTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def VHLO_FloatF8E5M2V1 : VHLO_TypeDef<"FloatF8E5M2V1", "f8E5M2_v1", "0.9.0", "cu
// Corresponds to the 'f8E4M3FNUZ' FloatType from the StableHLO spec.
def VHLO_FloatF8E4M3FNUZV1 : VHLO_TypeDef<"FloatF8E4M3FNUZV1", "f8E4M3FNUZ_v1", "0.10.0", "current">;

// Corresponds to the 'f8E4M3B11FNUZ' FloatType from the StableHLO spec.
def VHLO_FloatF8E4M3B11FNUZV1 : VHLO_TypeDef<"FloatF8E4M3B11FNUZV1", "f8E4M3B11FNUZ_v1", "0.11.0", "current">;

// Corresponds to the 'f8E5M2FNUZ' FloatType from the StableHLO spec.
def VHLO_FloatF8E5M2FNUZV1 : VHLO_TypeDef<"FloatF8E5M2FNUZV1", "f8E5M2FNUZ_v1", "0.10.0", "current">;

Expand Down
16 changes: 12 additions & 4 deletions stablehlo/reference/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ Element Tensor::get(const Index &index) const {
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3FNUZ(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E4M3B11FNUZ()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3B11FNUZ(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E5M2FNUZ()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E5M2FNUZ(),
Expand Down Expand Up @@ -217,7 +222,8 @@ void Tensor::set(const Index &index, const Element &element) {
getSizeInBytes(elementType) * flattenIndex(getShape(), index);

// Handle floating-point types.
if (elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2FNUZ()) {
if (elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E4M3B11FNUZ() ||
elementType.isFloat8E5M2FNUZ()) {
auto elementData = reinterpret_cast<uint8_t *>(elementPtr);
auto value = element.getFloatValue();
*elementData = (uint8_t)value.bitcastToAPInt().getZExtValue();
Expand Down Expand Up @@ -371,14 +377,16 @@ Tensor makeTensor(DenseElementsAttr attr) {
auto elemType = type.getElementType();

// Handle floating-point types.
if (elemType.isFloat8E4M3FNUZ() || elemType.isFloat8E5M2FNUZ()) {
if (elemType.isFloat8E4M3FNUZ() || elemType.isFloat8E4M3B11FNUZ() ||
elemType.isFloat8E5M2FNUZ()) {
auto floatValues = llvm::to_vector(llvm::map_range(
attr.getValues<APFloat>(), [&](APFloat value) -> uint8_t {
return value.bitcastToAPInt().getZExtValue();
}));

// For both f8E4M3FNUZ and f8E5M2FNUZ floating-point types, we use uint8_t
// as their storage type because there are no builtin types for those.
// For f8E4M3FNUZ, f8E4M3B11FNUZ, and f8E5M2FNUZ floating-point types, we
// use uint8_t as their storage type because there are no builtin types for
// those.
return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign<uint8_t>(
floatValues));
}
Expand Down
5 changes: 3 additions & 2 deletions stablehlo/reference/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ bool isSupportedIntegerType(Type type) {
}

bool isSupportedFloatType(Type type) {
return type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() || type.isF16() ||
type.isBF16() || type.isF32() || type.isF64();
return type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ() ||
type.isFloat8E5M2FNUZ() || type.isF16() || type.isBF16() ||
type.isF32() || type.isF64();
}

bool isSupportedComplexType(Type type) {
Expand Down
7 changes: 7 additions & 0 deletions stablehlo/tests/interpret_constant.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ func.func @constant_op_test_f8_e4m3_fnuz() {

// -----

func.func @constant_op_test_f8_e4m3b11_fnuz() {
%0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E4M3B11FNUZ>
check.expect_eq_const %0, dense<[0.0, 0.0, 1.0, 0.125, 0.1, 3.25, 30.0, -30.0, 0.0001220703125, -0.0001220703125]> : tensor<10xf8E4M3B11FNUZ>
func.return
}

// -----
func.func @constant_op_test_f8_e5m2_fnuz() {
%0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E5M2FNUZ>
check.expect_eq_const %0, dense<[0.0, 0.0, 1.0, 0.125, 0.1, 3.0, 57344.0, -57344.0, 7.62939e-06, -7.62939e-06]> : tensor<10xf8E5M2FNUZ>
Expand Down
Loading

0 comments on commit c42cc10

Please sign in to comment.