Skip to content

Conversation

ShivaChen
Copy link
Collaborator

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Sep 16, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: None (ShivaChen)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/158782.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+45-35)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+31)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1955eec9964eb..91e0f235349f0 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -186,56 +186,61 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::NegateOp>(op)) {
     auto negate = cast<tosa::NegateOp>(op);
 
+    int64_t inZp = 0, outZp = 0;
     FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
-    if (failed(maybeInZp)) {
-      (void)rewriter.notifyMatchFailure(
-          op, "input1 zero point cannot be statically determined");
-      return nullptr;
-    }
-
     FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
-    if (failed(maybeOutZp)) {
-      (void)rewriter.notifyMatchFailure(
-          op, "output zero point cannot be statically determined");
-      return nullptr;
-    }
-
-    int64_t inZp = *maybeInZp;
-    int64_t outZp = *maybeOutZp;
+    if (!failed(maybeInZp))
+      inZp = *maybeInZp;
+    if (!failed(maybeOutZp))
+      outZp = *maybeOutZp;
 
     if (isa<FloatType>(elementTy))
       return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
 
     if (isa<IntegerType>(elementTy)) {
-      if (!inZp && !outZp) {
+      if (!failed(maybeInZp) && !failed(maybeOutZp) && !inZp && !outZp) {
         auto constant = arith::ConstantOp::create(
             rewriter, loc, IntegerAttr::get(elementTy, 0));
         return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
                                      args[0]);
       }
 
+      Value zpAddValue;
+      Type intermediateType;
       // Compute the maximum value that can occur in the intermediate buffer.
       const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
-      const int64_t zpAdd = inZp + outZp;
-      const int64_t maxValue =
-          APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
-          std::abs(zpAdd) + 1;
-
-      // Convert that maximum value into the maximum bitwidth needed to
-      // represent it. We assume 48-bit numbers may be supported further in
-      // the pipeline.
       int intermediateBitWidth = 64;
-      if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
-        intermediateBitWidth = 16;
-      } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
-        intermediateBitWidth = 32;
-      } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
-        intermediateBitWidth = 48;
-      }
 
-      Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
-      Value zpAddValue = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+      if (!failed(maybeInZp) && !failed(maybeOutZp)) {
+        // Compute the maximum value that can occur in the intermediate buffer.
+        const int64_t zpAdd = inZp + outZp;
+        const int64_t maxValue =
+            APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
+            std::abs(zpAdd) + 1;
+
+        // Convert that maximum value into the maximum bitwidth needed to
+        // represent it. We assume 48-bit numbers may be supported further in
+        // the pipeline.
+        if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
+          intermediateBitWidth = 16;
+        } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
+          intermediateBitWidth = 32;
+        } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
+          intermediateBitWidth = 48;
+        }
+
+        intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+        zpAddValue = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+      } else {
+        intermediateType = rewriter.getIntegerType(intermediateBitWidth);
+        auto arg1 =
+            rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[1]);
+        auto arg2 =
+            rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[2]);
+        zpAddValue =
+            rewriter.create<arith::AddIOp>(loc, intermediateType, arg1, arg2);
+      }
 
       // The negation can be applied by doing:
       //  outputValue = inZp + outZp - inputValue
@@ -1013,9 +1018,14 @@ static ValueRange getBroadcastableOperands(Operation *operation,
     else
       return operands.take_front(3);
   }
-  // Input1_zp and output_zp cannot broadcast
-  if (isa<tosa::NegateOp>(operation))
+  if (auto negate = dyn_cast<tosa::NegateOp>(operation)) {
+    FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
+    FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
+    if (failed(maybeOutZp) && failed(maybeInZp))
+      return operands;
+    // Input1_zp and output_zp cannot broadcast when they are constants.
     return operands.take_front(1);
+  }
   return operands;
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 37af8b8859852..780344764e014 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -899,6 +899,37 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
 
 // -----
 
+func.func @negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> {
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK:   ^bb0([[ARG0:%.*]]: f16, [[ARG1:%.*]]: f16, [[ARG2:%.*]]: f16, [[OUT:%.*]]: f16)
+  // CHECK:   [[ELEMENT:%.*]] = arith.negf [[ARG0]] : f16
+  %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<50x42xf16>
+  %cast = tensor.cast %0 : tensor<50x42xf16> to tensor<*xf16>
+  return %cast : tensor<*xf16>
+}
+
+// -----
+
+func.func @negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> {
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK:   ^bb0([[ARG0:%.*]]: i16, [[ARG1:%.*]]: i16, [[ARG2:%.*]]: i16, [[OUT:%.*]]: i16)
+  // CHECK:   [[EXTSI1:%.*]] = arith.extsi [[ARG1]] : i16 to i64
+  // CHECK:   [[EXTSI2:%.*]] = arith.extsi [[ARG2]] : i16 to i64
+  // CHECK:   [[SUM:%.*]] = arith.addi [[EXTSI1]], [[EXTSI2]] : i64
+  // CHECK:   [[EXTSI0:%.*]] = arith.extsi [[ARG0]] : i16 to i64
+  // CHECK:   [[SUB:%.*]] = arith.subi [[SUM]], [[EXTSI0]] : i64
+  // CHECK:   [[C_32768:%.*]] = arith.constant -32768 : i64
+  // CHECK:   [[C32767:%.*]] = arith.constant 32767 : i64
+  // CHECK:   [[MAX:%.*]] = arith.maxsi [[C_32768]], [[SUB]] : i64
+  // CHECK:   [[MIN:%.*]] = arith.minsi [[C32767]], [[MAX]] : i64
+  // CHECK:   [[TRUNC:%.*]] = arith.trunci [[MIN]] : i64 to i16
+  %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xi16>, tensor<1xi16>, tensor<1xi16>) -> tensor<50x42xi16>
+  %cast = tensor.cast %0 : tensor<50x42xi16> to tensor<*xi16>
+  return %cast : tensor<*xi16>
+}
+
+// -----
+
 // CHECK-LABEL: @test_identity
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<1xf32>,
 // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xi32>

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ShivaChen, the lowering looks reasonable but I haven't had chance to test it. Had a couple of minor nitpicks, otherwise LGTM!

@ShivaChen
Copy link
Collaborator Author

Very appreciate your review in busy days.

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @ShivaChen!

@lhutton1 lhutton1 merged commit bd67b8f into llvm:main Sep 22, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants