From 60372437bdc11b6c595eb1cefe7972cf54011a41 Mon Sep 17 00:00:00 2001 From: Shiva Chen Date: Fri, 29 Aug 2025 10:08:07 +0100 Subject: [PATCH 1/3] [mlir][tosa] support NegateOp with dynamic extension in TosaToLinalg --- .../Conversion/TosaToLinalg/TosaToLinalg.cpp | 80 +++++++++++-------- .../TosaToLinalg/tosa-to-linalg.mlir | 31 +++++++ 2 files changed, 76 insertions(+), 35 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 1955eec9964eb..91e0f235349f0 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -186,56 +186,61 @@ static Value createLinalgBodyCalculationForElementwiseOp( if (isa(op)) { auto negate = cast(op); + int64_t inZp = 0, outZp = 0; FailureOr maybeInZp = negate.getInput1ZeroPoint(); - if (failed(maybeInZp)) { - (void)rewriter.notifyMatchFailure( - op, "input1 zero point cannot be statically determined"); - return nullptr; - } - FailureOr maybeOutZp = negate.getOutputZeroPoint(); - if (failed(maybeOutZp)) { - (void)rewriter.notifyMatchFailure( - op, "output zero point cannot be statically determined"); - return nullptr; - } - - int64_t inZp = *maybeInZp; - int64_t outZp = *maybeOutZp; + if (!failed(maybeInZp)) + inZp = *maybeInZp; + if (!failed(maybeOutZp)) + outZp = *maybeOutZp; if (isa(elementTy)) return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]); if (isa(elementTy)) { - if (!inZp && !outZp) { + if (!failed(maybeInZp) && !failed(maybeOutZp) && !inZp && !outZp) { auto constant = arith::ConstantOp::create( rewriter, loc, IntegerAttr::get(elementTy, 0)); return arith::SubIOp::create(rewriter, loc, resultTypes, constant, args[0]); } + Value zpAddValue; + Type intermediateType; // Compute the maximum value that can occur in the intermediate buffer. const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); - const int64_t zpAdd = inZp + outZp; - const int64_t maxValue = - APInt::getSignedMaxValue(inputBitWidth).getSExtValue() + - std::abs(zpAdd) + 1; - - // Convert that maximum value into the maximum bitwidth needed to - // represent it. We assume 48-bit numbers may be supported further in - // the pipeline. int intermediateBitWidth = 64; - if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) { - intermediateBitWidth = 16; - } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) { - intermediateBitWidth = 32; - } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) { - intermediateBitWidth = 48; - } - Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); - Value zpAddValue = arith::ConstantOp::create( - rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); + if (!failed(maybeInZp) && !failed(maybeOutZp)) { + // Compute the maximum value that can occur in the intermediate buffer. + const int64_t zpAdd = inZp + outZp; + const int64_t maxValue = + APInt::getSignedMaxValue(inputBitWidth).getSExtValue() + + std::abs(zpAdd) + 1; + + // Convert that maximum value into the maximum bitwidth needed to + // represent it. We assume 48-bit numbers may be supported further in + // the pipeline. + if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) { + intermediateBitWidth = 16; + } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) { + intermediateBitWidth = 32; + } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) { + intermediateBitWidth = 48; + } + + intermediateType = rewriter.getIntegerType(intermediateBitWidth); + zpAddValue = rewriter.create( + loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); + } else { + intermediateType = rewriter.getIntegerType(intermediateBitWidth); + auto arg1 = + rewriter.create(loc, intermediateType, args[1]); + auto arg2 = + rewriter.create(loc, intermediateType, args[2]); + zpAddValue = + rewriter.create(loc, intermediateType, arg1, arg2); + } // The negation can be applied by doing: // outputValue = inZp + outZp - inputValue @@ -1013,9 +1018,14 @@ static ValueRange getBroadcastableOperands(Operation *operation, else return operands.take_front(3); } - // Input1_zp and output_zp cannot broadcast - if (isa(operation)) + if (auto negate = dyn_cast(operation)) { + FailureOr maybeInZp = negate.getInput1ZeroPoint(); + FailureOr maybeOutZp = negate.getOutputZeroPoint(); + if (failed(maybeOutZp) && failed(maybeInZp)) + return operands; + // Input1_zp and output_zp cannot broadcast when they are constants. return operands.take_front(1); + } return operands; } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 37af8b8859852..780344764e014 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -899,6 +899,37 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () { // ----- +func.func @negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> { + // CHECK: %[[GENERIC:.+]] = linalg.generic + // CHECK: ^bb0([[ARG0:%.*]]: f16, [[ARG1:%.*]]: f16, [[ARG2:%.*]]: f16, [[OUT:%.*]]: f16) + // CHECK: [[ELEMENT:%.*]] = arith.negf [[ARG0]] : f16 + %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<50x42xf16> + %cast = tensor.cast %0 : tensor<50x42xf16> to tensor<*xf16> + return %cast : tensor<*xf16> +} + +// ----- + +func.func @negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> { + // CHECK: %[[GENERIC:.+]] = linalg.generic + // CHECK: ^bb0([[ARG0:%.*]]: i16, [[ARG1:%.*]]: i16, [[ARG2:%.*]]: i16, [[OUT:%.*]]: i16) + // CHECK: [[EXTSI1:%.*]] = arith.extsi [[ARG1]] : i16 to i64 + // CHECK: [[EXTSI2:%.*]] = arith.extsi [[ARG2]] : i16 to i64 + // CHECK: [[SUM:%.*]] = arith.addi [[EXTSI1]], [[EXTSI2]] : i64 + // CHECK: [[EXTSI0:%.*]] = arith.extsi [[ARG0]] : i16 to i64 + // CHECK: [[SUB:%.*]] = arith.subi [[SUM]], [[EXTSI0]] : i64 + // CHECK: [[C_32768:%.*]] = arith.constant -32768 : i64 + // CHECK: [[C32767:%.*]] = arith.constant 32767 : i64 + // CHECK: [[MAX:%.*]] = arith.maxsi [[C_32768]], [[SUB]] : i64 + // CHECK: [[MIN:%.*]] = arith.minsi [[C32767]], [[MAX]] : i64 + // CHECK: [[TRUNC:%.*]] = arith.trunci [[MIN]] : i64 to i16 + %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xi16>, tensor<1xi16>, tensor<1xi16>) -> tensor<50x42xi16> + %cast = tensor.cast %0 : tensor<50x42xi16> to tensor<*xi16> + return %cast : tensor<*xi16> +} + +// ----- + // CHECK-LABEL: @test_identity // CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<1xf32>, // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xi32> From 4bbd46918f36a6dfbe112999959697d5454fef52 Mon Sep 17 00:00:00 2001 From: Shiva Chen Date: Mon, 22 Sep 2025 01:56:05 +0100 Subject: [PATCH 2/3] Define hasInZp and hasOutZp --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 91e0f235349f0..e3602111cb1dd 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -189,16 +189,18 @@ static Value createLinalgBodyCalculationForElementwiseOp( int64_t inZp = 0, outZp = 0; FailureOr maybeInZp = negate.getInput1ZeroPoint(); FailureOr maybeOutZp = negate.getOutputZeroPoint(); - if (!failed(maybeInZp)) + bool hasInZp = !failed(maybeInZp); + bool hasOutZp = !failed(maybeOutZp); + if (hasInZp) inZp = *maybeInZp; - if (!failed(maybeOutZp)) + if (hasOutZp) outZp = *maybeOutZp; if (isa(elementTy)) return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]); if (isa(elementTy)) { - if (!failed(maybeInZp) && !failed(maybeOutZp) && !inZp && !outZp) { + if (hasInZp && hasOutZp && !inZp && !outZp) { auto constant = arith::ConstantOp::create( rewriter, loc, IntegerAttr::get(elementTy, 0)); return arith::SubIOp::create(rewriter, loc, resultTypes, constant, @@ -211,7 +213,7 @@ static Value createLinalgBodyCalculationForElementwiseOp( const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); int intermediateBitWidth = 64; - if (!failed(maybeInZp) && !failed(maybeOutZp)) { + if (hasInZp && hasOutZp) { // Compute the maximum value that can occur in the intermediate buffer. const int64_t zpAdd = inZp + outZp; const int64_t maxValue = From b62b227a0ec6a7c284215ef97dcb4bacb969df84 Mon Sep 17 00:00:00 2001 From: Shiva Chen Date: Mon, 22 Sep 2025 02:03:13 +0100 Subject: [PATCH 3/3] Add CHECK-LABEL in test cases --- mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 780344764e014..2163dbb0d4561 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -899,7 +899,8 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () { // ----- -func.func @negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> { +// CHECK-LABEL: @test_negate_no_const_1 +func.func @test_negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> { // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK: ^bb0([[ARG0:%.*]]: f16, [[ARG1:%.*]]: f16, [[ARG2:%.*]]: f16, [[OUT:%.*]]: f16) // CHECK: [[ELEMENT:%.*]] = arith.negf [[ARG0]] : f16 @@ -910,7 +911,8 @@ func.func @negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %a // ----- -func.func @negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> { +// CHECK-LABEL: @test_negate_no_const_2 +func.func @test_negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> { // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK: ^bb0([[ARG0:%.*]]: i16, [[ARG1:%.*]]: i16, [[ARG2:%.*]]: i16, [[OUT:%.*]]: i16) // CHECK: [[EXTSI1:%.*]] = arith.extsi [[ARG1]] : i16 to i64