Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][arith] Add overflow flags support to arith ops #77211

Merged
merged 7 commits into from Jan 9, 2024

Conversation

Hardcode84
Copy link
Contributor

@Hardcode84 Hardcode84 commented Jan 6, 2024

Add overflow flags support to the following ops:

  • arith.addi
  • arith.subi
  • arith.muli

Example of new syntax:

%res = arith.addi %arg1, %arg2 overflow<nsw> : i64

Similar to existing LLVM dialect syntax

%res = llvm.add %arg1, %arg2 overflow<nsw> : i64

Tablegen canonicalization patterns updated to always drop flags, proper support with tests will be added later.

Updated LLVMIR translation as part of this commit as it currenly written in a way that it will crash when new attributes added to arith ops otherwise.

Discussion https://discourse.llvm.org/t/rfc-integer-overflow-flags-support-in-arith-dialect/76025

Add overflow flags support to the following ops:
* `arith.addi`
* `arith.subi`
* `arith.muli`

Tablegen canonicalization patterns updated to always drop flags, proper support with tests will be added later.

Updated LLVMIR translation as part of this commit as it currenly written in a way that it will crash when new attributes added to arith ops otherwise.
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 6, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Ivan Butygin (Hardcode84)

Changes

Add overflow flags support to the following ops:

  • arith.addi
  • arith.subi
  • arith.muli

Tablegen canonicalization patterns updated to always drop flags, proper support with tests will be added later.

Updated LLVMIR translation as part of this commit as it currenly written in a way that it will crash when new attributes added to arith ops otherwise.


Patch is 24.26 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/77211.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h (+35)
  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithBase.td (+23)
  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+19-3)
  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td (+57)
  • (modified) mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp (+26-3)
  • (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+9-3)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+41-39)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+5)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+13)
  • (modified) mlir/test/Dialect/Arith/ops.mlir (+11)
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index eea16b4da6a690..dbd0726fe16d1a 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -26,6 +26,14 @@ convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);
 LLVM::FastmathFlagsAttr
 convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
 
+// Map arithmetic overflow enum values to LLVMIR enum values.
+LLVM::IntegerOverflowFlags
+convertArithOveflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
+
+// Create an LLVM overflow attribute from a given arithmetic overflow attribute.
+LLVM::IntegerOverflowFlagsAttr
+convertArithOveflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
+
 // Attribute converter that populates a NamedAttrList by removing the fastmath
 // attribute from the source operation attributes, and replacing it with an
 // equivalent LLVM fastmath attribute.
@@ -49,6 +57,33 @@ class AttrConvertFastMathToLLVM {
 
   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
 
+private:
+  NamedAttrList convertedAttr;
+};
+
+// Attribute converter that populates a NamedAttrList by removing the overflow
+// attribute from the source operation attributes, and replacing it with an
+// equivalent LLVM fastmath attribute.
+template <typename SourceOp, typename TargetOp>
+class AttrConvertOverflowToLLVM {
+public:
+  AttrConvertOverflowToLLVM(SourceOp srcOp) {
+    // Copy the source attributes.
+    convertedAttr = NamedAttrList{srcOp->getAttrs()};
+    // Get the name of the arith fastmath attribute.
+    llvm::StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
+    // Remove the source fastmath attribute.
+    auto arithAttr = dyn_cast_or_null<arith::IntegerOverflowFlagsAttr>(
+        convertedAttr.erase(arithAttrName));
+    if (arithAttr) {
+      llvm::StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
+      convertedAttr.set(targetAttrName,
+                        convertArithOveflowAttrToLLVM(arithAttr));
+    }
+  }
+
+  ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+
 private:
   NamedAttrList convertedAttr;
 };
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 1e4061392b22d4..3fb7f948b0a45a 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -133,4 +133,27 @@ def Arith_FastMathAttr :
   let assemblyFormat = "`<` $value `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// IntegerOverflowFlags
