diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 1955eec9964eb..1e0aefded19c1 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -298,6 +298,8 @@ static Value createLinalgBodyCalculationForElementwiseOp( IntegerAttr::get(elementTy, 1)); auto zero = arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(elementTy, 0)); + auto i1zero = + arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0)); auto i1one = arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1)); @@ -315,9 +317,9 @@ static Value createLinalgBodyCalculationForElementwiseOp( ArrayRef()); auto isInputOdd = arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one); - - auto shouldRound = arith::AndIOp::create( - rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); + // shifted, truncated, isInputOdd can be poison when input2 is 0. + auto shouldRound = arith::SelectOp::create( + rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero); auto extended = arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound); return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 37af8b8859852..e25b1565f39ee 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -698,13 +698,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns // CHECK: linalg.generic // CHECK: arith.constant 1 // CHECK: arith.constant 0 + // CHECK: arith.constant false // CHECK: arith.constant true // CHECK: arith.cmpi // CHECK: arith.subi // CHECK: arith.shrsi // CHECK: arith.trunci // CHECK: and - // CHECK: and + // CHECK: arith.select // CHECK: arith.extui // CHECK: arith.addi %12 = tosa.arithmetic_right_shift %arg0, %arg0 {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>