Skip to content

Commit

Permalink
[mlir][arith] Add overflow flags support to arith ops (#77211)
Browse files Browse the repository at this point in the history
Add overflow flags support to the following ops:
* `arith.addi`
* `arith.subi`
* `arith.muli`

Example of new syntax:
```
%res = arith.addi %arg1, %arg2 overflow<nsw> : i64
```
Similar to existing LLVM dialect syntax
```
%res = llvm.add %arg1, %arg2 overflow<nsw> : i64
``` 

Tablegen canonicalization patterns updated to always drop flags, proper
support with tests will be added later.

Updated LLVMIR translation as part of this commit as it currenly written
in a way that it will crash when new attributes added to arith ops
otherwise.

Discussion
https://discourse.llvm.org/t/rfc-integer-overflow-flags-support-in-arith-dialect/76025

---------

Co-authored-by: Yi Wu <yi.wu2@arm.com>
  • Loading branch information
Hardcode84 and yi-wu-arm committed Jan 9, 2024
1 parent b5d4332 commit a7262d2
Show file tree
Hide file tree
Showing 11 changed files with 321 additions and 73 deletions.
47 changes: 42 additions & 5 deletions mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
Expand Up @@ -18,14 +18,24 @@

namespace mlir {
namespace arith {
// Map arithmetic fastmath enum values to LLVMIR enum values.
/// Maps arithmetic fastmath enum values to LLVM enum values.
LLVM::FastmathFlags
convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);

// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
/// Creates an LLVM fastmath attribute from a given arithmetic fastmath
/// attribute.
LLVM::FastmathFlagsAttr
convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);

/// Maps arithmetic overflow enum values to LLVM enum values.
LLVM::IntegerOverflowFlags
convertArithOveflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);

/// Creates an LLVM overflow attribute from a given arithmetic overflow
/// attribute.
LLVM::IntegerOverflowFlagsAttr
convertArithOveflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);

// Attribute converter that populates a NamedAttrList by removing the fastmath
// attribute from the source operation attributes, and replacing it with an
// equivalent LLVM fastmath attribute.
Expand All @@ -36,19 +46,46 @@ class AttrConvertFastMathToLLVM {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith fastmath attribute.
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
// Remove the source fastmath attribute.
auto arithFMFAttr = dyn_cast_or_null<arith::FastMathFlagsAttr>(
auto arithFMFAttr = dyn_cast_if_present<arith::FastMathFlagsAttr>(
convertedAttr.erase(arithFMFAttrName));
if (arithFMFAttr) {
llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
StringRef targetAttrName = TargetOp::getFastmathAttrName();
convertedAttr.set(targetAttrName,
convertArithFastMathAttrToLLVM(arithFMFAttr));
}
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }

private:
NamedAttrList convertedAttr;
};

// Attribute converter that populates a NamedAttrList by removing the overflow
// attribute from the source operation attributes, and replacing it with an
// equivalent LLVM overflow attribute.
template <typename SourceOp, typename TargetOp>
class AttrConvertOverflowToLLVM {
public:
AttrConvertOverflowToLLVM(SourceOp srcOp) {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
// Get the name of the arith overflow attribute.
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
// Remove the source overflow attribute.
auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
convertedAttr.erase(arithAttrName));
if (arithAttr) {
StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
convertedAttr.set(targetAttrName,
convertArithOveflowAttrToLLVM(arithAttr));
}
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }

private:
NamedAttrList convertedAttr;
};
Expand Down
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
Expand Up @@ -133,4 +133,27 @@ def Arith_FastMathAttr :
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// IntegerOverflowFlags
//===----------------------------------------------------------------------===//

def IOFnone : I32BitEnumAttrCaseNone<"none">;
def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;

def IntegerOverflowFlags : I32BitEnumAttr<
"IntegerOverflowFlags",
"Integer overflow arith flags",
[IOFnone, IOFnsw, IOFnuw]> {
let separator = ", ";
let cppNamespace = "::mlir::arith";
let genSpecializedAttr = 0;
let printBitEnumPrimaryGroups = 1;
}

def Arith_IntegerOverflowAttr :
EnumAttr<Arith_Dialect, IntegerOverflowFlags, "overflow"> {
let assemblyFormat = "`<` $value `>`";
}

#endif // ARITH_BASE
101 changes: 81 additions & 20 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Expand Up @@ -137,6 +137,20 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
let results = (outs BoolLikeOfAnyRank:$result);
}

class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
DefaultValuedAttr<
Arith_IntegerOverflowAttr,
"::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
Results<(outs SignlessIntegerLike:$result)> {

let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
attr-dict `:` type($result) }];
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -192,7 +206,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
// AddIOp
//===----------------------------------------------------------------------===//

