12 changes: 3 additions & 9 deletions mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ namespace {
using AddFOpLowering =
VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
arith::AttrConvertFastMathToLLVM>;
using AddIOpLowering =
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
arith::AttrConvertOverflowToLLVM>;
using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
Expand Down Expand Up @@ -80,9 +78,7 @@ using MinUIOpLowering =
using MulFOpLowering =
VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
arith::AttrConvertFastMathToLLVM>;
using MulIOpLowering =
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
arith::AttrConvertOverflowToLLVM>;
using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
using NegFOpLowering =
VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
arith::AttrConvertFastMathToLLVM>;
Expand All @@ -106,9 +102,7 @@ using SIToFPOpLowering =
using SubFOpLowering =
VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
arith::AttrConvertFastMathToLLVM>;
using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
using TruncFOpLowering =
VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
using TruncIOpLowering =
Expand Down
59 changes: 3 additions & 56 deletions mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,61 +158,8 @@ getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
}

// TODO: Move to some common place?
static std::string getDecorationString(spirv::Decoration decor) {
return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
}

namespace {

/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
/// operations. Op can potentially support overflow flags.
template <typename Op, typename SPIRVOp>
struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
using OpConversionPattern<Op>::OpConversionPattern;

LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() <= 3);
auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
Type dstType = converter->convertType(op.getType());
if (!dstType) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
}

if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
!getElementTypeOrSelf(op.getType()).isIndex() &&
dstType != op.getType()) {
return op.emitError("bitwidth emulation is not implemented yet on "
"unsigned op pattern version");
}

auto overflowFlags = arith::IntegerOverflowFlags::none;
if (auto overflowIface =
dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
if (converter->getTargetEnv().allows(
spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
overflowFlags = overflowIface.getOverflowAttr().getValue();
}

auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
op, dstType, adaptor.getOperands());

if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
rewriter.getUnitAttr());

if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
rewriter.getUnitAttr());

return success();
}
};

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1207,9 +1154,9 @@ void mlir::arith::populateArithToSPIRVPatterns(
patterns.add<
ConstantCompositeOpPattern,
ConstantScalarOpPattern,
ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
Expand Down
94 changes: 39 additions & 55 deletions mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
// Multiply two integer attributes and create a new one with the result.
def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;

// TODO: Canonicalizations currently doesn't take into account integer overflow
// flags and always reset them to default (wraparound) which is safe but can
// inhibit later optimizations. Individual patterns must be reviewed for
// better handling of overflow flags.
def DefOverflow : NativeCodeCall<"getDefOverflowFlags($_builder)">;

class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;

//===----------------------------------------------------------------------===//
Expand All @@ -42,26 +36,23 @@ class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
// addi(addi(x, c0), c1) -> addi(x, c0 + c1)
def AddIAddConstant :
Pat<(Arith_AddIOp:$res
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
(DefOverflow))>;
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
(ConstantLikeMatcher APIntAttr:$c1)),
(Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>;

// addi(subi(x, c0), c1) -> addi(x, c1 - c0)
def AddISubConstantRHS :
Pat<(Arith_AddIOp:$res
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
(DefOverflow))>;
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
(ConstantLikeMatcher APIntAttr:$c1)),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;

// addi(subi(c0, x), c1) -> subi(c0 + c1, x)
def AddISubConstantLHS :
Pat<(Arith_AddIOp:$res
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
(DefOverflow))>;
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x),
(ConstantLikeMatcher APIntAttr:$c1)),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;

def IsScalarOrSplatNegativeOne :
Constraint<And<[
Expand All @@ -72,25 +63,24 @@ def IsScalarOrSplatNegativeOne :
def AddIMulNegativeOneRhs :
Pat<(Arith_AddIOp
$x,
(Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
(Arith_SubIOp $x, $y, (DefOverflow)),
(Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0))),
(Arith_SubIOp $x, $y),
[(IsScalarOrSplatNegativeOne $c0)]>;

// addi(muli(x, -1), y) -> subi(y, x)
def AddIMulNegativeOneLhs :
Pat<(Arith_AddIOp
(Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
$y, $ovf2),
(Arith_SubIOp $y, $x, (DefOverflow)),
(Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0)),
$y),
(Arith_SubIOp $y, $x),
[(IsScalarOrSplatNegativeOne $c0)]>;

// muli(muli(x, c0), c1) -> muli(x, c0 * c1)
def MulIMulIConstant :
Pat<(Arith_MulIOp:$res
(Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
(DefOverflow))>;
(Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
(ConstantLikeMatcher APIntAttr:$c1)),
(Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)))>;

//===----------------------------------------------------------------------===//
// AddUIExtendedOp
Expand All @@ -100,7 +90,7 @@ def MulIMulIConstant :
// uses. Since the 'overflow' result is unused, any replacement value will do.
def AddUIExtendedToAddI:
Pattern<(Arith_AddUIExtendedOp:$res $x, $y),
[(Arith_AddIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
[(Arith_AddIOp $x, $y), (replaceWithValue $x)],
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;

//===----------------------------------------------------------------------===//
Expand All @@ -110,55 +100,49 @@ def AddUIExtendedToAddI:
// subi(addi(x, c0), c1) -> addi(x, c0 - c1)
def SubIRHSAddConstant :
Pat<(Arith_SubIOp:$res
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)),
(DefOverflow))>;
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
(ConstantLikeMatcher APIntAttr:$c1)),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)))>;

// subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
def SubILHSAddConstant :
Pat<(Arith_SubIOp:$res
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
(DefOverflow))>;
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0))),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x)>;