+//===----------------------------------------------------------------------===//
+
+def IOFnone : I32BitEnumAttrCaseNone<"none">;
+def IOFnsw  : I32BitEnumAttrCaseBit<"nsw", 0>;
+def IOFnuw  : I32BitEnumAttrCaseBit<"nuw", 1>;
+
+def IntegerOverflowFlags : I32BitEnumAttr<
+    "IntegerOverflowFlags",
+    "Integer overflow arith flags",
+    [IOFnone, IOFnsw, IOFnuw]> {
+  let separator = ", ";
+  let cppNamespace = "::mlir::arith";
+  let genSpecializedAttr = 0;
+  let printBitEnumPrimaryGroups = 1;
+}
+
+def Arith_IntegerOverflowAttr :
+    EnumAttr<Arith_Dialect, IntegerOverflowFlags, "overflow"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 #endif // ARITH_BASE
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 6d133d69dd0faf..880718bca9e7ec 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -137,6 +137,22 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
   let results = (outs BoolLikeOfAnyRank:$result);
 }
 
+class Arith_IntArithmeticOpWithOverflowFlag<string mnemonic, list<Trait> traits = []> :
+    Arith_BinaryOp<mnemonic, traits #
+      [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
+      DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
+    Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
+      DefaultValuedAttr<
+        Arith_IntegerOverflowAttr, "::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
+    Results<(outs SignlessIntegerLike:$result)> {
+
+  let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
+                          attr-dict `:` type($result) }];
+
+  let hasFolder = 1;
+  let hasCanonicalizer = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -192,7 +208,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
 // AddIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
+def Arith_AddIOp : Arith_IntArithmeticOpWithOverflowFlag<"addi", [Commutative]> {
   let summary = "integer addition operation";
   let description = [{
     Performs N-bit addition on the operands. The operands are interpreted as 
@@ -278,7 +294,7 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
 // SubIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
+def Arith_SubIOp : Arith_IntArithmeticOpWithOverflowFlag<"subi"> {
   let summary = [{
     Integer subtraction operation.
   }];
@@ -302,7 +318,7 @@ def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
 // MulIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
+def Arith_MulIOp : Arith_IntArithmeticOpWithOverflowFlag<"muli", [Commutative]> {
   let summary = [{
     Integer multiplication operation.
   }];
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index acaecf6f409dcf..e248422f84db84 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -49,4 +49,61 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
   ];
 }
 
+def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
+  let description = [{
+    Access to op integer overflow flags.
+  }];
+
+  let cppNamespace = "::mlir::arith";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns an IntegerOverflowFlagsAttr attribute for the operation",
+      /*returnType=*/  "IntegerOverflowFlagsAttr",
+      /*methodName=*/  "getOverflowAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getOverflowFlagsAttr();
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Unsigned Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNoUnsignedWrap",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Signed Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNoSignedWrap",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
+      }]
+      >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the IntegerOveflowFlagsAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getIntegerOverflowAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "overflowFlags";
+      }]
+      >
+  ];
+}
+
 #endif // ARITH_OPS_INTERFACES
diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
index 8c5d76f9f2d72e..7ba12de122bb4d 100644
--- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
+++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
@@ -22,9 +22,9 @@ mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
       {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
       {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
       {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
-  for (auto fmfMap : flags) {
-    if (bitEnumContainsAny(arithFMF, fmfMap.first))
-      llvmFMF = llvmFMF | fmfMap.second;
+  for (auto [arithFlag, llvmFlag] : flags) {
+    if (bitEnumContainsAny(arithFMF, arithFlag))
+      llvmFMF = llvmFMF | llvmFlag;
   }
   return llvmFMF;
 }
@@ -36,3 +36,26 @@ mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
   return LLVM::FastmathFlagsAttr::get(
       fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
 }
+
+// Map arithmetic overflow enum values to LLVMIR enum values.
+LLVM::IntegerOverflowFlags mlir::arith::convertArithOveflowFlagsToLLVM(
+    arith::IntegerOverflowFlags arithFlags) {
+  LLVM::IntegerOverflowFlags llvmFlags{};
+  const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
+      flags[] = {
+          {arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
+          {arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
+  for (auto [arithFlag, llvmFlag] : flags) {
+    if (bitEnumContainsAny(arithFlags, arithFlag))
+      llvmFlags = llvmFlags | llvmFlag;
+  }
+  return llvmFlags;
+}
+
+// Create an LLVM overflow attribute from a given arithmetic overflow attribute.
+LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOveflowAttrToLLVM(
+    arith::IntegerOverflowFlagsAttr flagsAttr) {
+  arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
+  return LLVM::IntegerOverflowFlagsAttr::get(
+      flagsAttr.getContext(), convertArithOveflowFlagsToLLVM(arithFlags));
+}
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 5e4213cc4e874a..cf46e0d3ac46ac 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -35,7 +35,9 @@ namespace {
 using AddFOpLowering =
     VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
                                arith::AttrConvertFastMathToLLVM>;
-using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
+using AddIOpLowering =
+    VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
+                               arith::AttrConvertOverflowToLLVM>;
 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
 using BitcastOpLowering =
     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
@@ -78,7 +80,9 @@ using MinUIOpLowering =
 using MulFOpLowering =
     VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
                                arith::AttrConvertFastMathToLLVM>;
-using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
+using MulIOpLowering =
+    VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
+                               arith::AttrConvertOverflowToLLVM>;
 using NegFOpLowering =
     VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
                                arith::AttrConvertFastMathToLLVM>;
@@ -102,7 +106,9 @@ using SIToFPOpLowering =
 using SubFOpLowering =
     VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
                                arith::AttrConvertFastMathToLLVM>;
-using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
+using SubIOpLowering =
+    VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
+                               arith::AttrConvertOverflowToLLVM>;
 using TruncFOpLowering =
     VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
 using TruncIOpLowering =
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index ef951647ccd146..19f0c0aac31713 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,6 +24,8 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
 // Multiply two integer attributes and create a new one with the result.
 def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
 
+def DefOverflow : NativeCodeCall<"getDefOverflowFlags($_builder)">;
+
 class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 
 //===----------------------------------------------------------------------===//
@@ -36,23 +38,23 @@ class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 // addi(addi(x, c0), c1) -> addi(x, c0 + c1)
 def AddIAddConstant :
     Pat<(Arith_AddIOp:$res
-          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>;
+          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), (DefOverflow))>;
 
 // addi(subi(x, c0), c1) -> addi(x, c1 - c0)
 def AddISubConstantRHS :
     Pat<(Arith_AddIOp:$res
-          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;
+          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), (DefOverflow))>;
 
 // addi(subi(c0, x), c1) -> subi(c0 + c1, x)
 def AddISubConstantLHS :
     Pat<(Arith_AddIOp:$res
-          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
+          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x, (DefOverflow))>;
 
 def IsScalarOrSplatNegativeOne :
     Constraint<And<[
@@ -63,24 +65,24 @@ def IsScalarOrSplatNegativeOne :
 def AddIMulNegativeOneRhs :
     Pat<(Arith_AddIOp
            $x,
-           (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0))),
-        (Arith_SubIOp $x, $y),
+           (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
+        (Arith_SubIOp $x, $y, (DefOverflow)),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // addi(muli(x, -1), y) -> subi(y, x)
 def AddIMulNegativeOneLhs :
     Pat<(Arith_AddIOp
-           (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0)),
-           $y),
-        (Arith_SubIOp $y, $x),
+           (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
+           $y, $ovf2),
+        (Arith_SubIOp $y, $x, (DefOverflow)),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // muli(muli(x, c0), c1) -> muli(x, c0 * c1)
 def MulIMulIConstant :
     Pat<(Arith_MulIOp:$res
-          (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)))>;
+          (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)), (DefOverflow))>;
 
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
@@ -90,7 +92,7 @@ def MulIMulIConstant :
 // uses. Since the 'overflow' result is unused, any replacement value will do.
 def AddUIExtendedToAddI:
     Pattern<(Arith_AddUIExtendedOp:$res $x, $y),
-             [(Arith_AddIOp $x, $y), (replaceWithValue $x)],
+             [(Arith_AddIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
              [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 //===----------------------------------------------------------------------===//
@@ -100,49 +102,49 @@ def AddUIExtendedToAddI:
 // subi(addi(x, c0), c1) -> addi(x, c0 - c1)
 def SubIRHSAddConstant :
     Pat<(Arith_SubIOp:$res
-          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)))>;
+          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), (DefOverflow))>;
 
 // subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
 def SubILHSAddConstant :
     Pat<(Arith_SubIOp:$res
           (ConstantLikeMatcher APIntAttr:$c1),
-          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0))),
-        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x)>;
+          (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x, (DefOverflow))>;
 
 // subi(subi(x, c0), c1) -> subi(x, c0 + c1)
 def SubIRHSSubConstantRHS :
     Pat<(Arith_SubIOp:$res
-          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>;
+          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), (DefOverflow))>;
 
 // subi(subi(c0, x), c1) -> subi(c0 - c1, x)
 def SubIRHSSubConstantLHS :
     Pat<(Arith_SubIOp:$res
-          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x),
-          (ConstantLikeMatcher APIntAttr:$c1)),
-        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x)>;
+          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
+          (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x, (DefOverflow))>;
 
 // subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
 def SubILHSSubConstantRHS :
     Pat<(Arith_SubIOp:$res
           (ConstantLikeMatcher APIntAttr:$c1),
-          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0))),
-        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
+          (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x, (DefOverflow))>;
 
 // subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
 def SubILHSSubConstantLHS :
     Pat<(Arith_SubIOp:$res
           (ConstantLikeMatcher APIntAttr:$c1),
-          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x)),
-        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;
+          (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
+        (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), (DefOverflow))>;
 
 // subi(subi(a, b), a) -> subi(0, b)
 def SubISubILHSRHSLHS :
