diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 20c9097b51e6d..739d0439c4bba 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1008,6 +1008,7 @@ def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> { def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> { let summary = "signed integer maximum operation"; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -1017,6 +1018,7 @@ def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> { def Arith_MaxUIOp : Arith_TotalIntBinaryOp<"maxui", [Commutative]> { let summary = "unsigned integer maximum operation"; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -1067,6 +1069,7 @@ def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> { def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> { let summary = "signed integer minimum operation"; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -1076,6 +1079,7 @@ def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> { def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> { let summary = "unsigned integer minimum operation"; let hasFolder = 1; + let hasCanonicalizer = 1; } diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index de3efc9fe3506..ef57af86f0540 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -24,6 +24,18 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">; // Multiply two integer attributes and create a new one with the result. def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">; +// Select signed min value of two integer attributes and store to the result +def SMinIntAttrs : NativeCodeCall<"sminIntegerAttrs($_builder, $0, $1, $2)">; + +// Select unsigned min value of two integer attributes and store to the result +def UMinIntAttrs : NativeCodeCall<"uminIntegerAttrs($_builder, $0, $1, $2)">; + +// Select signed max value of two integer attributes and store to the result +def SMaxIntAttrs : NativeCodeCall<"smaxIntegerAttrs($_builder, $0, $1, $2)">; + +// Select unsigned max value of two integer attributes and store to the result +def UMaxIntAttrs : NativeCodeCall<"umaxIntegerAttrs($_builder, $0, $1, $2)">; + // Merge overflow flags from 2 ops, selecting the most conservative combination. def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">; @@ -202,6 +214,62 @@ def MulUIExtendedToMulI : [(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)], [(Constraint> $res__1)]>; +//===----------------------------------------------------------------------===// +// MaxSIOp +//===----------------------------------------------------------------------===// + +// maxsi is commutative and will be canonicalized to have its constants appear +// as the second operand. + +// maxsi(maxsi(x, c0), c1) -> maxsi(x, maxsi(c0, c1)) +def MaxSIMaxSIConstant : + Pat<(Arith_MaxSIOp:$res + (Arith_MaxSIOp $x, (ConstantLikeMatcher APIntAttr:$c0)), + (ConstantLikeMatcher APIntAttr:$c1)), + (Arith_MaxSIOp $x, (Arith_ConstantOp (SMaxIntAttrs $res, $c0, $c1)))>; + +//===----------------------------------------------------------------------===// +// MaxUIOp +//===----------------------------------------------------------------------===// + +// maxui is commutative and will be canonicalized to have its constants appear +// as the second operand. + +// maxui(maxui(x, c0), c1) -> maxui(x, maxui(c0, c1)) +def MaxUIMaxUIConstant : + Pat<(Arith_MaxUIOp:$res + (Arith_MaxUIOp $x, (ConstantLikeMatcher APIntAttr:$c0)), + (ConstantLikeMatcher APIntAttr:$c1)), + (Arith_MaxUIOp $x, (Arith_ConstantOp (UMaxIntAttrs $res, $c0, $c1)))>; + +//===----------------------------------------------------------------------===// +// MinSIOp +//===----------------------------------------------------------------------===// + +// minsi is commutative and will be canonicalized to have its constants appear +// as the second operand. + +// minsi(minsi(x, c0), c1) -> minsi(x, minsi(c0, c1)) +def MinSIMinSIConstant : + Pat<(Arith_MinSIOp:$res + (Arith_MinSIOp $x, (ConstantLikeMatcher APIntAttr:$c0)), + (ConstantLikeMatcher APIntAttr:$c1)), + (Arith_MinSIOp $x, (Arith_ConstantOp (SMinIntAttrs $res, $c0, $c1)))>; + +//===----------------------------------------------------------------------===// +// MinUIOp +//===----------------------------------------------------------------------===// + +// minui is commutative and will be canonicalized to have its constants appear +// as the second operand. + +// minui(minui(x, c0), c1) -> minui(x, minui(c0, c1)) +def MinUIMinUIConstant : + Pat<(Arith_MinUIOp:$res + (Arith_MinUIOp $x, (ConstantLikeMatcher APIntAttr:$c0)), + (ConstantLikeMatcher APIntAttr:$c1)), + (Arith_MinUIOp $x, (Arith_ConstantOp (UMinIntAttrs $res, $c0, $c1)))>; + //===----------------------------------------------------------------------===// // XOrIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 7cfd6d3a98df8..82270ab64f7ec 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -63,6 +63,26 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies()); } +static IntegerAttr sminIntegerAttrs(PatternRewriter &builder, Value res, + Attribute lhs, Attribute rhs) { + return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::smin); +} + +static IntegerAttr uminIntegerAttrs(PatternRewriter &builder, Value res, + Attribute lhs, Attribute rhs) { + return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::umin); +} + +static IntegerAttr smaxIntegerAttrs(PatternRewriter &builder, Value res, + Attribute lhs, Attribute rhs) { + return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::smax); +} + +static IntegerAttr umaxIntegerAttrs(PatternRewriter &builder, Value res, + Attribute lhs, Attribute rhs) { + return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::umax); +} + // Merge overflow flags from 2 ops, selecting the most conservative combination. static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, @@ -1162,6 +1182,11 @@ OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) { }); } +void arith::MaxSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // MaxUIOp //===----------------------------------------------------------------------===// @@ -1187,6 +1212,11 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) { }); } +void arith::MaxUIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // MinimumFOp //===----------------------------------------------------------------------===// @@ -1248,6 +1278,11 @@ OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) { }); } +void arith::MinSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // MinUIOp //===----------------------------------------------------------------------===// @@ -1273,6 +1308,11 @@ OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) { }); } +void arith::MinUIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index ca3de3a2d7703..1848decc2eb7c 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -1952,6 +1952,30 @@ func.func @bitcastChain(%arg: i16) -> f16 { // ----- +// CHECK-LABEL: @maxsiMaxsiConst1 +// CHECK: %[[C42:.+]] = arith.constant 42 : i32 +// CHECK: %[[RES:.+]] = arith.maxsi %arg0, %[[C42]] : i32 +// CHECK: return %[[RES]] +func.func @maxsiMaxsiConst1(%arg0: i32) -> i32 { + %c17 = arith.constant 17 : i32 + %c42 = arith.constant 42 : i32 + %max1 = arith.maxsi %arg0, %c17 : i32 + %max2 = arith.maxsi %max1, %c42 : i32 + return %max2 : i32 +} + +// CHECK-LABEL: @maxsiMaxsiConst2 +// CHECK: %[[C21:.+]] = arith.constant 21 : i32 +// CHECK: %[[RES:.+]] = arith.maxsi %arg0, %[[C21]] : i32 +// CHECK: return %[[RES]] +func.func @maxsiMaxsiConst2(%arg0: i32) -> i32 { + %c7 = arith.constant 7 : i32 + %c21 = arith.constant 21 : i32 + %max1 = arith.maxsi %arg0, %c7 : i32 + %max2 = arith.maxsi %c21, %max1 : i32 + return %max2 : i32 +} + // CHECK-LABEL: test_maxsi // CHECK-DAG: %[[C0:.+]] = arith.constant 42 // CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant 127 @@ -1986,6 +2010,30 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) { // ----- +// CHECK-LABEL: @maxuiMaxuiConst1 +// CHECK: %[[C42:.+]] = arith.constant 42 : index +// CHECK: %[[RES:.+]] = arith.maxui %arg0, %[[C42]] : index +// CHECK: return %[[RES]] +func.func @maxuiMaxuiConst1(%arg0: index) -> index { + %c17 = arith.constant 17 : index + %c42 = arith.constant 42 : index + %max1 = arith.maxui %arg0, %c17 : index + %max2 = arith.maxui %max1, %c42 : index + return %max2 : index +} + +// CHECK-LABEL: @maxuiMaxuiConst2 +// CHECK: %[[C21:.+]] = arith.constant 21 : index +// CHECK: %[[RES:.+]] = arith.maxui %arg0, %[[C21]] : index +// CHECK: return %[[RES]] +func.func @maxuiMaxuiConst2(%arg0: index) -> index { + %c7 = arith.constant 7 : index + %c21 = arith.constant 21 : index + %max1 = arith.maxui %arg0, %c7 : index + %max2 = arith.maxui %c21, %max1 : index + return %max2 : index +} + // CHECK-LABEL: test_maxui // CHECK-DAG: %[[C0:.+]] = arith.constant 42 // CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant -1 @@ -2020,6 +2068,30 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) { // ----- +// CHECK-LABEL: @minsiMinsiConst1 +// CHECK: %[[C17:.+]] = arith.constant 17 : i32 +// CHECK: %[[RES:.+]] = arith.minsi %arg0, %[[C17]] : i32 +// CHECK: return %[[RES]] +func.func @minsiMinsiConst1(%arg0: i32) -> i32 { + %c17 = arith.constant 17 : i32 + %c42 = arith.constant 42 : i32 + %min1 = arith.minsi %arg0, %c17 : i32 + %min2 = arith.minsi %min1, %c42 : i32 + return %min2 : i32 +} + +// CHECK-LABEL: @minsiMinsiConst2 +// CHECK: %[[C7:.+]] = arith.constant 7 : i32 +// CHECK: %[[RES:.+]] = arith.minsi %arg0, %[[C7]] : i32 +// CHECK: return %[[RES]] +func.func @minsiMinsiConst2(%arg0: i32) -> i32 { + %c7 = arith.constant 7 : i32 + %c21 = arith.constant 21 : i32 + %min1 = arith.minsi %arg0, %c7 : i32 + %min2 = arith.minsi %c21, %min1 : i32 + return %min2 : i32 +} + // CHECK-LABEL: test_minsi // CHECK-DAG: %[[C0:.+]] = arith.constant 42 // CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant -128 @@ -2054,6 +2126,30 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) { // ----- +// CHECK-LABEL: @minuiMinuiConst1 +// CHECK: %[[C17:.+]] = arith.constant 17 : index +// CHECK: %[[RES:.+]] = arith.minui %arg0, %[[C17]] : index +// CHECK: return %[[RES]] +func.func @minuiMinuiConst1(%arg0: index) -> index { + %c17 = arith.constant 17 : index + %c42 = arith.constant 42 : index + %min1 = arith.minui %arg0, %c17 : index + %min2 = arith.minui %min1, %c42 : index + return %min2 : index +} + +// CHECK-LABEL: @minuiMinuiConst2 +// CHECK: %[[C7:.+]] = arith.constant 7 : index +// CHECK: %[[RES:.+]] = arith.minui %arg0, %[[C7]] : index +// CHECK: return %[[RES]] +func.func @minuiMinuiConst2(%arg0: index) -> index { + %c7 = arith.constant 7 : index + %c21 = arith.constant 21 : index + %min1 = arith.minui %arg0, %c7 : index + %min2 = arith.minui %c21, %min1 : index + return %min2 : index +} + // CHECK-LABEL: test_minui // CHECK-DAG: %[[C0:.+]] = arith.constant 42 // CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant 0 @@ -3377,4 +3473,3 @@ func.func @unreachable() { %add = arith.addi %add, %c1_i64 : i64 cf.br ^unreachable } -