Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][arith] Add overflow flags support to arith ops #77211

Merged
merged 7 commits into from Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 42 additions & 5 deletions mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
Expand Up @@ -18,14 +18,24 @@

namespace mlir {
namespace arith {
// Map arithmetic fastmath enum values to LLVMIR enum values.
/// Maps arithmetic fastmath enum values to LLVM enum values.
LLVM::FastmathFlags
convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);

// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
/// Creates an LLVM fastmath attribute from a given arithmetic fastmath
/// attribute.
LLVM::FastmathFlagsAttr
convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);

/// Maps arithmetic overflow enum values to LLVM enum values.
LLVM::IntegerOverflowFlags
convertArithOveflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo, Oveflow -> Overflow


/// Creates an LLVM overflow attribute from a given arithmetic overflow
/// attribute.
LLVM::IntegerOverflowFlagsAttr
convertArithOveflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo, Oveflow -> Overflow


// Attribute converter that populates a NamedAttrList by removing the fastmath
// attribute from the source operation attributes, and replacing it with an
// equivalent LLVM fastmath attribute.
Expand All @@ -36,19 +46,46 @@ class AttrConvertFastMathToLLVM {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith fastmath attribute.
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
// Remove the source fastmath attribute.
auto arithFMFAttr = dyn_cast_or_null<arith::FastMathFlagsAttr>(
auto arithFMFAttr = dyn_cast_if_present<arith::FastMathFlagsAttr>(
convertedAttr.erase(arithFMFAttrName));
if (arithFMFAttr) {
llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
StringRef targetAttrName = TargetOp::getFastmathAttrName();
convertedAttr.set(targetAttrName,
convertArithFastMathAttrToLLVM(arithFMFAttr));
}
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }

private:
NamedAttrList convertedAttr;
};

// Attribute converter that populates a NamedAttrList by removing the overflow
// attribute from the source operation attributes, and replacing it with an
// equivalent LLVM overflow attribute.
template <typename SourceOp, typename TargetOp>
class AttrConvertOverflowToLLVM {
public:
AttrConvertOverflowToLLVM(SourceOp srcOp) {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith overflow attribute.
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
// Remove the source overflow attribute.
auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
convertedAttr.erase(arithAttrName));
if (arithAttr) {
StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
convertedAttr.set(targetAttrName,
convertArithOveflowAttrToLLVM(arithAttr));
}
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }

private:
NamedAttrList convertedAttr;
};
Expand Down
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
Expand Up @@ -133,4 +133,27 @@ def Arith_FastMathAttr :
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// IntegerOverflowFlags
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to break the Python bindings, since now both the arith and llvm dialects define IntegerOverflowFlags:

Traceback (most recent call last):
  ...
File "[...]/mlir/dialects/_llvm_enum_gen.py", line 681, in <module>
    @register_attribute_builder("IntegerOverflowFlags")
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "[...]/mlir/ir.py", line 14, in decorator_builder
    AttrBuilder.insert(kind, func, replace=replace)
RuntimeError: Attribute builder for 'IntegerOverflowFlags' is already registered with func: <function _integeroverflowflags at 0x7faf26f8ce00>

The repro is

from mlir.dialects import arith, llvm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering why this wasn't the case for fastmath flags and they are actually named slightly differently between llvm and arith (FastMathFlags vs FastmathFlags).

Anyways, I'm not familiar with in-tree MLIR python bindings and considering these enums live in different namespaces on C++ level, this sounds like quite a big limitation, which people will continue to hit randomly.

Any ideas how to fix it properly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just need to namespace (by dialect) the generated enum bindings here and here. Will send a patch soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted for now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Independent of Maks upcoming change, it's good to "textually namespace" TableGen side too. That's the common case, but I see unfortunately only documented such for ops and here you were being consistent with how fast math flags were defined. (This reminds me of the ODS linter again).

//===----------------------------------------------------------------------===//

def IOFnone : I32BitEnumAttrCaseNone<"none">;
def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;

def IntegerOverflowFlags : I32BitEnumAttr<
"IntegerOverflowFlags",
"Integer overflow arith flags",
[IOFnone, IOFnsw, IOFnuw]> {
let separator = ", ";
let cppNamespace = "::mlir::arith";
let genSpecializedAttr = 0;
let printBitEnumPrimaryGroups = 1;
}

def Arith_IntegerOverflowAttr :
EnumAttr<Arith_Dialect, IntegerOverflowFlags, "overflow"> {
let assemblyFormat = "`<` $value `>`";
}

#endif // ARITH_BASE
101 changes: 81 additions & 20 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Expand Up @@ -137,6 +137,20 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
let results = (outs BoolLikeOfAnyRank:$result);
}

class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
DefaultValuedAttr<
Arith_IntegerOverflowAttr,
"::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
Results<(outs SignlessIntegerLike:$result)> {

let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
attr-dict `:` type($result) }];
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -192,7 +206,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
// AddIOp
//===----------------------------------------------------------------------===//

