Skip to content

Commit

Permalink
Revert "[mlir][arith] Add canonicalization patterns for 'mul*i_extend…
Browse files Browse the repository at this point in the history
…ed'"

This reverts commit 834c17f.

Revert to properly address post-commit comments by @jpienaar
under https://reviews.llvm.org/D139778.
  • Loading branch information
kuhar committed Dec 13, 2022
1 parent 619b7ce commit 2c33031
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 257 deletions.
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,6 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
}];

let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand Down
55 changes: 0 additions & 55 deletions mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,6 @@ def MulSIExtendedToMulI :
[(Arith_MulIOp $x, $y), (replaceWithValue $x)],
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;

// mulsi_extended(x, 1) -> [x, extsi(cmpi slt, x, 0)]
def MulSIExtendedRHSOne :
Pattern<(Arith_MulSIExtendedOp $x, (Arith_ConstantOp $c1)),
[(replaceWithValue $x),
(Arith_ExtSIOp(Arith_CmpIOp
(NativeCodeCall<"arith::CmpIPredicate::slt">),
$x,
(Arith_ConstantOp (GetZeroAttr $x))))],
[(Constraint<CPred<"getIntOrSplatIntValue($0) == 1">> $c1)]>;

//===----------------------------------------------------------------------===//
// MulUIExtendedOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -253,51 +243,6 @@ def OrOfExtSI :
Pat<(Arith_OrIOp (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)), (Arith_ExtSIOp (Arith_OrIOp $x, $y)),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;

//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//

def ValuesWithSameType :
Constraint<
CPred<"llvm::all_equal({$0.getType(), $1.getType(), $2.getType()})">>;

def ValueWiderThan :
Constraint<
CPred<"getScalarOrElementWidth($0) > getScalarOrElementWidth($1)">>;

def TruncationMatchesShiftAmount :
Constraint<
CPred<"(getScalarOrElementWidth($0) - getScalarOrElementWidth($1)) == "
"getIntOrSplatIntValue($2)">>;

// trunci(shrsi(x, c)) -> trunci(shrui(x, c))
def TruncIShrSIToTrunciShrUI :
Pat<(Arith_TruncIOp:$tr (Arith_ShRSIOp $x, (Arith_ConstantOp $c0))),
(Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp $c0))),
[(TruncationMatchesShiftAmount $x, $tr, $c0)]>;

// trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
def TruncIShrUIMulIToMulSIExtended :
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
(Arith_MulIOp:$mul
(Arith_ExtSIOp $x), (Arith_ExtSIOp $y)),
(Arith_ConstantOp $c0))),
(Arith_MulSIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
(ValueWiderThan $mul, $x),
(TruncationMatchesShiftAmount $mul, $x, $c0)]>;

// trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)
def TruncIShrUIMulIToMulUIExtended :
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
(Arith_MulIOp:$mul
(Arith_ExtUIOp $x), (Arith_ExtUIOp $y)),
(Arith_ConstantOp $c0))),
(Arith_MulUIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
(ValueWiderThan $mul, $x),
(TruncationMatchesShiftAmount $mul, $x, $c0)]>;

//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
Expand Down
33 changes: 1 addition & 32 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,31 +74,6 @@ static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
invertPredicate(pred.getValue()));
}

static int64_t getScalarOrElementWidth(Type type) {
if (type.isIntOrFloat())
return type.getIntOrFloatBitWidth();

if (auto shapeTy = type.dyn_cast<ShapedType>())
return shapeTy.getElementTypeBitWidth();

return -1;
}

static int64_t getScalarOrElementWidth(Value value) {
return getScalarOrElementWidth(value.getType());
}

static int64_t getIntOrSplatIntValue(Attribute attr) {
if (auto intAttr = attr.dyn_cast<IntegerAttr>())
return intAttr.getInt();

if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>())
if (splatAttr.getElementType().isa<IntegerType>())
return splatAttr.getSplatValue<APInt>().getLimitedValue();

return -1;
}