// subi(subi(x, c0), c1) -> subi(x, c0 + c1)
def SubIRHSSubConstantRHS :
Pat<(Arith_SubIOp:$res
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
(DefOverflow))>;
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
(ConstantLikeMatcher APIntAttr:$c1)),
(Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>;

// subi(subi(c0, x), c1) -> subi(c0 - c1, x)
def SubIRHSSubConstantLHS :
Pat<(Arith_SubIOp:$res
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
(DefOverflow))>;
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x),
(ConstantLikeMatcher APIntAttr:$c1)),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x)>;

// subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
def SubILHSSubConstantRHS :
Pat<(Arith_SubIOp:$res
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
(DefOverflow))>;
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0))),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;

// subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
def SubILHSSubConstantLHS :
Pat<(Arith_SubIOp:$res
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
(DefOverflow))>;
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x)),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;

// subi(subi(a, b), a) -> subi(0, b)
def SubISubILHSRHSLHS :
Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
(Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, (DefOverflow))>;
Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x),
(Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>;

//===----------------------------------------------------------------------===//
// MulSIExtendedOp
Expand All @@ -168,7 +152,7 @@ def SubISubILHSRHSLHS :
// Since the `high` result it not used, any replacement value will do.
def MulSIExtendedToMulI :
Pattern<(Arith_MulSIExtendedOp:$res $x, $y),
[(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
[(Arith_MulIOp $x, $y), (replaceWithValue $x)],
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;


Expand All @@ -195,7 +179,7 @@ def MulSIExtendedRHSOne :
// Since the `high` result it not used, any replacement value will do.
def MulUIExtendedToMulI :
Pattern<(Arith_MulUIExtendedOp:$res $x, $y),
[(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
[(Arith_MulIOp $x, $y), (replaceWithValue $x)],
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -419,7 +403,7 @@ def TruncIShrSIToTrunciShrUI :
def TruncIShrUIMulIToMulSIExtended :
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
(Arith_MulIOp:$mul
(Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1),
(Arith_ExtSIOp $x), (Arith_ExtSIOp $y)),
(ConstantLikeMatcher AnyAttr:$c0))),
(Arith_MulSIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
Expand All @@ -430,7 +414,7 @@ def TruncIShrUIMulIToMulSIExtended :
def TruncIShrUIMulIToMulUIExtended :
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
(Arith_MulIOp:$mul
(Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1),
(Arith_ExtUIOp $x), (Arith_ExtUIOp $y)),
(ConstantLikeMatcher AnyAttr:$c0))),
(Arith_MulUIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
Expand Down
5 changes: 0 additions & 5 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
}

static IntegerOverflowFlagsAttr getDefOverflowFlags(OpBuilder &builder) {
return IntegerOverflowFlagsAttr::get(builder.getContext(),
IntegerOverflowFlags::none);
}

/// Invert an integer comparison predicate.
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
switch (pred) {
Expand Down
13 changes: 0 additions & 13 deletions mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -575,16 +575,3 @@ func.func @ops_supporting_fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
%7 = arith.subf %arg0, %arg1 fastmath<fast> : f32
return
}

// -----

// CHECK-LABEL: @ops_supporting_overflow
func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
// CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} overflow<nsw> : i64
%0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
// CHECK: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} overflow<nuw> : i64
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
// CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}
40 changes: 0 additions & 40 deletions mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1407,43 +1407,3 @@ func.func @float_scalar(%arg0: f16) {
}

} // end module

// -----

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Kernel], [SPV_KHR_no_integer_wrap_decoration]>, #spirv.resource_limits<>>
} {

// CHECK-LABEL: @ops_flags
func.func @ops_flags(%arg0: i64, %arg1: i64) {
// CHECK: %{{.*}} = spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap} : i64
%0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
// CHECK: %{{.*}} = spirv.ISub %{{.*}}, %{{.*}} {no_unsigned_wrap} : i64
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
// CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap} : i64
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}

} // end module


// -----

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, #spirv.resource_limits<>>
} {

// No decorations should be generated is corresponding Extensions/Capabilities are missing
// CHECK-LABEL: @ops_flags
func.func @ops_flags(%arg0: i64, %arg1: i64) {
// CHECK: %{{.*}} = spirv.IAdd %{{.*}}, %{{.*}} : i64
%0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
// CHECK: %{{.*}} = spirv.ISub %{{.*}}, %{{.*}} : i64
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
// CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}

} // end module
11 changes: 0 additions & 11 deletions mlir/test/Dialect/Arith/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1138,14 +1138,3 @@ func.func @select_tensor_encoding(
%0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "foo">, tensor<8xi32, "foo">
return %0 : tensor<8xi32, "foo">
}

// CHECK-LABEL: @intflags_func
func.func @intflags_func(%arg0: i64, %arg1: i64) {
// CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} overflow<nsw> : i64
%0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
// CHECK: %{{.*}} = arith.subi %{{.*}}, %{{.*}} overflow<nuw> : i64
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
// CHECK: %{{.*}} = arith.muli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}
2 changes: 1 addition & 1 deletion mlir/test/python/ir/diagnostic_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def testDiagnosticNonEmptyNotes():
def callback(d):
# CHECK: DIAGNOSTIC:
# CHECK: message='arith.addi' op requires one result
# CHECK: notes=['see current operation: "arith.addi"() {{.*}} : () -> ()']
# CHECK: notes=['see current operation: "arith.addi"() : () -> ()']
print(f"DIAGNOSTIC:")
print(f" message={d.message}")
print(f" notes={list(map(str, d.notes))}")
Expand Down