-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][tosa-to-linalg] fix arithmetic_right_shift conversion with round #159930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Congcong Cai (HerrCai0907) ChangesFull diff: https://github.com/llvm/llvm-project/pull/159930.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1955eec9964eb..1e0aefded19c1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -298,6 +298,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
IntegerAttr::get(elementTy, 1));
auto zero = arith::ConstantOp::create(rewriter, loc,
IntegerAttr::get(elementTy, 0));
+ auto i1zero =
+ arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
auto i1one =
arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
@@ -315,9 +317,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
ArrayRef<NamedAttribute>());
auto isInputOdd =
arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
-
- auto shouldRound = arith::AndIOp::create(
- rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
+ // shifted, truncated, isInputOdd can be poison when input2 is 0.
+ auto shouldRound = arith::SelectOp::create(
+ rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
auto extended =
arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 37af8b8859852..e25b1565f39ee 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -698,13 +698,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
// CHECK: linalg.generic
// CHECK: arith.constant 1
// CHECK: arith.constant 0
+ // CHECK: arith.constant false
// CHECK: arith.constant true
// CHECK: arith.cmpi
// CHECK: arith.subi
// CHECK: arith.shrsi
// CHECK: arith.trunci
// CHECK: and
- // CHECK: and
+ // CHECK: arith.select
// CHECK: arith.extui
// CHECK: arith.addi
%12 = tosa.arithmetic_right_shift %arg0, %arg0 {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
@llvm/pr-subscribers-mlir-linalg Author: Congcong Cai (HerrCai0907) ChangesFull diff: https://github.com/llvm/llvm-project/pull/159930.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1955eec9964eb..1e0aefded19c1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -298,6 +298,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
IntegerAttr::get(elementTy, 1));
auto zero = arith::ConstantOp::create(rewriter, loc,
IntegerAttr::get(elementTy, 0));
+ auto i1zero =
+ arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
auto i1one =
arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
@@ -315,9 +317,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
ArrayRef<NamedAttribute>());
auto isInputOdd =
arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
-
- auto shouldRound = arith::AndIOp::create(
- rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
+ // shifted, truncated, isInputOdd can be poison when input2 is 0.
+ auto shouldRound = arith::SelectOp::create(
+ rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
auto extended =
arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 37af8b8859852..e25b1565f39ee 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -698,13 +698,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
// CHECK: linalg.generic
// CHECK: arith.constant 1
// CHECK: arith.constant 0
+ // CHECK: arith.constant false
// CHECK: arith.constant true
// CHECK: arith.cmpi
// CHECK: arith.subi
// CHECK: arith.shrsi
// CHECK: arith.trunci
// CHECK: and
- // CHECK: and
+ // CHECK: arith.select
// CHECK: arith.extui
// CHECK: arith.addi
%12 = tosa.arithmetic_right_shift %arg0, %arg0 {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks! Minor nit to the commit message:
Can you be more specific in what produce the poison value (value2 -1)? I feel that improves clarity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix @HerrCai0907!
Fixed: #154259
According to TOSA spec,
tosa.arithmetic_right_shift
should handle round.The original conversion is the similar as definition, and will convert to pseudo code
But when value2 is 0,
value1 >> (value2 - 1)
will produce poison value because performing arithmetic right shift on a negative number. Then the poison value propagate to the final result.This PR wants to change the conversion to
arith.select
to stop poison propagation.