Skip to content

Conversation

@lhutton1
Copy link
Contributor

@lhutton1 lhutton1 commented Oct 15, 2025

This commit adds support for the OCP-MX INT8 type. This includes the following operations: MATMUL_T_BLOCK_SCALED, CAST_FROM_BLOCK_SCALED, CAST_TO_BLOCK_SCALED and CONST.

The support is added via a custom TOSA type "!tosa.mxint8" due to the fact it is not yet a builtin type in mlir. This may change in the future, depending on how this type is used by other frameworks/dialects. Conversions to/from this type have not yet been implemented for the same reasoning.

Note: This PR relies on #156425, #163433, #163436 and #163641 so also contains their contents.

Co-authored-by: Tat Wai Chong tatwai.chong@arm.com

@github-actions
Copy link

github-actions bot commented Oct 15, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

This commit adds support for the OCP-MX INT8 type. This includes the
following operations: MATMUL_T_BLOCK_SCALED, CAST_FROM_BLOCK_SCALED,
CAST_TO_BLOCK_SCALED and CONST.

The support is added via a custom TOSA type "!tosa.mxint8" due to the
fact it is not yet a builtin type in mlir. This may change in the
future, depending on how this type is used by other frameworks/
dialects. Conversions to/from this type have not yet been implemented
for the same reasoning.

Co-authored-by: Tat Wai Chong <tatwai.chong@arm.com>
Change-Id: I6dbba8d55075111cae6b3186cef90fd87d9e5ae6
@lhutton1 lhutton1 marked this pull request as ready for review October 24, 2025 15:13
@lhutton1
Copy link
Contributor Author

This is rebased & ready for review.

@llvmbot
Copy link
Member

llvmbot commented Oct 24, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

This commit adds support for the OCP-MX INT8 type. This includes the following operations: MATMUL_T_BLOCK_SCALED, CAST_FROM_BLOCK_SCALED, CAST_TO_BLOCK_SCALED and CONST.

The support is added via a custom TOSA type "!tosa.mxint8" due to the fact it is not yet a builtin type in mlir. This may change in the future, depending on how this type is used by other frameworks/dialects. Conversions to/from this type have not yet been implemented for the same reasoning.

Note: This PR relies on #156425, #163433, #163436 and #163641 so also contains their contents.

Co-authored-by: Tat Wai Chong <tatwai.chong@arm.com>


