Skip to content

Commit

Permalink
Add index::CmpOp canonicalization.
Browse files Browse the repository at this point in the history
Add canonicalization pattern for index::CmpOp

Differential Revision: https://reviews.llvm.org/D157903
  • Loading branch information
weiweichen committed Aug 15, 2023
1 parent f8ad86c commit 2b2889b
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 4 deletions.
7 changes: 5 additions & 2 deletions mlir/include/mlir/Dialect/Index/IR/IndexOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
// Forward Declarations
//===----------------------------------------------------------------------===//

namespace mlir::index {
namespace mlir {
class PatternRewriter;
namespace index {
enum class IndexCmpPredicate : uint32_t;
class IndexCmpPredicateAttr;
} // namespace mlir::index
} // namespace index
} // namespace mlir

//===----------------------------------------------------------------------===//
// ODS-Generated Declarations
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Index/IR/IndexOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ def Index_CmpOp : IndexOp<"cmp"> {
let results = (outs I1:$result);
let assemblyFormat = "`` $pred `(` $lhs `,` $rhs `)` attr-dict";
let hasFolder = 1;
let hasCanonicalizeMethod = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
32 changes: 32 additions & 0 deletions mlir/lib/Dialect/Index/IR/IndexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/TypeSwitch.h"
Expand Down Expand Up @@ -549,6 +550,37 @@ OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
return {};
}

/// Canonicalize
/// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`.
/// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`.
LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
IntegerAttr cmpRhs;
IntegerAttr cmpLhs;

bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) &&
cmpRhs.getValue().isZero();
bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) &&
cmpLhs.getValue().isZero();
if (!rhsIsZero && !lhsIsZero)
return rewriter.notifyMatchFailure(op.getLoc(),
"cmp is not comparing something with 0");
SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
: op.getRhs().getDefiningOp<index::SubOp>();
if (!subOp)
return rewriter.notifyMatchFailure(
op.getLoc(), "non-zero operand is not a result of subtraction");

index::CmpOp newCmp;
if (rhsIsZero)
newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
subOp.getLhs(), subOp.getRhs());
else
newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
subOp.getRhs(), subOp.getLhs());
rewriter.replaceOp(op, newCmp);
return success();
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
Expand Down
13 changes: 11 additions & 2 deletions mlir/test/Dialect/Index/index-canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ func.func @xor() -> index {
}

// CHECK-LABEL: @cmp
func.func @cmp() -> (i1, i1, i1, i1) {
func.func @cmp(%arg0: index) -> (i1, i1, i1, i1, i1, i1) {
%a = index.constant 0
%b = index.constant -1
%c = index.constant -2
Expand All @@ -484,10 +484,19 @@ func.func @cmp() -> (i1, i1, i1, i1) {
%2 = index.cmp ne(%d, %a)
%3 = index.cmp sgt(%b, %a)

%4 = index.sub %a, %arg0
%5 = index.cmp sgt(%4, %a)

%6 = index.sub %a, %arg0
%7 = index.cmp sgt(%a, %6)

// CHECK-DAG: %[[TRUE:.*]] = index.bool.constant true
// CHECK-DAG: %[[FALSE:.*]] = index.bool.constant false
// CHECK-DAG: [[IDX0:%.*]] = index.constant 0
// CHECK-DAG: [[V4:%.*]] = index.cmp sgt([[IDX0]], %arg0)
// CHECK-DAG: [[V5:%.*]] = index.cmp sgt(%arg0, [[IDX0]])
// CHECK: return %[[FALSE]], %[[TRUE]], %[[TRUE]], %[[FALSE]]
return %0, %1, %2, %3 : i1, i1, i1, i1
return %0, %1, %2, %3, %5, %7 : i1, i1, i1, i1, i1, i1
}

// CHECK-LABEL: @cmp_nofold
Expand Down

0 comments on commit 2b2889b

Please sign in to comment.