Skip to content

Conversation

HerrCai0907
Copy link
Contributor

@HerrCai0907 HerrCai0907 commented Sep 20, 2025

Fixed: #154259

According to TOSA spec, tosa.arithmetic_right_shift should handle round.

if (round == true && static_cast<int32_t>(value2) > 0 &&
    (apply_arith_rshift<in_out_t>(value1, apply_sub_s<in_out_t>(value2, 1)) & 1 != 0)) {
    result = result + 1;
}

The original conversion is the similar as definition, and will convert to pseudo code

result = (value1 >> value2) +
              ( (i1)(value2 > 0) & (i1)((value1 >> (value2 - 1)) & 1) )

But when value2 is 0,value1 >> (value2 - 1) will produce poison value because performing arithmetic right shift on a negative number. Then the poison value propagate to the final result.

This PR wants to change the conversion to arith.select to stop poison propagation.

result = (value1 >> value2) + 
             (value2 > 0) ? (i1)((value1 >> (value2 - 1)) & 1) : (i1)(0)

@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Congcong Cai (HerrCai0907)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/159930.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+5-3)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+2-1)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1955eec9964eb..1e0aefded19c1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -298,6 +298,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                                          IntegerAttr::get(elementTy, 1));
     auto zero = arith::ConstantOp::create(rewriter, loc,
                                           IntegerAttr::get(elementTy, 0));
+    auto i1zero =
+        arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
     auto i1one =
         arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
 
@@ -315,9 +317,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                                              ArrayRef<NamedAttribute>());
     auto isInputOdd =
         arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
-
-    auto shouldRound = arith::AndIOp::create(
-        rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
+    // shifted, truncated, isInputOdd can be poison when input2 is 0.
+    auto shouldRound = arith::SelectOp::create(
+        rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
     auto extended =
         arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
     return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 37af8b8859852..e25b1565f39ee 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -698,13 +698,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
   // CHECK: linalg.generic
   // CHECK: arith.constant 1
   // CHECK: arith.constant 0
+  // CHECK: arith.constant false
   // CHECK: arith.constant true
   // CHECK: arith.cmpi
   // CHECK: arith.subi
   // CHECK: arith.shrsi
   // CHECK: arith.trunci
   // CHECK: and
-  // CHECK: and
+  // CHECK: arith.select
   // CHECK: arith.extui
   // CHECK: arith.addi
   %12 = tosa.arithmetic_right_shift %arg0, %arg0 {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Congcong Cai (HerrCai0907)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/159930.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+5-3)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+2-1)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1955eec9964eb..1e0aefded19c1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -298,6 +298,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                                          IntegerAttr::get(elementTy, 1));
     auto zero = arith::ConstantOp::create(rewriter, loc,
                                           IntegerAttr::get(elementTy, 0));
+    auto i1zero =
+        arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
     auto i1one =
         arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
 
@@ -315,9 +317,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                                              ArrayRef<NamedAttribute>());
     auto isInputOdd =
         arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
-
-    auto shouldRound = arith::AndIOp::create(
-        rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
+    // shifted, truncated, isInputOdd can be poison when input2 is 0.
+    auto shouldRound = arith::SelectOp::create(
+        rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
     auto extended =
         arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
     return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 37af8b8859852..e25b1565f39ee 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -698,13 +698,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
   // CHECK: linalg.generic
   // CHECK: arith.constant 1
   // CHECK: arith.constant 0
+  // CHECK: arith.constant false
   // CHECK: arith.constant true
   // CHECK: arith.cmpi
   // CHECK: arith.subi
   // CHECK: arith.shrsi
   // CHECK: arith.trunci
   // CHECK: and
-  // CHECK: and
+  // CHECK: arith.select
   // CHECK: arith.extui
   // CHECK: arith.addi
   %12 = tosa.arithmetic_right_shift %arg0, %arg0 {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

Copy link
Contributor

@RoboTux RoboTux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks! Minor nit to the commit message:

Can you be more specific in what produce the poison value (value2 -1)? I feel that improves clarity.

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix @HerrCai0907!

@HerrCai0907 HerrCai0907 merged commit 7a8bff4 into llvm:main Sep 22, 2025
13 checks passed
@HerrCai0907 HerrCai0907 deleted the fix/154259 branch September 22, 2025 23:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR] Inconsistent output with -convert-arith-to-llvm
4 participants