diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc index d3fd4c3d1d3e1..a9b458acd87f2 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc @@ -35,11 +35,9 @@ profileComplianceMap = { {fp16T, fp16T, fp32T, fp32T}, {fp32T, fp32T, fp32T, fp32T}}}}}, {"tosa.matmul", - {{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}}, + {{{Profile::pro_int}, {{i8T, i8T, i32T}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp32T}, - {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{fp16T, fp16T, fp16T}, {fp16T, fp16T, fp32T}, {fp32T, fp32T, fp32T}}}}}, {"tosa.max_pool2d", {{{Profile::pro_int}, {{i8T, i8T}}}, {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, @@ -275,10 +273,10 @@ extensionComplianceMap = { {{Extension::int16}, {{i16T, i8T, i48T, i48T}}}, {{Extension::bf16}, {{bf16T, bf16T, fp32T, fp32T}}}}}, {"tosa.matmul", - {{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}}}, - {{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}}, + {{{Extension::int16}, {{i16T, i16T, i48T}}}, + {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T}}}, + {{Extension::bf16}, {{bf16T, bf16T, fp32T}}}}}, {"tosa.max_pool2d", {{{Extension::int16}, {{i16T, i16T}}}, {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index ecddc9fe9a13d..097f78cd487ea 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -311,8 +311,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> { let arguments = (ins Tosa_Tensor3D:$a, Tosa_Tensor3D:$b, - Tosa_ScalarIntOrFloatTensor:$a_zp, - Tosa_ScalarIntOrFloatTensor:$b_zp + OptionalAttr:$a_zp, + OptionalAttr:$b_zp ); let results = (outs @@ -324,13 +324,6 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> { Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, ]; - let extraClassDeclaration = [{ - FailureOr getAZeroPoint(); - FailureOr getBZeroPoint(); - LogicalResult verifyAZeroPoint(int64_t zp); - LogicalResult verifyBZeroPoint(int64_t zp); - }]; - let builders = [Tosa_MatMulOpQuantInfoBuilder]; let hasVerifier = 1; } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 13c62b2d3e91c..2a2589e19d0ac 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -270,8 +270,8 @@ class ConvConverter : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "weight zero point cannot be statically determined"); - const int64_t inputZpVal = *maybeIZp; - const int64_t weightZpVal = *maybeWZp; + int64_t inputZpVal = *maybeIZp; + int64_t weightZpVal = *maybeWZp; if (op.verifyInputZeroPoint(inputZpVal).failed()) return rewriter.notifyMatchFailure( @@ -466,8 +466,8 @@ class DepthwiseConvConverter return rewriter.notifyMatchFailure( op, "weight zero point cannot be statically determined"); - const int64_t inputZpVal = *maybeIZp; - const int64_t weightZpVal = *maybeWZp; + int64_t inputZpVal = *maybeIZp; + int64_t weightZpVal = *maybeWZp; if (op.verifyInputZeroPoint(inputZpVal).failed()) return rewriter.notifyMatchFailure( @@ -621,38 +621,15 @@ class MatMulConverter : public OpConversionPattern { .create(loc, ValueRange{zero}, ValueRange{emptyTensor}) .result(); - - FailureOr maybeAZp = op.getAZeroPoint(); - FailureOr maybeBZp = op.getBZeroPoint(); - if (failed(maybeAZp)) - return rewriter.notifyMatchFailure( - op, "input a zero point cannot be statically determined"); - if (failed(maybeBZp)) - return rewriter.notifyMatchFailure( - op, "input b zero point cannot be statically determined"); - - const int64_t aZpVal = *maybeAZp; - const int64_t bZpVal = *maybeBZp; - - if (op.verifyAZeroPoint(aZpVal).failed()) - return rewriter.notifyMatchFailure( - op, "input a zero point must be zero for non-int8 integer types"); - - if (op.verifyBZeroPoint(bZpVal).failed()) - return rewriter.notifyMatchFailure( - op, "input b zero point must be zero for non-int8 integer types"); - - if (aZpVal == 0 && bZpVal == 0) { + if (!op.getAZp() && !op.getBZp()) { rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor}); return success(); } - auto aZp = rewriter.create( - loc, rewriter.getI32IntegerAttr(aZpVal)); - auto bZp = rewriter.create( - loc, rewriter.getI32IntegerAttr(bZpVal)); + auto aZp = rewriter.create(loc, op.getAZpAttr()); + auto bZp = rewriter.create(loc, op.getBZpAttr()); rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor); @@ -857,8 +834,8 @@ class AvgPool2dConverter : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "output zero point could not be statically determined"); - const int64_t inputZpVal = *maybeIZp; - const int64_t outputZpVal = *maybeOZp; + int64_t inputZpVal = *maybeIZp; + int64_t outputZpVal = *maybeOZp; // Apply padding as necessary. llvm::SmallVector pad; diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp index 6dcb7c845b21f..ffbb707344b8c 100644 --- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp @@ -55,8 +55,6 @@ struct MatMulOpSharding SmallVector maps; maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx)); maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx)); - maps.push_back(AffineMap::get(0, 0, {}, ctx)); - maps.push_back(AffineMap::get(0, 0, {}, ctx)); maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx)); return maps; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 7a991b3876f69..4711122dc76e2 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -629,13 +629,23 @@ buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b) { - auto zps = createZPsAsConst(builder, a, b); - result.addOperands({a, b, zps.first, zps.second}); + result.addOperands({a, b}); + auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b); - Type finalOutputType{outputType}; - if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) { - auto eType = getStorageElementTypeOrSelf(a.getType()); - auto inputBits = eType.getIntOrFloatBitWidth(); + if (quantAttr) { + result.addAttribute("a_zp", builder.getI32IntegerAttr( + static_cast(quantAttr.getAZp()))); + result.addAttribute("b_zp", builder.getI32IntegerAttr( + static_cast(quantAttr.getBZp()))); + + auto inputType = llvm::dyn_cast(a.getType()); + assert(inputType && "Input must be a shaped tensor type!"); + + auto inputQType = llvm::dyn_cast( + inputType.getElementType()); + assert(inputQType && "Tensor must have quantized datatype!"); + + unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); auto outputShapedType = llvm::dyn_cast(outputType); assert(outputShapedType && "Output must be a shaped type"); @@ -645,10 +655,11 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder, accElementType = builder.getIntegerType(48); else accElementType = builder.getI32Type(); - - finalOutputType = outputShapedType.clone(accElementType); + auto accType = outputShapedType.clone(accElementType); + result.addTypes(accType); + } else { + result.addTypes(outputType); } - result.addTypes(finalOutputType); } /// Both the tosa.avg_pool2d and unary ops use the same @@ -1129,39 +1140,16 @@ LogicalResult MatMulOp::verify() { return emitOpError("expect quantized operands to have same widths, got ") << aQuantWidth << " and " << bQuantWidth; } - } else { - // non-quantized element types - if (aElementType != bElementType) { - return emitOpError("expect same element type for inputs a and b, got ") - << aElementType << " and " << bElementType; - } - } - // check a_zp and b_zp - auto aEType = getStorageElementTypeOrSelf(aType); - auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType()); - if (aEType != aZpEType) { - return emitOpError("expect input a and a_zp have the same " - "element type, got ") - << aEType << " and " << aZpEType; + return success(); } - auto bEType = getStorageElementTypeOrSelf(bType); - auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType()); - if (bEType != bZpEType) { - return emitOpError("expect input b and b_zp have the same " - "element type, got ") - << bEType << " and " << bZpEType; + // non-quantized element types + if (aElementType != bElementType) { + return emitOpError("expect same element type for inputs a and b, got ") + << aElementType << " and " << bElementType; } - FailureOr maybeAZp = getAZeroPoint(); - if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed()) - return failure(); - - FailureOr maybeBZp = getBZeroPoint(); - if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed()) - return failure(); - return success(); } @@ -1726,8 +1714,6 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input) ZERO_POINT_HELPER(TransposeConv2DOp, Weight) ZERO_POINT_HELPER(AvgPool2dOp, Input) ZERO_POINT_HELPER(AvgPool2dOp, Output) -ZERO_POINT_HELPER(MatMulOp, A) -ZERO_POINT_HELPER(MatMulOp, B) #undef ZERO_POINT_HELPER LogicalResult tosa::TransposeOp::inferReturnTypeComponents( diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 983062ffd7912..345616c9563b5 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -178,15 +178,6 @@ void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) { addValue(op.getOutput()); } -template <> -void ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) { - addValue(op.getA()); - addValue(op.getB()); - addValue(op.getAZp()); - addValue(op.getBZp()); - addValue(op.getOutput()); -} - LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // This helper function only populates the info for the customised operands. #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \ @@ -227,7 +218,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { POPULATE_PROFILE_INFO_CUSTOM(Resize) POPULATE_PROFILE_INFO_CUSTOM(Select) POPULATE_PROFILE_INFO_CUSTOM(Rescale) - POPULATE_PROFILE_INFO_CUSTOM(MatMul) // Type Invariant Extension, a capability extension that is independent // of the data type, meaning any compatible type can be used. No type @@ -245,6 +235,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { POPULATE_PROFILE_INFO_COMMON(Cast) POPULATE_PROFILE_INFO_COMMON(Const) POPULATE_PROFILE_INFO_COMMON(ArgMax) + POPULATE_PROFILE_INFO_COMMON(MatMul) POPULATE_PROFILE_INFO_COMMON(Sub) POPULATE_PROFILE_INFO_COMMON(Maximum) POPULATE_PROFILE_INFO_COMMON(Minimum) diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index 341f773c79a5e..5bb4a3bddb51b 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -8,9 +8,7 @@ func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : f32) outs([[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x6xf32>) outs([[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> - %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> - %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> - %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x6xf32> + %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x6xf32>) -> tensor<1x5x6xf32> return %0 : tensor<1x5x6xf32> } @@ -25,9 +23,7 @@ func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) -> // CHECK: [[ONE:%.+]] = arith.constant 1 // CHECK: [[TWO:%.+]] = arith.constant 2 // CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32> - %a_zp = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8> - %b_zp = "tosa.const"() <{values = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8> - %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xi8>, tensor<1x3x6xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x6xi32> + %0 = tosa.matmul %arg0, %arg1 {a_zp = 1 : i32, b_zp = 2 : i32} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32> return %0 : tensor<1x5x6xi32> } @@ -41,9 +37,7 @@ func.func @matmul_dyn_batch(%arg0: tensor, %arg1: tensor) // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0_0]] : f32) outs(%[[INIT]] : tensor) -> tensor // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor, tensor) outs(%[[FILLED]] : tensor) -> tensor - %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> - %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> - %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor, tensor, tensor<1xf32>, tensor<1xf32>) -> tensor + %0 = tosa.matmul %arg0, %arg1 : (tensor, tensor) -> tensor return %0 : tensor } @@ -57,9 +51,7 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32> // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x?xf32>) outs(%[[FILLED]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32> - %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> - %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> - %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x?xf32> + %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x?xf32>) -> tensor<1x5x?xf32> return %0 : tensor<1x5x?xf32> } @@ -71,9 +63,7 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x // CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x?xf32>, tensor<1x?x6xf32>) outs(%[[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> - %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> - %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> - %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x?xf32>, tensor<1x?x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x6xf32> + %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x?xf32>, tensor<1x?x6xf32>) -> tensor<1x5x6xf32> return %0 : tensor<1x5x6xf32> } @@ -87,9 +77,7 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>) // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor) -> tensor // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x1x8xf32>, tensor<1x8x1xf32>) outs(%[[FILLED]] : tensor) -> tensor - %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> - %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> - %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x1x8xf32>, tensor<1x8x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor + %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor return %0 : tensor } diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir index 14c67e670e921..83136f613b020 100644 --- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir +++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir @@ -98,16 +98,14 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor } // CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> -func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32> +func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> { // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0]] : !mesh.sharding // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> - // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] - %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]] + %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32> // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding @@ -117,16 +115,14 @@ func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: ten } // CHECK-LABEL: func.func @matmul_on_def_shard_m_and_k -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> -func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32> +func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> { // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> - // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] - %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]] + %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32> // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> %s1 = mesh.sharding @mesh_2d split_axes = [[], [1]] partial = sum [0] : !mesh.sharding @@ -136,18 +132,16 @@ func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor< } // CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> -func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32> +func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> { // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32> // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> - // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] - %1 = tosa.matmul %0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]] + %1 = tosa.matmul %0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32> // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> // CHECK-NEXT: return %[[V3]] @@ -155,8 +149,8 @@ func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor< } // CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> -func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32> +func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> { // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding @@ -165,10 +159,8 @@ func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> %s1 = mesh.sharding @mesh_2d split_axes = [[], [0]] : !mesh.sharding %1 = mesh.shard %arg1 to %s1 annotate_for_users : tensor<2x8x32xf32> - // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] - %2 = tosa.matmul %0, %1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]] + %2 = tosa.matmul %0, %1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32> // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> // CHECK-NEXT: return %[[V3]] @@ -207,16 +199,14 @@ func.func @resolve_conflicting_annotations( // https://arxiv.org/abs/2211.05102 Figure 2(a) // CHECK-LABEL: func.func @mlp_1d_weight_stationary -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>, %[[ARG3:.*]]: tensor<1xf32> -func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> { +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32> +func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> { %s0 = mesh.sharding @mesh_1d split_axes = [[], [], [0]] : !mesh.sharding %0 = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32> // CHECK-DAG: %[[S1:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding // CHECK-DAG: %[[S2:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding - // CHECK-DAG: %[[S3:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-DAG: %[[ZP:.*]] = mesh.shard %[[ARG3]] to %[[S3]] annotate_for_users : tensor<1xf32> // CHECK: %[[V0:.*]] = tosa.matmul - %1 = tosa.matmul %0, %arg1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32> + %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32> // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S2]] : tensor<2x4x32xf32> // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]] annotate_for_users : tensor<2x4x32xf32> // CHECK-DAG: %[[V3:.*]] = tosa.sigmoid %[[V2]] @@ -225,8 +215,8 @@ func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S2]] annotate_for_users : tensor<2x4x32xf32> // CHECK-DAG: %[[S6:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [0]] : !mesh.sharding // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[ARG2]] to %[[S6]] annotate_for_users : tensor<2x32x8xf32> - // CHECK-DAG: %[[V7:.*]] = tosa.matmul %[[V5]], %[[V6]], %[[ZP]], %[[ZP]] - %3 = tosa.matmul %2, %arg2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32> + // CHECK-DAG: %[[V7:.*]] = tosa.matmul %[[V5]], %[[V6]] + %3 = tosa.matmul %2, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32> %s4 = mesh.sharding @mesh_1d split_axes = [[], [], []] partial = sum [0] : !mesh.sharding %4 = mesh.shard %3 to %s4 : tensor<2x4x8xf32> // CHECK: %[[S8:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], []] partial = sum [0] : !mesh.sharding @@ -240,8 +230,8 @@ func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x // https://arxiv.org/abs/2211.05102 Figure 2(b) // CHECK-LABEL: func.func @mlp_2d_weight_stationary -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>, %[[ARG3:.*]]: tensor<1xf32> -func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> { +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32> +func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> { // CHECK-DAG: %[[S0:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] : tensor<2x4x8xf32> %s0 = mesh.sharding @mesh_3d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding @@ -250,10 +240,8 @@ func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S1]] annotate_for_users : tensor<2x4x8xf32> // CHECK-DAG: %[[S2:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [0], [1, 2]] : !mesh.sharding // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[ARG1]] to %[[S2]] annotate_for_users : tensor<2x8x32xf32> - // CHECK-DAG: %[[S3:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}]] : !mesh.sharding - // CHECK-DAG: %[[ZP:.*]] = mesh.shard %[[ARG3]] to %[[S3]] annotate_for_users : tensor<1xf32> - // CHECK-DAG: %[[V3:.*]] = tosa.matmul %[[V1]], %[[V2]], %[[ZP]], %[[ZP]] - %1 = tosa.matmul %0, %arg1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32> + // CHECK-DAG: %[[V3:.*]] = tosa.matmul %[[V1]], %[[V2]] + %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32> // CHECK-DAG: %[[S4:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [1, 2]] partial = sum [0] : !mesh.sharding // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S4]] : tensor<2x4x32xf32> %s2 = mesh.sharding @mesh_3d split_axes = [[], [], [1, 2]] partial = sum [0] : !mesh.sharding @@ -266,8 +254,8 @@ func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S5]] annotate_for_users : tensor<2x4x32xf32> // CHECK-DAG: %[[S9:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [1, 2], [0]] : !mesh.sharding // CHECK-NEXT: %[[V9:.*]] = mesh.shard %[[ARG2]] to %[[S9]] annotate_for_users : tensor<2x32x8xf32> - // CHECK-DAG: %[[V10:.*]] = tosa.matmul %[[V8]], %[[V9]], %[[ZP]], %[[ZP]] - %4 = tosa.matmul %3, %arg2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32> + // CHECK-DAG: %[[V10:.*]] = tosa.matmul %[[V8]], %[[V9]] + %4 = tosa.matmul %3, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32> // CHECK-DAG: %[[S11:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0]] partial = sum [1, 2] : !mesh.sharding // CHECK-NEXT: %[[V11:.*]] = mesh.shard %[[V10]] to %[[S11]] : tensor<2x4x8xf32> %s5 = mesh.sharding @mesh_3d split_axes = [[], [], [0]] partial = sum[1, 2] : !mesh.sharding diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index b786264d84106..1952ad79392c7 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -69,10 +69,10 @@ func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (te // ----- // CHECK-LABEL: matmul -func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %a_zp: tensor<1xf32>, %b_zp: tensor<1xf32>) -> tensor<1x14x28xf32> { +func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { // CHECK: profiles: [ [pro_int, pro_fp] ] // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ] - %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> + %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32> return %0 : tensor<1x14x28xf32> } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index f536444f6379e..05700ca3765e4 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -287,7 +287,7 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso // ----- func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { - %0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> + %0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> %pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32> // expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}} %1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32> @@ -1612,43 +1612,3 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3 %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16> return %0 : tensor<13x21x3xi16> } - -// ----- -// CHECK-LABEL: test_matmul_a_zp_same_element_type -func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { -%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> -%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> -// expected-error@+1 {{'tosa.matmul' op expect input a and a_zp have the same element type, got 'f32' and 'f16'}} -%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf16>, tensor<1xf32>) -> tensor<1x14x28xf32> - return %0 : tensor<1x14x28xf32> -} - -// ----- -// CHECK-LABEL: test_matmul_b_zp_same_element_type -func.func @test_matmul_b_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { -%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> -%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> -// expected-error@+1 {{'tosa.matmul' op expect input b and b_zp have the same element type, got 'f32' and 'f16'}} -%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf16>) -> tensor<1x14x28xf32> - return %0 : tensor<1x14x28xf32> -} - -// ----- -// CHECK-LABEL: test_matmul_a_zp_non_zero -func.func @test_matmul_a_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { -%azp0 = "tosa.const"() <{values = dense<1.0> : tensor<1xf32>}> : () -> tensor<1xf32> -%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> -// expected-error@+1 {{'tosa.matmul' op a zero point must be zero for non-int8 integer types}} -%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> - return %0 : tensor<1x14x28xf32> -} - -// ----- -// CHECK-LABEL: test_matmul_b_zp_non_zero -func.func @test_matmul_b_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { -%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> -%bzp0 = "tosa.const"() <{values = dense<-1.0> : tensor<1xf32>}> : () -> tensor<1xf32> -// expected-error@+1 {{'tosa.matmul' op b zero point must be zero for non-int8 integer types}} -%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> - return %0 : tensor<1x14x28xf32> -} diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 6d8237635d0ec..bc13b614e3f9d 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1110,9 +1110,8 @@ func.func @test_rfft2d_tensor_size_invalid(%arg0: tensor<536870912x8x16xf32>) -> // ----- func.func @test_matmul_tensor_size_invalid(%arg0: tensor<23178x20000x19xf32>, %arg1: tensor<23178x19x28xf32>) -> tensor<23178x20000x28xf32> { - %zero = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> // expected-error@+1 {{'tosa.matmul' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} - %0 = tosa.matmul %arg0, %arg1, %zero, %zero : (tensor<23178x20000x19xf32>, tensor<23178x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<23178x20000x28xf32> + %0 = tosa.matmul %arg0, %arg1 : (tensor<23178x20000x19xf32>, tensor<23178x19x28xf32>) -> tensor<23178x20000x28xf32> return %0 : tensor<23178x20000x28xf32> } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index bebc1a8b748be..96eacc9fb6093 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -145,9 +145,7 @@ func.func @test_fft2d_with_local_bound(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1 // ----- // CHECK-LABEL: test_matmul func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { -%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> -%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> -%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> + %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32> return %0 : tensor<1x14x28xf32> } diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir index d0e97e46f1f6a..342c57b0dd85c 100644 --- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir @@ -26,9 +26,9 @@ func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>, %arg1: tensor<1xf32>, %ar } // ----- -func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %arg2: tensor<1xf32>) -> tensor<1x14x28xf32> { +func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { // expected-error@+1 {{'tosa.matmul' op illegal: requires [pro_fp] but not enabled in target}} - %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2: (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> + %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32> return %0 : tensor<1x14x28xf32> } diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index 28c7abdeaf7f7..3dd0344e3647d 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -19,9 +19,9 @@ func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>, %arg1: tensor<1xf32>, %ar } // ----- -func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %arg2: tensor<1xf32>) -> tensor<1x14x28xf32> { +func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { // expected-error@+1 {{'tosa.matmul' op illegal: requires [pro_fp] but not enabled in target}} - %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> + %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32> return %0 : tensor<1x14x28xf32> } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index deede4b0afadc..55c5c3f6bdfb6 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -279,10 +279,8 @@ func.func @test_dynamic_argmax(%arg0 : tensor<2x?xi32>) -> () { // CHECK-LABEL: @test_static_matmul func.func @test_static_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor<2x4x5xi32>) -> () { - // CHECK tosa.matmul %arg0, %arg1, %0, %1 : (tensor<2x3x4xi32>, tensor<2x4x5xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x3x5xi32> - %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> - %1 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> - %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor<2x3x4xi32>, tensor<2x4x5xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + // CHECK: tosa.matmul %arg0, %arg1 : (tensor<2x3x4xi32>, tensor<2x4x5xi32>) -> tensor<2x3x5xi32> + %0 = tosa.matmul %arg0, %arg1 : (tensor<2x3x4xi32>, tensor<2x4x5xi32>) -> tensor return } @@ -291,10 +289,8 @@ func.func @test_static_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor<2x4x5xi3 // CHECK-LABEL: @test_dynamic_lhs_matmul func.func @test_dynamic_lhs_matmul(%arg0 : tensor, %arg1 : tensor<2x4x5xi32>) -> () { - // CHECK: tosa.matmul %arg0, %arg1, %0, %1 : (tensor, tensor<2x4x5xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x?x5xi32> - %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> - %1 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> - %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor, tensor<2x4x5xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + // CHECK: tosa.matmul %arg0, %arg1 : (tensor, tensor<2x4x5xi32>) -> tensor<2x?x5xi32> + %0 = tosa.matmul %arg0, %arg1 : (tensor, tensor<2x4x5xi32>) -> tensor return } @@ -303,10 +299,8 @@ func.func @test_dynamic_lhs_matmul(%arg0 : tensor, %arg1 : tensor<2x4 // CHECK-LABEL: @test_dynamic_rhs_matmul func.func @test_dynamic_rhs_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor) -> () { - // CHECK: tosa.matmul %arg0, %arg1, %0, %1 : (tensor<2x3x4xi32>, tensor, tensor<1xi32>, tensor<1xi32>) -> tensor<2x3x?xi32> - %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> - %1 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> - %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor<2x3x4xi32>, tensor, tensor<1xi32>, tensor<1xi32>) -> tensor + // CHECK: tosa.matmul %arg0, %arg1 : (tensor<2x3x4xi32>, tensor) -> tensor<2x3x?xi32> + %0 = tosa.matmul %arg0, %arg1 : (tensor<2x3x4xi32>, tensor) -> tensor return } @@ -315,10 +309,8 @@ func.func @test_dynamic_rhs_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor, %arg1 : tensor) -> () { - // CHECK: tosa.matmul %arg0, %arg1, %0, %1 : (tensor, tensor, tensor<1xi32>, tensor<1xi32>) -> tensor - %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> - %1 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> - %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor, tensor, tensor<1xi32>, tensor<1xi32>) -> tensor + // CHECK: tosa.matmul %arg0, %arg1 : (tensor, tensor) -> tensor + %0 = tosa.matmul %arg0, %arg1 : (tensor, tensor) -> tensor return } @@ -1413,13 +1405,11 @@ func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index) // CHECK-LABEL: test_non_tosa_consumer_still_propagates func.func @test_non_tosa_consumer_still_propagates(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>) -> tensor { - // CHECK: tosa.matmul %arg0, %arg1, %0, %1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x1xf32> - %0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> - %1 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> - %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor - %3 = arith.constant dense<[1, 1]> : tensor<2xindex> - %4 = tensor.reshape %2(%3) : (tensor, tensor<2xindex>) -> tensor - return %4 : tensor + // CHECK: tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<1x1x1xf32> + %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor + %1 = arith.constant dense<[1, 1]> : tensor<2xindex> + %2 = tensor.reshape %0(%1) : (tensor, tensor<2xindex>) -> tensor + return %2 : tensor } // -----