-    Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x),
-        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>;
+    Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
+        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, (DefOverflow))>;
 
 //===----------------------------------------------------------------------===//
 // MulSIExtendedOp
@@ -152,7 +154,7 @@ def SubISubILHSRHSLHS :
 // Since the `h...
[truncated]

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td Outdated Show resolved Hide resolved
Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td Outdated Show resolved Hide resolved
@tblah tblah requested a review from yi-wu-arm January 8, 2024 11:05
Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once existing comments are addressed. Thank you for working on this!

@yi-wu-arm
Copy link
Contributor

Don't know much about converting to llvmir, but LGTM for the Arith and canonicalization part.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one suggestion

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

I added two formatting suggestions for the tablegen changes. It seems like most of ArithOps.td stays within the 80 column limit.

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td Outdated Show resolved Hide resolved
@Hardcode84 Hardcode84 merged commit a7262d2 into llvm:main Jan 9, 2024
4 checks passed
@Hardcode84 Hardcode84 deleted the arith-overflow branch January 9, 2024 22:17
LLVM::FastmathFlagsAttr
convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);

/// Maps arithmetic overflow enum values to LLVM enum values.
LLVM::IntegerOverflowFlags
convertArithOveflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo, Oveflow -> Overflow

/// Creates an LLVM overflow attribute from a given arithmetic overflow
/// attribute.
LLVM::IntegerOverflowFlagsAttr
convertArithOveflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo, Oveflow -> Overflow

}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the IntegerOveflowFlagsAttr attribute
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo, Oveflow -> Overflow

Hardcode84 added a commit to Hardcode84/llvm-project that referenced this pull request Jan 10, 2024
Hardcode84 added a commit that referenced this pull request Jan 11, 2024
@@ -133,4 +133,27 @@ def Arith_FastMathAttr :
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// IntegerOverflowFlags
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to break the Python bindings, since now both the arith and llvm dialects define IntegerOverflowFlags:

Traceback (most recent call last):
  ...
File "[...]/mlir/dialects/_llvm_enum_gen.py", line 681, in <module>
    @register_attribute_builder("IntegerOverflowFlags")
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "[...]/mlir/ir.py", line 14, in decorator_builder
    AttrBuilder.insert(kind, func, replace=replace)
RuntimeError: Attribute builder for 'IntegerOverflowFlags' is already registered with func: <function _integeroverflowflags at 0x7faf26f8ce00>

The repro is

from mlir.dialects import arith, llvm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering why this wasn't the case for fastmath flags and they are actually named slightly differently between llvm and arith (FastMathFlags vs FastmathFlags).