def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
let summary = "integer addition operation";
let description = [{
Performs N-bit addition on the operands. The operands are interpreted as
Expand All @@ -203,16 +217,23 @@ def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {

The `addi` operation takes two operands and returns one result, each of
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers. It has no
standard attributes.
a vector whose element type is integer, or a tensor of integers.

This op supports `nuw`/`nsw` overflow flags which stands stand for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.

Example:

```mlir
// Scalar addition.
%a = arith.addi %b, %c : i64

// SIMD vector element-wise addition, e.g. for Intel SSE.
// Scalar addition with overflow flags.
%a = arith.addi %b, %c overflow<nsw, nuw> : i64

// SIMD vector element-wise addition.
%f = arith.addi %g, %h : vector<4xi32>

// Tensor element-wise addition.
Expand Down Expand Up @@ -278,21 +299,41 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
// SubIOp
//===----------------------------------------------------------------------===//

def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
let summary = [{
Integer subtraction operation.
}];
let description = [{
Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
bitvectors. The result is represented by a bitvector containing the mathematical
value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
integers use a two's complement representation, this operation is applicable on
Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
bitvectors. The result is represented by a bitvector containing the mathematical
value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
integers use a two's complement representation, this operation is applicable on
both signed and unsigned integer operands.

The `subi` operation takes two operands and returns one result, each of
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers. It has no
standard attributes.
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers.

This op supports `nuw`/`nsw` overflow flags which stands stand for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.

Example:

```mlir
// Scalar subtraction.
%a = arith.subi %b, %c : i64

// Scalar subtraction with overflow flags.
%a = arith.subi %b, %c overflow<nsw, nuw> : i64

// SIMD vector element-wise subtraction.
%f = arith.subi %g, %h : vector<4xi32>

// Tensor element-wise subtraction.
%x = arith.subi %y, %z : tensor<4x?xi8>
```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
Expand All @@ -302,21 +343,41 @@ def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
// MulIOp
//===----------------------------------------------------------------------===//

def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli", [Commutative]> {
let summary = [{
Integer multiplication operation.
}];
let description = [{
Performs N-bit multiplication on the operands. The operands are interpreted as
unsigned bitvectors. The result is represented by a bitvector containing the
mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
Because `arith` integers use a two's complement representation, this operation is
Performs N-bit multiplication on the operands. The operands are interpreted as
unsigned bitvectors. The result is represented by a bitvector containing the
mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
Because `arith` integers use a two's complement representation, this operation is
applicable on both signed and unsigned integer operands.

The `muli` operation takes two operands and returns one result, each of
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers. It has no
standard attributes.
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers.

This op supports `nuw`/`nsw` overflow flags which stands stand for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.

Example:

```mlir
// Scalar multiplication.
%a = arith.muli %b, %c : i64

// Scalar multiplication with overflow flags.
%a = arith.muli %b, %c overflow<nsw, nuw> : i64

// SIMD vector element-wise multiplication.
%f = arith.muli %g, %h : vector<4xi32>

// Tensor element-wise multiplication.
%x = arith.muli %y, %z : tensor<4x?xi8>
```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
Expand Down
57 changes: 57 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
Expand Up @@ -49,4 +49,61 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
];
}

def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
let description = [{
Access to op integer overflow flags.
}];

let cppNamespace = "::mlir::arith";

let methods = [
InterfaceMethod<
/*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation",
/*returnType=*/ "IntegerOverflowFlagsAttr",
/*methodName=*/ "getOverflowAttr",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getOverflowFlagsAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
/*returnType=*/ "bool",
/*methodName=*/ "hasNoUnsignedWrap",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
}]
>,
InterfaceMethod<
/*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
/*returnType=*/ "bool",
/*methodName=*/ "hasNoSignedWrap",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the IntegerOveflowFlagsAttr attribute
for the operation}],
/*returnType=*/ "StringRef",
/*methodName=*/ "getIntegerOverflowAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
return "overflowFlags";
}]
>
];
}

#endif // ARITH_OPS_INTERFACES
29 changes: 24 additions & 5 deletions mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
Expand Up @@ -10,7 +10,6 @@

using namespace mlir;

// Map arithmetic fastmath enum values to LLVMIR enum values.
LLVM::FastmathFlags
mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
LLVM::FastmathFlags llvmFMF{};
Expand All @@ -22,17 +21,37 @@ mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
{arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
{arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
{arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
for (auto fmfMap : flags) {
if (bitEnumContainsAny(arithFMF, fmfMap.first))
llvmFMF = llvmFMF | fmfMap.second;
for (auto [arithFlag, llvmFlag] : flags) {
if (bitEnumContainsAny(arithFMF, arithFlag))
llvmFMF = llvmFMF | llvmFlag;
}
return llvmFMF;
}

// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
LLVM::FastmathFlagsAttr
mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
arith::FastMathFlags arithFMF = fmfAttr.getValue();
return LLVM::FastmathFlagsAttr::get(
fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
}

LLVM::IntegerOverflowFlags mlir::arith::convertArithOveflowFlagsToLLVM(
arith::IntegerOverflowFlags arithFlags) {
LLVM::IntegerOverflowFlags llvmFlags{};
const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
flags[] = {
{arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
{arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
for (auto [arithFlag, llvmFlag] : flags) {
if (bitEnumContainsAny(arithFlags, arithFlag))
llvmFlags = llvmFlags | llvmFlag;
}
return llvmFlags;
}

LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOveflowAttrToLLVM(
arith::IntegerOverflowFlagsAttr flagsAttr) {
arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
return LLVM::IntegerOverflowFlagsAttr::get(
flagsAttr.getContext(), convertArithOveflowFlagsToLLVM(arithFlags));
}

0 comments on commit a7262d2

Please sign in to comment.