-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][tosa] support NegateOp with dynamic extension in TosaToLinalg #158782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: None (ShivaChen) ChangesFull diff: https://github.com/llvm/llvm-project/pull/158782.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1955eec9964eb..91e0f235349f0 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -186,56 +186,61 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::NegateOp>(op)) {
auto negate = cast<tosa::NegateOp>(op);
+ int64_t inZp = 0, outZp = 0;
FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
- if (failed(maybeInZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input1 zero point cannot be statically determined");
- return nullptr;
- }
-
FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
- if (failed(maybeOutZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return nullptr;
- }
-
- int64_t inZp = *maybeInZp;
- int64_t outZp = *maybeOutZp;
+ if (!failed(maybeInZp))
+ inZp = *maybeInZp;
+ if (!failed(maybeOutZp))
+ outZp = *maybeOutZp;
if (isa<FloatType>(elementTy))
return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
if (isa<IntegerType>(elementTy)) {
- if (!inZp && !outZp) {
+ if (!failed(maybeInZp) && !failed(maybeOutZp) && !inZp && !outZp) {
auto constant = arith::ConstantOp::create(
rewriter, loc, IntegerAttr::get(elementTy, 0));
return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
args[0]);
}
+ Value zpAddValue;
+ Type intermediateType;
// Compute the maximum value that can occur in the intermediate buffer.
const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
- const int64_t zpAdd = inZp + outZp;
- const int64_t maxValue =
- APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
- std::abs(zpAdd) + 1;
-
- // Convert that maximum value into the maximum bitwidth needed to
- // represent it. We assume 48-bit numbers may be supported further in
- // the pipeline.
int intermediateBitWidth = 64;
- if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
- intermediateBitWidth = 16;
- } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
- intermediateBitWidth = 32;
- } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
- intermediateBitWidth = 48;
- }
- Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
- Value zpAddValue = arith::ConstantOp::create(
- rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+ if (!failed(maybeInZp) && !failed(maybeOutZp)) {
+ // Compute the maximum value that can occur in the intermediate buffer.
+ const int64_t zpAdd = inZp + outZp;
+ const int64_t maxValue =
+ APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+ std::abs(zpAdd) + 1;
+
+ // Convert that maximum value into the maximum bitwidth needed to
+ // represent it. We assume 48-bit numbers may be supported further in
+ // the pipeline.
+ if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+ intermediateBitWidth = 16;
+ } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+ intermediateBitWidth = 32;
+ } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+ intermediateBitWidth = 48;
+ }
+
+ intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+ zpAddValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+ } else {
+ intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+ auto arg1 =
+ rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[1]);
+ auto arg2 =
+ rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[2]);
+ zpAddValue =
+ rewriter.create<arith::AddIOp>(loc, intermediateType, arg1, arg2);
+ }
// The negation can be applied by doing:
// outputValue = inZp + outZp - inputValue
@@ -1013,9 +1018,14 @@ static ValueRange getBroadcastableOperands(Operation *operation,
else
return operands.take_front(3);
}
- // Input1_zp and output_zp cannot broadcast
- if (isa<tosa::NegateOp>(operation))
+ if (auto negate = dyn_cast<tosa::NegateOp>(operation)) {
+ FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
+ FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
+ if (failed(maybeOutZp) && failed(maybeInZp))
+ return operands;
+ // Input1_zp and output_zp cannot broadcast when they are constants.
return operands.take_front(1);
+ }
return operands;
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 37af8b8859852..780344764e014 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -899,6 +899,37 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// -----
+func.func @negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> {
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK: ^bb0([[ARG0:%.*]]: f16, [[ARG1:%.*]]: f16, [[ARG2:%.*]]: f16, [[OUT:%.*]]: f16)
+ // CHECK: [[ELEMENT:%.*]] = arith.negf [[ARG0]] : f16
+ %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<50x42xf16>
+ %cast = tensor.cast %0 : tensor<50x42xf16> to tensor<*xf16>
+ return %cast : tensor<*xf16>
+}
+
+// -----
+
+func.func @negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> {
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK: ^bb0([[ARG0:%.*]]: i16, [[ARG1:%.*]]: i16, [[ARG2:%.*]]: i16, [[OUT:%.*]]: i16)
+ // CHECK: [[EXTSI1:%.*]] = arith.extsi [[ARG1]] : i16 to i64
+ // CHECK: [[EXTSI2:%.*]] = arith.extsi [[ARG2]] : i16 to i64
+ // CHECK: [[SUM:%.*]] = arith.addi [[EXTSI1]], [[EXTSI2]] : i64
+ // CHECK: [[EXTSI0:%.*]] = arith.extsi [[ARG0]] : i16 to i64
+ // CHECK: [[SUB:%.*]] = arith.subi [[SUM]], [[EXTSI0]] : i64
+ // CHECK: [[C_32768:%.*]] = arith.constant -32768 : i64
+ // CHECK: [[C32767:%.*]] = arith.constant 32767 : i64
+ // CHECK: [[MAX:%.*]] = arith.maxsi [[C_32768]], [[SUB]] : i64
+ // CHECK: [[MIN:%.*]] = arith.minsi [[C32767]], [[MAX]] : i64
+ // CHECK: [[TRUNC:%.*]] = arith.trunci [[MIN]] : i64 to i16
+ %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xi16>, tensor<1xi16>, tensor<1xi16>) -> tensor<50x42xi16>
+ %cast = tensor.cast %0 : tensor<50x42xi16> to tensor<*xi16>
+ return %cast : tensor<*xi16>
+}
+
+// -----
+
// CHECK-LABEL: @test_identity
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<1xf32>,
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xi32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ShivaChen, the lowering looks reasonable but I haven't had chance to test it. Had a couple of minor nitpicks, otherwise LGTM!
d96f5f3
to
b5dff75
Compare
Very appreciate your review in busy days. |
8b672d7
to
b62b227
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @ShivaChen!
No description provided.