From fb87c47444a6dfa7bbe073eacdf9d2e5f71274b6 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 17 Sep 2025 13:52:13 +0000 Subject: [PATCH 1/3] [mlir][tosa] Add support for mxint8 type in mxfp operations This commit adds support for the OCP-MX INT8 type. This includes the following operations: MATMUL_T_BLOCK_SCALED, CAST_FROM_BLOCK_SCALED, CAST_TO_BLOCK_SCALED and CONST. The support is added via a custom TOSA type "!tosa.mxint8" due to the fact it is not yet a builtin type in mlir. This may change in the future, depending on how this type is used by other frameworks/ dialects. Conversions to/from this type have not yet been implemented for the same reasoning. Co-authored-by: Tat Wai Chong Change-Id: I6dbba8d55075111cae6b3186cef90fd87d9e5ae6 --- .../Dialect/Tosa/IR/TosaComplianceData.h.inc | 17 ++++-- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 3 + .../Dialect/Tosa/IR/TosaProfileCompliance.h | 2 +- .../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 33 +++++++---- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 6 ++ .../Tosa/Transforms/TosaProfileCompliance.cpp | 3 + .../Tosa/Transforms/TosaValidation.cpp | 7 ++- mlir/test/Dialect/Tosa/ops.mlir | 21 +++++++ .../tosa-validation-version-1p1-valid.mlir | 56 +++++++++++++++++++ 9 files changed, 127 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc index 8b5934ff0630e..c774d870a8c45 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc @@ -572,6 +572,8 @@ extensionComplianceMap = { {{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, {{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}, + {{mxint8T, fp8ue8m0T, mxint8T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}}, {"tosa.max_pool2d", {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}}, @@ -870,14 +872,16 @@ extensionComplianceMap = { {{fp6e2m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, {{fp6e3m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, {{fp8e4m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, - {{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}, + {{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, + {{mxint8T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}, allOf}, {{Extension::mxfp}, {{{fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, {{fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, {{fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, {{fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, - {{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}}, + {{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}, + {{mxint8T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}}, {"tosa.cast_to_block_scaled", {{{Extension::mxfp}, {{{bf16T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, @@ -885,12 +889,14 @@ extensionComplianceMap = { {{fp32T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, {{fp32T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, {{fp32T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, - {{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}}, + {{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{fp32T, mxint8T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}}, {{Extension::bf16, Extension::mxfp}, {{{bf16T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, {{bf16T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, {{bf16T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, - {{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}, + {{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, + {{bf16T, mxint8T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}, allOf}}}, {"tosa.rescale", {{{Extension::int16}, @@ -908,7 +914,8 @@ extensionComplianceMap = { {{{fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}, {{fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT}, {{fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT}, - {{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT}}}}}, + {{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT}, + {{mxint8T}, SpecificationVersion::V_1_1_DRAFT}}}}}, {"tosa.identity", {{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}}, {{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}}, diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h index a15f073bc5fcb..2d4e7cf8b9dbd 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -179,6 +179,9 @@ Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, // returns type of variable op RankedTensorType getVariableType(VariableOp variableOp); +// Returns the bitwidth of a TOSA tensor element type +unsigned getBitWidth(Type type); + } // namespace tosa } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h index 45d380c1b2e6c..ea58f49b64c44 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h @@ -70,7 +70,7 @@ class ProfileInfoDepot { private: TypeInfo convertTypeToInfo(Type type) { - return {type.getTypeID(), type.getIntOrFloatBitWidth()}; + return {type.getTypeID(), tosa::getBitWidth(type)}; } TypeInfo convertValueToInfo(Value value) { diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 93843e86fd378..414b51bf4b135 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -22,6 +22,12 @@ include "mlir/Dialect/Tosa/IR/TosaOpBase.td" // Tosa Type Definitions. //===----------------------------------------------------------------------===// +// The base class for Tosa dialect types. +class Tosa_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + // The base class of a quantized type. // Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end]. // Where low and high ends are 0,255 when unsigned, -128,127 when signed, for @@ -78,13 +84,26 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>, Tosa_QuantizedType<"int16", [16, 0], 1>, Tosa_QuantizedType<"int32", [32, 0], 1>]>; +//===----------------------------------------------------------------------===// +// Custom TOSA element types. +//===----------------------------------------------------------------------===// + +// MLIR doesn't have a builtin type for mxint8 yet. For now declared it as a +// custom TOSA type. This may be changed in the future. +def Tosa_MXInt8 : Tosa_Type<"mxint8", "mxint8"> { + let summary = "INT8 type as defined by OCP-MX"; + let description = [{ + 8-bit integer format with an implicit 1/64 scale defined by OCP-MX. + }]; +} + //===----------------------------------------------------------------------===// // Multi-category types. //===----------------------------------------------------------------------===// -def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat], +def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat, Tosa_MXInt8], "number">; -def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN], +def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN, Tosa_MXInt8], "micro-scaling format number">; def Tosa_MXFPScaleNumber : AnyTypeOf<[F8E8M0FNU], "micro-scaling format scale number">; @@ -265,16 +284,6 @@ def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>; def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>; def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>; -//===----------------------------------------------------------------------===// -// Tosa Type Definitions. -//===----------------------------------------------------------------------===// - -// The base class for Tosa dialect types. -class Tosa_Type traits = []> - : TypeDef { - let mnemonic = typeMnemonic; -} - //===----------------------------------------------------------------------===// // ShapeType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 0aff67f0b5eba..bf3810ff231da 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -606,6 +606,12 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc, return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr); } +unsigned mlir::tosa::getBitWidth(Type type) { + if (dyn_cast(type)) + return 8; + return type.getIntOrFloatBitWidth(); +} + //===----------------------------------------------------------------------===// // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index ab363ee6b4d2a..ddd9c70402fdc 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -31,6 +31,7 @@ TosaProfileCompliance::TosaProfileCompliance() { const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6}; const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4}; const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8}; + const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8}; // The profile-based compliance content below is auto-generated by a script // in https://git.mlplatform.org/tosa/specification.git @@ -625,6 +626,8 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) { return {"fp4e2m1"}; } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) { return {"fp8e8m0"}; + } else if (typeInfo.typeID == tosa::mxint8Type::getTypeID()) { + return {"mxint8"}; } llvm_unreachable("unknown type"); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 4d0b61acc4ea4..9676ea5ca4868 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -693,7 +693,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op, << " shape dimension cannot be dynamic"; } - int64_t element_bits = type.getElementTypeBitWidth(); + int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type)); int64_t element_bytes = std::max(INT64_C(1), element_bits / 8); int64_t size = element_bytes * type.getNumElements(); @@ -1217,9 +1217,10 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { return true; } } - } else if (mlir::isa(type)) { + } else if (mlir::isa(type)) + return true; + else if (isa(type)) return true; - } return false; } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 865f712ce1a5a..22fde3b7d28a5 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -1269,6 +1269,13 @@ func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor, return %0 : tensor<4x8x16xf32> } +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_mxint8 +func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + // ----- // CHECK-LABEL: test_cast_from_block_scaled_static func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { @@ -1296,3 +1303,17 @@ func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<* %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU> } + +// ----- +// CHECK-LABEL: test_cast_to_block_scaled_mxint8 +func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) { + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU> +} + +// ----- +// CHECK-LABEL: test_const_mxint8 +func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> { + %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8> + return %0 : tensor<2x!tosa.mxint8> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir index f3d8dab2f6b0f..d8cbaa2c356c3 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir @@ -82,3 +82,59 @@ func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor< %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU> } + +// ----- + +// CHECK-LABEL: test_cast_to_block_scaled_mxint8 +func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) { + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU> +} + +// ----- + +// CHECK-LABEL: test_const_fp6e3m2 +func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> { + %0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN> + return %0 : tensor<4xf6E3M2FN> +} + +// ----- + +// CHECK-LABEL: test_const_mxint8 +func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> { + %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8> + return %0 : tensor<2x!tosa.mxint8> +} + +// ----- + +// CHECK-LABEL: test_cast_f4e2m1 +func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> { + %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> + return %0 : tensor<13x21x3xbf16> +} + +// ----- + +// CHECK-LABEL: test_matmul_t_block_scaled_mxint8 +func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { + %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> + return %0 : tensor<4x8x16xf32> +} + +// ----- + +// CHECK-LABEL: test_cast_to_block_scaled_mxint8 +func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) { + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) + return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU> +} + +// ----- + +// CHECK-LABEL: test_const_mxint8 +func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> { + %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8> + return %0 : tensor<2x!tosa.mxint8> +} From 6ca24ad1c9e920443dab1306f8612ecd72477909 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 28 Oct 2025 11:28:40 +0000 Subject: [PATCH 2/3] Clean up test cases * remove duplicate test * remove unused input parameter * use a list of strings, rather than one string with hex values Change-Id: I344eb12ed3b6cbd6d5cca04647667296f49e9df4 --- .../tosa-validation-version-1p1-valid.mlir | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir index d8cbaa2c356c3..9bd7aa8f0783e 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir @@ -38,7 +38,7 @@ func.func @test_argmax_int64(%arg0: tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64 // ----- // CHECK-LABEL: test_const_i64 -func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> { +func.func @test_const_i64() -> tensor<4xi64> { %0 = "tosa.const"() {values = dense<[3, 0, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> return %0 : tensor<4xi64> } @@ -46,7 +46,7 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> { // ----- // CHECK-LABEL: test_const_fp6e3m2 -func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> { +func.func @test_const_fp6e3m2() -> tensor<4xf6E3M2FN> { %0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN> return %0 : tensor<4xf6E3M2FN> } @@ -94,7 +94,7 @@ func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor< // ----- // CHECK-LABEL: test_const_fp6e3m2 -func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> { +func.func @test_const_fp6e3m2() -> tensor<4xf6E3M2FN> { %0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN> return %0 : tensor<4xf6E3M2FN> } @@ -102,8 +102,8 @@ func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> { // ----- // CHECK-LABEL: test_const_mxint8 -func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> { - %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8> +func.func @test_const_mxint8() -> tensor<2x!tosa.mxint8> { + %0 = "tosa.const"() {values = dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8> return %0 : tensor<2x!tosa.mxint8> } @@ -130,11 +130,3 @@ func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor< %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU> } - -// ----- - -// CHECK-LABEL: test_const_mxint8 -func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> { - %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8> - return %0 : tensor<2x!tosa.mxint8> -} From e9a2cc1354de3bbe3d077524d60657565ed1576a Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 29 Oct 2025 09:40:33 +0000 Subject: [PATCH 3/3] Drop "mlir::" prefix Change-Id: Iaa685661cc34d5a738260821ecfaff12f63812b4 --- mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 9676ea5ca4868..b54ed5585d72d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -1217,7 +1217,7 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { return true; } } - } else if (mlir::isa(type)) + } else if (isa(type)) return true; else if (isa(type)) return true;