//===----------------------------------------------------------------------===//
// TableGen'd canonicalization patterns
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -418,7 +393,7 @@ arith::MulSIExtendedOp::fold(ArrayRef<Attribute> operands,

void arith::MulSIExtendedOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
patterns.add<MulSIExtendedToMulI>(context);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1274,12 +1249,6 @@ bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
}

void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<TruncIShrSIToTrunciShrUI, TruncIShrUIMulIToMulSIExtended,
TruncIShrUIMulIToMulUIExtended>(context);
}

LogicalResult arith::TruncIOp::verify() {
return verifyTruncateOp<IntegerType>(*this);
}
Expand Down
169 changes: 0 additions & 169 deletions mlir/test/Dialect/Arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -761,30 +761,6 @@ func.func @mulsiExtendedZeroLhs(%arg0: i32) -> (i32, i32) {
return %low, %high : i32, i32
}

// CHECK-LABEL: @mulsiExtendedOneRhs
// CHECK-SAME: (%[[ARG:.+]]: i32) -> (i32, i32)
// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : i32
// CHECK-NEXT: %[[CMP:.+]] = arith.cmpi slt, %[[ARG]], %[[C0]] : i32
// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[CMP]] : i1 to i32
// CHECK-NEXT: return %[[ARG]], %[[EXT]] : i32, i32
func.func @mulsiExtendedOneRhs(%arg0: i32) -> (i32, i32) {
%one = arith.constant 1 : i32
%low, %high = arith.mulsi_extended %arg0, %one: i32
return %low, %high : i32, i32
}

// CHECK-LABEL: @mulsiExtendedOneRhsSplat
// CHECK-SAME: (%[[ARG:.+]]: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>)
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0> : vector<3xi32>
// CHECK-NEXT: %[[CMP:.+]] = arith.cmpi slt, %[[ARG]], %[[C0]] : vector<3xi32>
// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[CMP]] : vector<3xi1> to vector<3xi32>
// CHECK-NEXT: return %[[ARG]], %[[EXT]] : vector<3xi32>, vector<3xi32>
func.func @mulsiExtendedOneRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) {
%one = arith.constant dense<1> : vector<3xi32>
%low, %high = arith.mulsi_extended %arg0, %one: vector<3xi32>
return %low, %high : vector<3xi32>, vector<3xi32>
}

// CHECK-LABEL: @mulsiExtendedUnusedHigh
// CHECK-SAME: (%[[ARG:.+]]: i32) -> i32
// CHECK-NEXT: %[[RES:.+]] = arith.muli %[[ARG]], %[[ARG]] : i32
Expand Down Expand Up @@ -1940,148 +1916,3 @@ func.func @andand3(%a : i32, %b : i32) -> i32 {
%res = arith.andi %c, %b : i32
return %res : i32
}

// -----

// CHECK-LABEL: @truncIShrSIToTrunciShrUI
// CHECK-SAME: (%[[A:.+]]: i64)
// CHECK-NEXT: %[[C32:.+]] = arith.constant 32 : i64
// CHECK-NEXT: %[[SHR:.+]] = arith.shrui %[[A]], %[[C32]] : i64
// CHECK-NEXT: %[[TRU:.+]] = arith.trunci %[[SHR]] : i64 to i32
// CHECK-NEXT: return %[[TRU]] : i32
func.func @truncIShrSIToTrunciShrUI(%a: i64) -> i32 {
%c32 = arith.constant 32: i64
%sh = arith.shrsi %a, %c32 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}

// CHECK-LABEL: @truncIShrSIToTrunciShrUIBadShiftAmt1
// CHECK: arith.shrsi
func.func @truncIShrSIToTrunciShrUIBadShiftAmt1(%a: i64) -> i32 {
%c33 = arith.constant 33: i64
%sh = arith.shrsi %a, %c33 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}

// CHECK-LABEL: @truncIShrSIToTrunciShrUIBadShiftAmt2
// CHECK: arith.shrsi
func.func @truncIShrSIToTrunciShrUIBadShiftAmt2(%a: i64) -> i32 {
%c31 = arith.constant 31: i64
%sh = arith.shrsi %a, %c31 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}