def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
let summary = "integer addition operation";
let description = [{
Performs N-bit addition on the operands. The operands are interpreted as
Expand All @@ -203,16 +217,23 @@ def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {

The `addi` operation takes two operands and returns one result, each of
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers. It has no
standard attributes.
a vector whose element type is integer, or a tensor of integers.

This op supports `nuw`/`nsw` overflow flags which stands stand for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.

Example:

```mlir
// Scalar addition.
%a = arith.addi %b, %c : i64

// SIMD vector element-wise addition, e.g. for Intel SSE.
// Scalar addition with overflow flags.
%a = arith.addi %b, %c overflow<nsw, nuw> : i64

// SIMD vector element-wise addition.
%f = arith.addi %g, %h : vector<4xi32>

// Tensor element-wise addition.
Expand Down Expand Up @@ -278,21 +299,41 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
// SubIOp
//===----------------------------------------------------------------------===//

def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
let summary = [{
Integer subtraction operation.
}];
let description = [{
Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
bitvectors. The result is represented by a bitvector containing the mathematical
value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
integers use a two's complement representation, this operation is applicable on
Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
bitvectors. The result is represented by a bitvector containing the mathematical
value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
integers use a two's complement representation, this operation is applicable on
both signed and unsigned integer operands.

The `subi` operation takes two operands and returns one result, each of
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers. It has no
standard attributes.
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers.

This op supports `nuw`/`nsw` overflow flags which stands stand for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.

Example:

```mlir
// Scalar subtraction.
%a = arith.subi %b, %c : i64

// Scalar subtraction with overflow flags.
%a = arith.subi %b, %c overflow<nsw, nuw> : i64

// SIMD vector element-wise subtraction.
%f = arith.subi %g, %h : vector<4xi32>

// Tensor element-wise subtraction.
%x = arith.subi %y, %z : tensor<4x?xi8>
```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
Expand All @@ -302,21 +343,41 @@ def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
// MulIOp
//===----------------------------------------------------------------------===//

def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli", [Commutative]> {
let summary = [{
Integer multiplication operation.
}];
let description = [{
Performs N-bit multiplication on the operands. The operands are interpreted as
unsigned bitvectors. The result is represented by a bitvector containing the
mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
Because `arith` integers use a two's complement representation, this operation is
Performs N-bit multiplication on the operands. The operands are interpreted as
unsigned bitvectors. The result is represented by a bitvector containing the
mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
Because `arith` integers use a two's complement representation, this operation is
applicable on both signed and unsigned integer operands.

The `muli` operation takes two operands and returns one result, each of
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers. It has no
standard attributes.
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers.

This op supports `nuw`/`nsw` overflow flags which stands stand for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.

Example:

```mlir
// Scalar multiplication.
%a = arith.muli %b, %c : i64

// Scalar multiplication with overflow flags.
%a = arith.muli %b, %c overflow<nsw, nuw> : i64

// SIMD vector element-wise multiplication.
%f = arith.muli %g, %h : vector<4xi32>

// Tensor element-wise multiplication.
%x = arith.muli %y, %z : tensor<4x?xi8>
```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
Expand Down
57 changes: 57 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
Expand Up @@ -49,4 +49,61 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
];
}

def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
let description = [{
Access to op integer overflow flags.
}];

let cppNamespace = "::mlir::arith";

let methods = [
InterfaceMethod<
/*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation",
/*returnType=*/ "IntegerOverflowFlagsAttr",
/*methodName=*/ "getOverflowAttr",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getOverflowFlagsAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
/*returnType=*/ "bool",
/*methodName=*/ "hasNoUnsignedWrap",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
}]
>,
InterfaceMethod<
/*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
/*returnType=*/ "bool",
/*methodName=*/ "hasNoSignedWrap",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the IntegerOveflowFlagsAttr attribute
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo, Oveflow -> Overflow

for the operation}],
/*returnType=*/ "StringRef",
/*methodName=*/ "getIntegerOverflowAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
return "overflowFlags";
}]
>
];
}

#endif // ARITH_OPS_INTERFACES
29 changes: 24 additions & 5 deletions mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
Expand Up @@ -10,7 +10,6 @@

using namespace mlir;

// Map arithmetic fastmath enum values to LLVMIR enum values.
LLVM::FastmathFlags
mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
LLVM::FastmathFlags llvmFMF{};
Expand All @@ -22,17 +21,37 @@ mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
{arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
{arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
{arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
for (auto fmfMap : flags) {
if (bitEnumContainsAny(arithFMF, fmfMap.first))
llvmFMF = llvmFMF | fmfMap.second;
for (auto [arithFlag, llvmFlag] : flags) {
if (bitEnumContainsAny(arithFMF, arithFlag))
llvmFMF = llvmFMF | llvmFlag;
}
return llvmFMF;
}

// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
LLVM::FastmathFlagsAttr
mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
arith::FastMathFlags arithFMF = fmfAttr.getValue();
return LLVM::FastmathFlagsAttr::get(
fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
}

LLVM::IntegerOverflowFlags mlir::arith::convertArithOveflowFlagsToLLVM(
arith::IntegerOverflowFlags arithFlags) {
LLVM::IntegerOverflowFlags llvmFlags{};
const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
flags[] = {
{arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
{arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
for (auto [arithFlag, llvmFlag] : flags) {
if (bitEnumContainsAny(arithFlags, arithFlag))
llvmFlags = llvmFlags | llvmFlag;
}
return llvmFlags;
}

LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOveflowAttrToLLVM(
arith::IntegerOverflowFlagsAttr flagsAttr) {
arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
return LLVM::IntegerOverflowFlagsAttr::get(
flagsAttr.getContext(), convertArithOveflowFlagsToLLVM(arithFlags));
}