Skip to content

Commit

Permalink
[MLIR][Arith] Canonicalize cmpi of extui/extsi
Browse files Browse the repository at this point in the history
Canonicalize cmpi(eq, ext a, ext b) and cmpi(ne, ext a, ext b)

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D120620
  • Loading branch information
wsmoses committed Mar 2, 2022
1 parent 17d7134 commit 2af81c6
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 0 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,7 @@ def Arith_CmpIOp : Arith_CompareOpOfAnyRank<"cmpi"> {
}];

let hasFolder = 1;
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,28 @@ def XOrINotCmpI :
(Arith_ConstantOp ConstantAttr<I1Attr, "1">)),
(Arith_CmpIOp (InvertPredicate $pred), $a, $b)>;

//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//

// cmpi(== or !=, a ext iNN, b ext iNN) == cmpi(== or !=, a, b)
def CmpIExtSI :
Pat<(Arith_CmpIOp $pred,
(Arith_ExtSIOp $a),
(Arith_ExtSIOp $b)),
(Arith_CmpIOp $pred, $a, $b),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $a, $b),
(Constraint<CPred<"$0.getValue() == arith::CmpIPredicate::eq || $0.getValue() == arith::CmpIPredicate::ne">> $pred)]>;

// cmpi(== or !=, a ext iNN, b ext iNN) == cmpi(== or !=, a, b)
def CmpIExtUI :
Pat<(Arith_CmpIOp $pred,
(Arith_ExtUIOp $a),
(Arith_ExtUIOp $b)),
(Arith_CmpIOp $pred, $a, $b),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $a, $b),
(Constraint<CPred<"$0.getValue() == arith::CmpIPredicate::eq || $0.getValue() == arith::CmpIPredicate::ne">> $pred)]>;

//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,11 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
return BoolAttr::get(getContext(), val);
}

void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.insert<CmpIExtSI, CmpIExtUI>(context);
}

//===----------------------------------------------------------------------===//
// CmpFOp
//===----------------------------------------------------------------------===//
Expand Down
42 changes: 42 additions & 0 deletions mlir/test/Dialect/Arithmetic/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,48 @@ func @extSIOfExtSI(%arg0: i1) -> i64 {

// -----

// CHECK-LABEL: @cmpIExtSINE
// CHECK: %[[comb:.+]] = arith.cmpi ne, %arg0, %arg1 : i8
// CHECK: return %[[comb]]
func @cmpIExtSINE(%arg0: i8, %arg1: i8) -> i1 {
%ext0 = arith.extsi %arg0 : i8 to i64
%ext1 = arith.extsi %arg1 : i8 to i64
%res = arith.cmpi ne, %ext0, %ext1 : i64
return %res : i1
}

// CHECK-LABEL: @cmpIExtSIEQ
// CHECK: %[[comb:.+]] = arith.cmpi eq, %arg0, %arg1 : i8
// CHECK: return %[[comb]]
func @cmpIExtSIEQ(%arg0: i8, %arg1: i8) -> i1 {
%ext0 = arith.extsi %arg0 : i8 to i64
%ext1 = arith.extsi %arg1 : i8 to i64
%res = arith.cmpi eq, %ext0, %ext1 : i64
return %res : i1
}

// CHECK-LABEL: @cmpIExtUINE
// CHECK: %[[comb:.+]] = arith.cmpi ne, %arg0, %arg1 : i8
// CHECK: return %[[comb]]
func @cmpIExtUINE(%arg0: i8, %arg1: i8) -> i1 {
%ext0 = arith.extui %arg0 : i8 to i64
%ext1 = arith.extui %arg1 : i8 to i64
%res = arith.cmpi ne, %ext0, %ext1 : i64
return %res : i1
}

// CHECK-LABEL: @cmpIExtUIEQ
// CHECK: %[[comb:.+]] = arith.cmpi eq, %arg0, %arg1 : i8
// CHECK: return %[[comb]]
func @cmpIExtUIEQ(%arg0: i8, %arg1: i8) -> i1 {
%ext0 = arith.extui %arg0 : i8 to i64
%ext1 = arith.extui %arg1 : i8 to i64
%res = arith.cmpi eq, %ext0, %ext1 : i64
return %res : i1
}

// -----

// CHECK-LABEL: @andOfExtSI
// CHECK: %[[comb:.+]] = arith.andi %arg0, %arg1 : i8
// CHECK: %[[ext:.+]] = arith.extsi %[[comb]] : i8 to i64
Expand Down

0 comments on commit 2af81c6

Please sign in to comment.