Anyways, I'm not familiar with in-tree MLIR python bindings and considering these enums live in different namespaces on C++ level, this sounds like quite a big limitation, which people will continue to hit randomly.

Any ideas how to fix it properly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just need to namespace (by dialect) the generated enum bindings here and here. Will send a patch soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted for now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Independent of Maks upcoming change, it's good to "textually namespace" TableGen side too. That's the common case, but I see unfortunately only documented such for ops and here you were being consistent with how fast math flags were defined. (This reminds me of the ODS linter again).

Hardcode84 added a commit that referenced this pull request Jan 11, 2024
Temporarily reverting as it broke python bindings

This reverts commit a7262d2.
jpienaar pushed a commit to jpienaar/llvm-project that referenced this pull request Jan 17, 2024
Add overflow flags support to the following ops:
* `arith.addi`
* `arith.subi`
* `arith.muli`

Example of new syntax:
```
%res = arith.addi %arg1, %arg2 overflow<nsw> : i64
```
Similar to existing LLVM dialect syntax
```
%res = llvm.add %arg1, %arg2 overflow<nsw> : i64
```

Tablegen canonicalization patterns updated to always drop flags, proper
support with tests will be added later.

Updated LLVMIR translation as part of this commit as it currenly written
in a way that it will crash when new attributes added to arith ops
otherwise.

Also lower `arith` overflow flags to corresponding SPIR-V op decorations

Discussion
https://discourse.llvm.org/t/rfc-integer-overflow-flags-support-in-arith-dialect/76025

This effectively rolls forward llvm#77211, llvm#77700 and llvm#77714 while adding a
test to ensure the Python usage is not broken. More follow up needed but
unrelated to the core change here. The changes here are minimal and just
correspond to "textual namespacing" ODS side, no C++ or Python changes
were needed.

---------

Co-authored-by: Ivan Butygin <ivan.butygin@gmail.com>, Yi Wu <yi.wu2@arm.com>
Hardcode84 pushed a commit that referenced this pull request Jan 17, 2024
Add overflow flags support to the following ops:
* `arith.addi`
* `arith.subi`
* `arith.muli`

Example of new syntax:
```
%res = arith.addi %arg1, %arg2 overflow<nsw> : i64
```
Similar to existing LLVM dialect syntax
```
%res = llvm.add %arg1, %arg2 overflow<nsw> : i64
```

Tablegen canonicalization patterns updated to always drop flags, proper
support with tests will be added later.

Updated LLVMIR translation as part of this commit as it currenly written
in a way that it will crash when new attributes added to arith ops
otherwise.

Also lower `arith` overflow flags to corresponding SPIR-V op decorations

Discussion

https://discourse.llvm.org/t/rfc-integer-overflow-flags-support-in-arith-dialect/76025

This effectively rolls forward #77211, #77700 and #77714 while adding a
test to ensure the Python usage is not broken. More follow up needed but
unrelated to the core change here. The changes here are minimal and just
correspond to "textual namespacing" ODS side, no C++ or Python changes
were needed.

---------

---------

Co-authored-by: Ivan Butygin <ivan.butygin@gmail.com>, Yi Wu <yi.wu2@arm.com>
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
Add overflow flags support to the following ops:
* `arith.addi`
* `arith.subi`
* `arith.muli`

Example of new syntax:
```
%res = arith.addi %arg1, %arg2 overflow<nsw> : i64
```
Similar to existing LLVM dialect syntax
```
%res = llvm.add %arg1, %arg2 overflow<nsw> : i64
``` 

Tablegen canonicalization patterns updated to always drop flags, proper
support with tests will be added later.

Updated LLVMIR translation as part of this commit as it currenly written
in a way that it will crash when new attributes added to arith ops
otherwise.

Discussion
https://discourse.llvm.org/t/rfc-integer-overflow-flags-support-in-arith-dialect/76025

---------

Co-authored-by: Yi Wu <yi.wu2@arm.com>
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…7211)"

Temporarily reverting as it broke python bindings

This reverts commit a7262d2.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants