Skip to content

Commit 2e4e218

Browse files
rikhuijzerJeff Niu
authored andcommitted
[mlir] Avoid folding index.remu and index.rems for 0 rhs
As discussed in #59714 (comment), the folder for the remainder operations should be resillient when the rhs is 0. The file `IndexOps.cpp` was already checking for multiple divisions by zero, so I tried to stick to the code style from those checks. Fixes #59714. As a side note, is it correct that remainder operations are never optimized away? I would expect that the following code ``` func.func @remu_test() -> index { %c3 = index.constant 2 %c0 = index.constant 1 %0 = index.remu %c3, %c0 return %0 : index } ``` would be optimized to ``` func.func @remu_test() -> index { return index.constant 0 : index } ``` when called with `mlir-opt --convert-scf-to-openmp temp.mlir`, but maybe I'm misunderstanding something. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D151476
1 parent 3332dc3 commit 2e4e218

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

mlir/lib/Dialect/Index/IR/IndexOps.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,12 @@ OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
263263
OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
264264
return foldBinaryOpChecked(
265265
adaptor.getOperands(),
266-
[](const APInt &lhs, const APInt &rhs) { return lhs.srem(rhs); });
266+
[](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
267+
// Don't fold division by zero.
268+
if (rhs.isZero())
269+
return std::nullopt;
270+
return lhs.srem(rhs);
271+
});
267272
}
268273

269274
//===----------------------------------------------------------------------===//
@@ -273,7 +278,12 @@ OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
273278
OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
274279
return foldBinaryOpChecked(
275280
adaptor.getOperands(),
276-
[](const APInt &lhs, const APInt &rhs) { return lhs.urem(rhs); });
281+
[](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
282+
// Don't fold division by zero.
283+
if (rhs.isZero())
284+
return std::nullopt;
285+
return lhs.urem(rhs);
286+
});
277287
}
278288

279289
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Index/index-canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,24 @@ func.func @floordivs_nofold() -> index {
198198
return %0 : index
199199
}
200200

201+
// CHECK-LABEL: @rems_zerodiv_nofold
202+
func.func @rems_zerodiv_nofold() -> index {
203+
%lhs = index.constant 2
204+
%rhs = index.constant 0
205+
// CHECK: index.rems
206+
%0 = index.rems %lhs, %rhs
207+
return %0 : index
208+
}
209+
210+
// CHECK-LABEL: @remu_zerodiv_nofold
211+
func.func @remu_zerodiv_nofold() -> index {
212+
%lhs = index.constant 2
213+
%rhs = index.constant 0
214+
// CHECK: index.remu
215+
%0 = index.remu %lhs, %rhs
216+
return %0 : index
217+
}
218+
201219
// CHECK-LABEL: @rems
202220
func.func @rems() -> index {
203221
%lhs = index.constant -5

0 commit comments

Comments
 (0)