Full diff: https://github.com/llvm/llvm-project/pull/163642.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+12-5)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+3)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+21-12)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+6)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+3)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+4-3)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+21)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+56)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 8b5934ff0630e..c774d870a8c45 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -572,6 +572,8 @@ extensionComplianceMap = {
         {{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T},
          SpecificationVersion::V_1_1_DRAFT},
         {{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T, fp8ue8m0T, mxint8T, fp8ue8m0T, fp32T},
          SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.max_pool2d",
      {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
@@ -870,14 +872,16 @@ extensionComplianceMap = {
         {{fp6e2m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp6e3m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp8e4m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
-        {{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
+        {{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
        allOf},
       {{Extension::mxfp},
        {{{fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
-        {{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
+        {{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.cast_to_block_scaled",
      {{{Extension::mxfp},
        {{{bf16T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
@@ -885,12 +889,14 @@ extensionComplianceMap = {
         {{fp32T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp32T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp32T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
-        {{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
+        {{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T, mxint8T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::bf16, Extension::mxfp},
        {{{bf16T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
         {{bf16T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
         {{bf16T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
-        {{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}},
+        {{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+        {{bf16T, mxint8T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}},
        allOf}}},
     {"tosa.rescale",
      {{{Extension::int16},
@@ -908,7 +914,8 @@ extensionComplianceMap = {
        {{{fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
         {{fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT},
-        {{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT}}}}},
+        {{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T}, SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.identity",
      {{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}},
       {{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index a15f073bc5fcb..2d4e7cf8b9dbd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -179,6 +179,9 @@ Value createPadConstTensor(OpBuilder &builder, Location loc, Value src,
 // returns type of variable op
 RankedTensorType getVariableType(VariableOp variableOp);
 
+// Returns the bitwidth of a TOSA tensor element type
+unsigned getBitWidth(Type type);
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 45d380c1b2e6c..ea58f49b64c44 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -70,7 +70,7 @@ class ProfileInfoDepot {
 
 private:
   TypeInfo convertTypeToInfo(Type type) {
-    return {type.getTypeID(), type.getIntOrFloatBitWidth()};
+    return {type.getTypeID(), tosa::getBitWidth(type)};
   }
 
   TypeInfo convertValueToInfo(Value value) {
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 93843e86fd378..414b51bf4b135 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -22,6 +22,12 @@ include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
 // Tosa Type Definitions.
 //===----------------------------------------------------------------------===//
 
+// The base class for Tosa dialect types.
+class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<Tosa_Dialect, name, traits> {
+  let mnemonic = typeMnemonic;
+}
+
 // The base class of a quantized type.
 // Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end].
 // Where low and high ends are 0,255 when unsigned, -128,127 when signed, for
@@ -78,13 +84,26 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
                                    Tosa_QuantizedType<"int16", [16, 0], 1>,
                                    Tosa_QuantizedType<"int32", [32, 0], 1>]>;
 
+//===----------------------------------------------------------------------===//
+// Custom TOSA element types.
+//===----------------------------------------------------------------------===//
+
+// MLIR doesn't have a builtin type for mxint8 yet. For now declared it as a
+// custom TOSA type. This may be changed in the future.
+def Tosa_MXInt8 : Tosa_Type<"mxint8", "mxint8"> {
+  let summary = "INT8 type as defined by OCP-MX";
+  let description = [{
+    8-bit integer format with an implicit 1/64 scale defined by OCP-MX.
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Multi-category types.
 //===----------------------------------------------------------------------===//
-def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
+def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat, Tosa_MXInt8],
                                 "number">;
 
-def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN],
+def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN, Tosa_MXInt8],
                                 "micro-scaling format number">;
 def Tosa_MXFPScaleNumber : AnyTypeOf<[F8E8M0FNU], "micro-scaling format scale number">;
 
@@ -265,16 +284,6 @@ def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>;
 def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>;
 def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>;
 
-//===----------------------------------------------------------------------===//
-// Tosa Type Definitions.
-//===----------------------------------------------------------------------===//
-
-// The base class for Tosa dialect types.
-class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
-    : TypeDef<Tosa_Dialect, name, traits> {
-  let mnemonic = typeMnemonic;
-}
-
 //===----------------------------------------------------------------------===//
 // ShapeType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0aff67f0b5eba..bf3810ff231da 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -606,6 +606,12 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
   return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
 }
 
+unsigned mlir::tosa::getBitWidth(Type type) {
+  if (dyn_cast<tosa::mxint8Type>(type))
+    return 8;
+  return type.getIntOrFloatBitWidth();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index ab363ee6b4d2a..ddd9c70402fdc 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -31,6 +31,7 @@ TosaProfileCompliance::TosaProfileCompliance() {
   const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
   const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
   const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
+  const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8};
 
 // The profile-based compliance content below is auto-generated by a script
 // in https://git.mlplatform.org/tosa/specification.git
@@ -625,6 +626,8 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
     return {"fp4e2m1"};
   } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) {
     return {"fp8e8m0"};
+  } else if (typeInfo.typeID == tosa::mxint8Type::getTypeID()) {
+    return {"mxint8"};
   }
   llvm_unreachable("unknown type");
 }
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 4d0b61acc4ea4..9676ea5ca4868 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -693,7 +693,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op,
                                  << " shape dimension cannot be dynamic";
     }
 
-    int64_t element_bits = type.getElementTypeBitWidth();
+    int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type));
     int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
     int64_t size = element_bytes * type.getNumElements();
 
@@ -1217,9 +1217,10 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
         return true;
       }
     }
-  } else if (mlir::isa<tosa::shapeType>(type)) {
+  } else if (mlir::isa<tosa::shapeType>(type))
+    return true;
+  else if (isa<tosa::mxint8Type>(type))
     return true;
-  }
   return false;
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 865f712ce1a5a..22fde3b7d28a5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1269,6 +1269,13 @@ func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor<?x8x32xf8E4M3FN>,
   return %0 : tensor<4x8x16xf32>
 }
 
+// -----
+// CHECK-LABEL: test_matmul_t_block_scaled_mxint8
+func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}
+
 // -----
 // CHECK-LABEL: test_cast_from_block_scaled_static
 func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
@@ -1296,3 +1303,17 @@ func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
 }
+
+// -----
+// CHECK-LABEL: test_cast_to_block_scaled_mxint8
+func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+// CHECK-LABEL: test_const_mxint8
+func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> {
+    %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
+    return %0 : tensor<2x!tosa.mxint8>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index f3d8dab2f6b0f..d8cbaa2c356c3 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -82,3 +82,59 @@ func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
 }
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_mxint8
+func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: test_const_fp6e3m2
+func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> {
+    %0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN>
+    return %0 : tensor<4xf6E3M2FN>
+}
+
+// -----
+
+// CHECK-LABEL: test_const_mxint8
+func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> {
+    %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
+    return %0 : tensor<2x!tosa.mxint8>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_f4e2m1
+func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16>
+  return %0 : tensor<13x21x3xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_t_block_scaled_mxint8
+func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
+  %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_to_block_scaled_mxint8
+func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
+  %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
+  return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
+}
+
+// -----
+
+// CHECK-LABEL: test_const_mxint8
+func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> {
+    %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
+    return %0 : tensor<2x!tosa.mxint8>
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants