diff --git a/docs/bytecode.md b/docs/bytecode.md index f1c4c5a291..5cec0b1155 100644 --- a/docs/bytecode.md +++ b/docs/bytecode.md @@ -21,7 +21,7 @@ this format, which we successfully did for StableHLO Portable artifacts can be created using either the `stablehlo-translate` tool, or directly in C++ or Python APIs. Serialization needs a target version of -StableHLO to write an artifact (the current version is 0.9.0). Deserialization +StableHLO to write an artifact (the current version is 0.10.0). Deserialization uses the current version of StableHLO to read an artifact. ### Using the `stablehlo-translate` tool diff --git a/docs/compatibility.md b/docs/compatibility.md index dd252b58dc..ea4bec6e3a 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -7,7 +7,7 @@ that StableHLO provides, based on the process established in ## Versions -The current version of StableHLO is **0.9.0**. +The current version of StableHLO is **0.10.0**. In the 0.x.x series, StableHLO is providing limited compatibility guarantees as described in the Guarantees section. In H2 2023, we are planning to release diff --git a/docs/spec.md b/docs/spec.md index f9c7c2f6fc..23bc50af3b 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -226,7 +226,8 @@ ElementType ::= BooleanType | IntegerType | FloatType | ComplexType BooleanType ::= 'i1' IntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64' -FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'bf16' | 'f16' | 'f32' | 'f64' +FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ' + | 'bf16' | 'f16' | 'f32' | 'f64' ComplexType ::= 'complex' '<' ('f32' | 'f64') '>' ``` @@ -246,6 +247,9 @@ values of type `tensor`). * `f8E4M3FN` and `f8E5M2` types corresponding to respectively the `E4M3` and `E5M2` encodings of the FP8 format described in [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433). + * `f8E4M3FNUZ` and `f8E5M2FNUZ` types corresponding to the `E4M3` and `E5M2` + encodings of the FP8 formats described in + [8-bit Numerical Formats for Deep Neural Networks](https://arxiv.org/abs/2206.02915). * `bf16` type corresponding to the `bfloat16` format described in [BFloat16: The secret to high performance on Cloud TPUs](https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus). * `f16`, `f32` and `f64` types corresponding to respectively diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td index 8e0ec7ae07..8934326b50 100644 --- a/stablehlo/dialect/Base.td +++ b/stablehlo/dialect/Base.td @@ -33,7 +33,8 @@ def HLO_SInt : SignlessIntOfWidths<[4, 8, 16, 32, 64]>; def HLO_UInt : UnsignedIntOfWidths<[4, 8, 16, 32, 64]>; def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>; -def HLO_Float : AnyTypeOf<[F8E4M3FN, F8E5M2, F16, F32, F64, BF16]>; +def HLO_Float : AnyTypeOf<[F8E4M3FN, F8E5M2, F8E4M3FNUZ, F8E5M2FNUZ, + F16, F32, F64, BF16]>; def HLO_Float32Or64 : AnyTypeOf<[F32, F64]>; def HLO_Complex : Complex>; diff --git a/stablehlo/dialect/Version.h b/stablehlo/dialect/Version.h index 31d24ab713..a17a05f43f 100644 --- a/stablehlo/dialect/Version.h +++ b/stablehlo/dialect/Version.h @@ -38,7 +38,7 @@ class Version { static FailureOr fromString(llvm::StringRef versionRef); /// Return a Version representing the current VHLO dialect version. - static Version getCurrentVersion() { return Version(0, 9, 0); } + static Version getCurrentVersion() { return Version(0, 10, 0); } /// Return a Version representing the minimum supported VHLO dialect version. static Version getMinimumVersion() { return Version(0, 9, 0); } diff --git a/stablehlo/dialect/VhloBytecode.cpp b/stablehlo/dialect/VhloBytecode.cpp index de59de35f9..4e7c873b10 100644 --- a/stablehlo/dialect/VhloBytecode.cpp +++ b/stablehlo/dialect/VhloBytecode.cpp @@ -316,6 +316,14 @@ enum TypeCode { /// WitnessV1Type { /// } kWitnessV1Type = 26, + + /// FloatF8E4M3FNUZV1Type { + /// } + kFloatF8E4M3FNUZV1Type = 27, + + /// FloatF8E5M2FNUZV1Type { + /// } + kFloatF8E5M2FNUZV1Type = 28, }; } // namespace vhlo_encoding @@ -693,7 +701,9 @@ const llvm::fltSemantics &getFloatSemantics(Type type) { if (type.isa()) return APFloat::IEEEhalf(); if (type.isa()) return APFloat::IEEEsingle(); if (type.isa()) return APFloat::IEEEdouble(); + if (type.isa()) return APFloat::Float8E4M3FNUZ(); if (type.isa()) return APFloat::Float8E4M3FN(); + if (type.isa()) return APFloat::Float8E5M2FNUZ(); if (type.isa()) return APFloat::Float8E5M2(); llvm::report_fatal_error("unsupported floating-point type"); } @@ -961,6 +971,10 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const { return FloatF8E5M2V1Type::get(getContext()); case vhlo_encoding::kFloatF8E4M3FNV1Type: return FloatF8E4M3FNV1Type::get(getContext()); + case vhlo_encoding::kFloatF8E5M2FNUZV1Type: + return FloatF8E5M2FNUZV1Type::get(getContext()); + case vhlo_encoding::kFloatF8E4M3FNUZV1Type: + return FloatF8E4M3FNUZV1Type::get(getContext()); case vhlo_encoding::kFunctionV1Type: return readFunctionV1Type(reader); case vhlo_encoding::kIndexV1Type: @@ -1044,6 +1058,16 @@ LogicalResult VhloBytecodeInterface::writeType( LOG_WRITE_CALL; return writer.writeVarInt(vhlo_encoding::kFloatF8E5M2V1Type), success(); }) + .Case([&](FloatF8E4M3FNUZV1Type) { + LOG_WRITE_CALL; + return writer.writeVarInt(vhlo_encoding::kFloatF8E4M3FNUZV1Type), + success(); + }) + .Case([&](FloatF8E5M2FNUZV1Type) { + LOG_WRITE_CALL; + return writer.writeVarInt(vhlo_encoding::kFloatF8E5M2FNUZV1Type), + success(); + }) .Case([&](IndexV1Type) { LOG_WRITE_CALL; return writer.writeVarInt(vhlo_encoding::kIndexV1Type), success(); diff --git a/stablehlo/dialect/VhloDialect.td b/stablehlo/dialect/VhloDialect.td index 4adeafac9d..7261c7bbb2 100644 --- a/stablehlo/dialect/VhloDialect.td +++ b/stablehlo/dialect/VhloDialect.td @@ -28,6 +28,7 @@ def VHLO_Dialect : Dialect { Version log: 0.9.0: Initial stability guarantees. + 0.10.0: Introduce `f8E4M3FNUZ` and `f8E5M2FNUZ` types. }]; let useDefaultAttributePrinterParser = 0; diff --git a/stablehlo/dialect/VhloTypes.cpp b/stablehlo/dialect/VhloTypes.cpp index bbfd4aed2d..36cf7e9414 100644 --- a/stablehlo/dialect/VhloTypes.cpp +++ b/stablehlo/dialect/VhloTypes.cpp @@ -77,6 +77,12 @@ void VhloTypeConverter::addBuiltinToVhloConversions() { addConversion([&](Float8E5M2Type type) { return FloatF8E5M2V1Type::get(type.getContext()); }); + addConversion([&](Float8E4M3FNUZType type) { + return FloatF8E4M3FNUZV1Type::get(type.getContext()); + }); + addConversion([&](Float8E5M2FNUZType type) { + return FloatF8E5M2FNUZV1Type::get(type.getContext()); + }); addConversion([&](FunctionType type) -> Type { SmallVector convertedInputs; SmallVector convertedResults; @@ -143,6 +149,12 @@ void VhloTypeConverter::addVhloToBuiltinConversions() { addConversion([&](FloatF8E5M2V1Type type) { return Float8E5M2Type::get(type.getContext()); }); + addConversion([&](FloatF8E4M3FNUZV1Type type) { + return Float8E4M3FNUZType::get(type.getContext()); + }); + addConversion([&](FloatF8E5M2FNUZV1Type type) { + return Float8E5M2FNUZType::get(type.getContext()); + }); addConversion([&](FunctionV1Type type) -> Type { SmallVector convertedInputs; SmallVector convertedOutputs; diff --git a/stablehlo/dialect/VhloTypes.td b/stablehlo/dialect/VhloTypes.td index b4c0d8ca84..c3f8d69c10 100644 --- a/stablehlo/dialect/VhloTypes.td +++ b/stablehlo/dialect/VhloTypes.td @@ -86,6 +86,12 @@ def VHLO_FloatF8E4M3FNV1 : VHLO_TypeDef<"FloatF8E4M3FNV1", "f8E4M3FN_v1", "0.9.0 // Corresponds to the 'f8E5M2' FloatType from the StableHLO spec. def VHLO_FloatF8E5M2V1 : VHLO_TypeDef<"FloatF8E5M2V1", "f8E5M2_v1", "0.9.0", "current">; +// Corresponds to the 'f8E4M3FNUZ' FloatType from the StableHLO spec. +def VHLO_FloatF8E4M3FNUZV1 : VHLO_TypeDef<"FloatF8E4M3FNUZV1", "f8E4M3FNUZ_v1", "0.10.0", "current">; + +// Corresponds to the 'f8E5M2FNUZ' FloatType from the StableHLO spec. +def VHLO_FloatF8E5M2FNUZV1 : VHLO_TypeDef<"FloatF8E5M2FNUZV1", "f8E5M2FNUZ_v1", "0.10.0", "current">; + // Corresponds to FunctionType from the StableHLO spec. def VHLO_FunctionV1 : VHLO_TypeDef<"FunctionV1", "func_v1", "0.9.0", "current"> { let parameters = (ins diff --git a/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/integrations/python/tests/stablehlo.py index b5fc5ca99d..98a6367a53 100644 --- a/stablehlo/integrations/python/tests/stablehlo.py +++ b/stablehlo/integrations/python/tests/stablehlo.py @@ -210,7 +210,7 @@ def test_api_version(): @run def test_serialization_apis(): curr_version = stablehlo.get_current_version() - assert curr_version == "0.9.0" + assert curr_version == "0.10.0" ASM = """ func.func @test(%arg0: tensor<2xf32>) -> tensor<2xf32> { diff --git a/stablehlo/reference/Tensor.cpp b/stablehlo/reference/Tensor.cpp index 4c3c3ac5d3..d2758db2a8 100644 --- a/stablehlo/reference/Tensor.cpp +++ b/stablehlo/reference/Tensor.cpp @@ -100,6 +100,16 @@ Element Tensor::get(const Index &index) const { getSizeInBytes(elementType) * flattenIndex(getShape(), index); // Handle floating-point types. + if (elementType.isFloat8E4M3FNUZ()) { + auto elementData = reinterpret_cast(elementPtr); + return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3FNUZ(), + APInt(8, *elementData))); + } + if (elementType.isFloat8E5M2FNUZ()) { + auto elementData = reinterpret_cast(elementPtr); + return Element(elementType, APFloat(llvm::APFloatBase::Float8E5M2FNUZ(), + APInt(8, *elementData))); + } if (elementType.isF16()) { auto elementData = reinterpret_cast(elementPtr); return Element(elementType, APFloat(llvm::APFloatBase::IEEEhalf(), @@ -207,6 +217,13 @@ void Tensor::set(const Index &index, const Element &element) { getSizeInBytes(elementType) * flattenIndex(getShape(), index); // Handle floating-point types. + if (elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2FNUZ()) { + auto elementData = reinterpret_cast(elementPtr); + auto value = element.getFloatValue(); + *elementData = (uint8_t)value.bitcastToAPInt().getZExtValue(); + return; + } + if (elementType.isF16() || elementType.isBF16()) { auto elementData = reinterpret_cast(elementPtr); auto value = element.getFloatValue(); @@ -354,6 +371,18 @@ Tensor makeTensor(DenseElementsAttr attr) { auto elemType = type.getElementType(); // Handle floating-point types. + if (elemType.isFloat8E4M3FNUZ() || elemType.isFloat8E5M2FNUZ()) { + auto floatValues = llvm::to_vector(llvm::map_range( + attr.getValues(), [&](APFloat value) -> uint8_t { + return value.bitcastToAPInt().getZExtValue(); + })); + + // For both f8E4M3FNUZ and f8E5M2FNUZ floating-point types, we use uint8_t + // as their storage type because there are no builtin types for those. + return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign( + floatValues)); + } + if (elemType.isF16() || elemType.isBF16()) { auto floatValues = llvm::to_vector(llvm::map_range( attr.getValues(), [&](APFloat value) -> uint16_t { diff --git a/stablehlo/reference/Types.cpp b/stablehlo/reference/Types.cpp index 9326492ea9..b1fdc4b0be 100644 --- a/stablehlo/reference/Types.cpp +++ b/stablehlo/reference/Types.cpp @@ -44,7 +44,8 @@ bool isSupportedIntegerType(Type type) { } bool isSupportedFloatType(Type type) { - return type.isF16() || type.isBF16() || type.isF32() || type.isF64(); + return type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() || type.isF16() || + type.isBF16() || type.isF32() || type.isF64(); } bool isSupportedComplexType(Type type) { diff --git a/stablehlo/tests/interpret_constant.mlir b/stablehlo/tests/interpret_constant.mlir index 2c773f4c29..82558065fa 100644 --- a/stablehlo/tests/interpret_constant.mlir +++ b/stablehlo/tests/interpret_constant.mlir @@ -80,6 +80,22 @@ func.func @constant_op_test_ui64() { // ----- +func.func @constant_op_test_f8_e4m3_fnuz() { + %0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E4M3FNUZ> + check.expect_eq_const %0, dense<[0.0, 0.0, 1.0, 0.125, 0.1, 3.25, 240.0, -240.0, 0.000976562, -0.000976562]> : tensor<10xf8E4M3FNUZ> + func.return +} + +// ----- + +func.func @constant_op_test_f8_e5m2_fnuz() { + %0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E5M2FNUZ> + check.expect_eq_const %0, dense<[0.0, 0.0, 1.0, 0.125, 0.1, 3.0, 57344.0, -57344.0, 7.62939e-06, -7.62939e-06]> : tensor<10xf8E5M2FNUZ> + func.return +} + +// ----- + func.func @constant_op_test_bf16() { %0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.140630, 0x7F80, 0xFF80, 0x7FFF, 0x0001, 0x8001]> : tensor<11xbf16> check.expect_almost_eq_const %0, dense<[0.000000e+00, -0.000000e+00, 1.000000e+00, 1.250000e-01, 1.000980e-01, 3.140630e+00, 0x7F80, 0xFF80, 0x7FFF, 9.183550e-41, -9.183550e-41]> : tensor<11xbf16> diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 08c46357f1..652dcec13a 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -1437,7 +1437,7 @@ func.func @cholesky_invalid_rank(%arg0: tensor<1xf32>) -> tensor<1xf32> { // ----- func.func @cholesky_invalid_elt(%arg0: tensor<1x2x2xi32>) -> tensor<1x2x2xi32> { - // expected-error@+1 {{operand #0 must be tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements values, but got 'tensor<1x2x2xi32>}} + // expected-error@+1 {{operand #0 must be tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements values, but got 'tensor<1x2x2xi32>}} %0 = "stablehlo.cholesky"(%arg0) { lower = true } : (tensor<1x2x2xi32>) -> tensor<1x2x2xi32> func.return %0: tensor<1x2x2xi32> } @@ -1908,7 +1908,7 @@ func.func @rng_normal_invalid_shape(%arg0: tensor, %arg1: tensor) { func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo}: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1917,7 +1917,7 @@ func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor) - func.func @rng_normal_invalid_sigma_rank(%mu: tensor, %sigma: tensor<1xf32>) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{#1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{#1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo}: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1935,7 +1935,7 @@ func.func @rng_normal_invalid_shape_rank(%mu: tensor, %sigma: tensor) func.func @rng_normal_invalid_type(%arg0: tensor>, %arg1: tensor) { %cst = "stablehlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> - // expected-error @+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} + // expected-error @+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} %0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo}: (tensor>, tensor, tensor<1xi64>) -> tensor<7xf32> func.return } @@ -1977,7 +1977,7 @@ func.func @rng_uniform_invalid_shape(%arg0: tensor, %arg1: tensor, %ar func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1987,7 +1987,7 @@ func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor) -> ten func.func @rng_uniform_invalid_b_rank(%a: tensor, %b: tensor<1xf32>) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{operand #1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2005,7 +2005,7 @@ func.func @rng_uniform_invalid_shape_rank(%a: tensor, %b: tensor) -> t func.func @rng_uniform_invalid_type(%a: tensor>, %b: tensor) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} + // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2611,7 +2611,7 @@ func.func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te // ----- func.func @floor_invalid_i32_type(%arg0: tensor<4xi32>) -> tensor<4xi32> { - // expected-error@+1 {{must be tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<4xi32>'}} + // expected-error@+1 {{must be tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<4xi32>'}} %0 = "stablehlo.floor"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> func.return %0 : tensor<4xi32> } @@ -4858,7 +4858,7 @@ func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<4xf // ----- func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xi32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<2x2x2x2xi32>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<2x2x2x2xi32>'}} %0:3 = "stablehlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xi32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } @@ -4890,7 +4890,7 @@ func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf // ----- func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{result #1 must be 1D tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<2x2xf32>'}} + // expected-error@+1 {{result #1 must be 1D tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<2x2xf32>'}} %0:3 = "stablehlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2x2xf32>, tensor<2xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } @@ -4898,7 +4898,7 @@ func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf // ----- func.func @error_batch_norm_grad(%input: tensor<*xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<*xf32> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<*xf32>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<*xf32>'}} %0:3 = "stablehlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>) func.return %0#0 : tensor<*xf32> } @@ -5655,7 +5655,7 @@ func.func @is_finite(%arg0: tensor<3xf32>) -> tensor<3xi1> { // ----- func.func @is_finite_int_input(%arg0: tensor<3xi32>) -> tensor<3xi1> { - // expected-error@+1 {{operand #0 must be tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<3xi32>'}} + // expected-error@+1 {{operand #0 must be tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<3xi32>'}} %0 = "stablehlo.is_finite"(%arg0) {} : (tensor<3xi32>) -> tensor<3xi1> func.return %0 : tensor<3xi1> } @@ -5710,6 +5710,22 @@ func.func @f8e5m2(%arg0: tensor) -> tensor { // ----- +// CHECK-LABEL: func @f8e4m3fnuz +func.func @f8e4m3fnuz(%arg0: tensor) -> tensor { + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @f8e5m2fnuz +func.func @f8e5m2fnuz(%arg0: tensor) -> tensor { + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @dynamic_iota_static() -> tensor<4xf32> { %0 = stablehlo.constant dense<[4]> : tensor<1xi64> %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32> diff --git a/stablehlo/tests/ops_stablehlo_roundtrip.mlir b/stablehlo/tests/ops_stablehlo_roundtrip.mlir index e727d360b3..39dbb6db9c 100644 --- a/stablehlo/tests/ops_stablehlo_roundtrip.mlir +++ b/stablehlo/tests/ops_stablehlo_roundtrip.mlir @@ -179,10 +179,12 @@ func.func @test_constants() { %cst_4 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> %cst_5 = arith.constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32> %cst_6 = arith.constant dense<[[1, 2], [4, 8]]> : tensor<2x2xui32> - %cst_7 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> - %cst_8 = arith.constant dense<[1.0e+00, -4.0e+00, -65504.0e+00, 1.5625e-02]> : tensor<4xf16> - %cst_9 = arith.constant dense<(1.000000e+00,0.000000e+00)> : tensor> - %cst_10 = arith.constant dense<(1.000000e+00,0.000000e+00)> : tensor> + %cst_7 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf8E5M2FNUZ> + %cst_8 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf8E5M2FNUZ> + %cst_9 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> + %cst_10 = arith.constant dense<[1.0e+00, -4.0e+00, -65504.0e+00, 1.5625e-02]> : tensor<4xf16> + %cst_11 = arith.constant dense<(1.000000e+00,0.000000e+00)> : tensor> + %cst_12 = arith.constant dense<(1.000000e+00,0.000000e+00)> : tensor> func.return } diff --git a/stablehlo/tests/stablehlo_legalize_to_vhlo.mlir b/stablehlo/tests/stablehlo_legalize_to_vhlo.mlir index 2afd1dbd00..aac8f9cc84 100644 --- a/stablehlo/tests/stablehlo_legalize_to_vhlo.mlir +++ b/stablehlo/tests/stablehlo_legalize_to_vhlo.mlir @@ -2152,6 +2152,20 @@ func.func @type_f8E5M2(%arg0: tensor, %arg1: tensor) -> tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f8E4M3FNUZ" + +func.func @type_f8E5M2FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f8E5M2FNUZ" + func.func @type_bf16(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor diff --git a/stablehlo/tests/vhlo_to_version_downgrade_invalid.0_9_0.mlir b/stablehlo/tests/vhlo_to_version_downgrade_invalid.0_9_0.mlir new file mode 100644 index 0000000000..23585f1c24 --- /dev/null +++ b/stablehlo/tests/vhlo_to_version_downgrade_invalid.0_9_0.mlir @@ -0,0 +1,15 @@ +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=0.9.0' --verify-diagnostics --split-input-file %s + +// expected-error @+1 {{failed to legalize operation 'vhlo.func_v1' that was explicitly marked illegal}} +func.func @type_fp8_E5M2FNUZ(%arg0: tensor) -> tensor { + %0 = stablehlo.add %arg0, %arg0 : tensor + func.return %0 : tensor +} + +// ----- + +// expected-error @+1 {{failed to legalize operation 'vhlo.func_v1' that was explicitly marked illegal}} +func.func @type_fp8_E4M3FNUZ(%arg0: tensor) -> tensor { + %0 = stablehlo.add %arg0, %arg0 : tensor + func.return %0 : tensor +} diff --git a/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/transforms/VhloToVersion.cpp index 3fb28d7512..d590dd2054 100644 --- a/stablehlo/transforms/VhloToVersion.cpp +++ b/stablehlo/transforms/VhloToVersion.cpp @@ -293,7 +293,8 @@ namespace stablehlo { void populateVhloToVersionPatterns(RewritePatternSet* patterns, TypeConverter* converter, MLIRContext* context) { - // Currently empty because we're starting from a clean slate in v0.9.0. + // Currently empty because we're starting from a clean slate in v0.9.0 and + // changes so far are additive. } } // namespace stablehlo