diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 844905c20210a..f812b3c61b366 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2221,6 +2221,9 @@ LogicalResult arith::SelectOp::verify() { //===----------------------------------------------------------------------===// OpFoldResult arith::ShLIOp::fold(ArrayRef operands) { + // shli(x, 0) -> x + if (matchPattern(getRhs(), m_Zero())) + return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; auto result = constFoldBinaryOp( @@ -2236,6 +2239,9 @@ OpFoldResult arith::ShLIOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// OpFoldResult arith::ShRUIOp::fold(ArrayRef operands) { + // shrui(x, 0) -> x + if (matchPattern(getRhs(), m_Zero())) + return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; auto result = constFoldBinaryOp( @@ -2251,6 +2257,9 @@ OpFoldResult arith::ShRUIOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// OpFoldResult arith::ShRSIOp::fold(ArrayRef operands) { + // shrsi(x, 0) -> x + if (matchPattern(getRhs(), m_Zero())) + return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; auto result = constFoldBinaryOp( diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index d181ae9d9d111..02cbaa28ab3f7 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2107,3 +2107,30 @@ func.func @wideMulToMulSIExtendedBadShift2(%a: i32, %b: i32) -> i32 { %hi = arith.trunci %sh: i64 to i32 return %hi : i32 } + +// CHECK-LABEL: @foldShli0 +// CHECK-SAME: (%[[ARG:.*]]: i64) +// CHECK: return %[[ARG]] : i64 +func.func @foldShli0(%x : i64) -> i64 { + %c0 = arith.constant 0 : i64 + %r = arith.shli %x, %c0 : i64 + return %r : i64 +} + +// CHECK-LABEL: @foldShrui0 +// CHECK-SAME: (%[[ARG:.*]]: i64) +// CHECK: return %[[ARG]] : i64 +func.func @foldShrui0(%x : i64) -> i64 { + %c0 = arith.constant 0 : i64 + %r = arith.shrui %x, %c0 : i64 + return %r : i64 +} + +// CHECK-LABEL: @foldShrsi0 +// CHECK-SAME: (%[[ARG:.*]]: i64) +// CHECK: return %[[ARG]] : i64 +func.func @foldShrsi0(%x : i64) -> i64 { + %c0 = arith.constant 0 : i64 + %r = arith.shrsi %x, %c0 : i64 + return %r : i64 +}