// CHECK-LABEL: @wideMulToMulSIExtended
// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32)
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : i32
// CHECK-NEXT: return %[[HIGH]] : i32
func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 {
%x = arith.extsi %a: i32 to i64
%y = arith.extsi %b: i32 to i64
%m = arith.muli %x, %y: i64
%c32 = arith.constant 32: i64
%sh = arith.shrui %m, %c32 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}

// CHECK-LABEL: @wideMulToMulSIExtendedVector
// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32>
// CHECK-NEXT: return %[[HIGH]] : vector<3xi32>
func.func @wideMulToMulSIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> {
%x = arith.extsi %a: vector<3xi32> to vector<3xi64>
%y = arith.extsi %b: vector<3xi32> to vector<3xi64>
%m = arith.muli %x, %y: vector<3xi64>
%c32 = arith.constant dense<32>: vector<3xi64>
%sh = arith.shrui %m, %c32 : vector<3xi64>
%hi = arith.trunci %sh: vector<3xi64> to vector<3xi32>
return %hi : vector<3xi32>
}

// CHECK-LABEL: @wideMulToMulUIExtended
// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32)
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : i32
// CHECK-NEXT: return %[[HIGH]] : i32
func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 {
%x = arith.extui %a: i32 to i64
%y = arith.extui %b: i32 to i64
%m = arith.muli %x, %y: i64
%c32 = arith.constant 32: i64
%sh = arith.shrui %m, %c32 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}

// CHECK-LABEL: @wideMulToMulUIExtendedVector
// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32>
// CHECK-NEXT: return %[[HIGH]] : vector<3xi32>
func.func @wideMulToMulUIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> {
%x = arith.extui %a: vector<3xi32> to vector<3xi64>
%y = arith.extui %b: vector<3xi32> to vector<3xi64>
%m = arith.muli %x, %y: vector<3xi64>
%c32 = arith.constant dense<32>: vector<3xi64>
%sh = arith.shrui %m, %c32 : vector<3xi64>
%hi = arith.trunci %sh: vector<3xi64> to vector<3xi32>
return %hi : vector<3xi32>
}

// CHECK-LABEL: @wideMulToMulIExtendedMixedExt
// CHECK: arith.muli
// CHECK: arith.shrui
// CHECK: arith.trunci
func.func @wideMulToMulIExtendedMixedExt(%a: i32, %b: i32) -> i32 {
%x = arith.extsi %a: i32 to i64
%y = arith.extui %b: i32 to i64
%m = arith.muli %x, %y: i64
%c32 = arith.constant 32: i64
%sh = arith.shrui %m, %c32 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}

// CHECK-LABEL: @wideMulToMulSIExtendedBadExt
// CHECK: arith.muli
// CHECK: arith.shrui
// CHECK: arith.trunci
func.func @wideMulToMulSIExtendedBadExt(%a: i16, %b: i16) -> i32 {
%x = arith.extsi %a: i16 to i64
%y = arith.extsi %b: i16 to i64
%m = arith.muli %x, %y: i64
%c32 = arith.constant 32: i64
%sh = arith.shrui %m, %c32 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}

// CHECK-LABEL: @wideMulToMulSIExtendedBadShift1
// CHECK: arith.muli
// CHECK: arith.shrui
// CHECK: arith.trunci
func.func @wideMulToMulSIExtendedBadShift1(%a: i32, %b: i32) -> i32 {
%x = arith.extsi %a: i32 to i64
%y = arith.extsi %b: i32 to i64
%m = arith.muli %x, %y: i64
%c33 = arith.constant 33: i64
%sh = arith.shrui %m, %c33 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}

// CHECK-LABEL: @wideMulToMulSIExtendedBadShift2
// CHECK: arith.muli
// CHECK: arith.shrui
// CHECK: arith.trunci
func.func @wideMulToMulSIExtendedBadShift2(%a: i32, %b: i32) -> i32 {
%x = arith.extsi %a: i32 to i64
%y = arith.extsi %b: i32 to i64
%m = arith.muli %x, %y: i64
%c31 = arith.constant 31: i64
%sh = arith.shrui %m, %c31 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}

0 comments on commit 2c33031

Please sign in to comment.