New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][arith] Add overflow flags support to arith ops #77211
Changes from all commits
0c1f04d
82eeda0
3d09907
4c6f3ac
c36ae9d
b69c5b8
750e0f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,14 +18,24 @@ | |
|
||
namespace mlir { | ||
namespace arith { | ||
// Map arithmetic fastmath enum values to LLVMIR enum values. | ||
/// Maps arithmetic fastmath enum values to LLVM enum values. | ||
LLVM::FastmathFlags | ||
convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF); | ||
|
||
// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute. | ||
/// Creates an LLVM fastmath attribute from a given arithmetic fastmath | ||
/// attribute. | ||
LLVM::FastmathFlagsAttr | ||
convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr); | ||
|
||
/// Maps arithmetic overflow enum values to LLVM enum values. | ||
LLVM::IntegerOverflowFlags | ||
convertArithOveflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags); | ||
|
||
/// Creates an LLVM overflow attribute from a given arithmetic overflow | ||
/// attribute. | ||
LLVM::IntegerOverflowFlagsAttr | ||
convertArithOveflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: typo, Oveflow -> Overflow |
||
|
||
// Attribute converter that populates a NamedAttrList by removing the fastmath | ||
// attribute from the source operation attributes, and replacing it with an | ||
// equivalent LLVM fastmath attribute. | ||
|
@@ -36,19 +46,46 @@ class AttrConvertFastMathToLLVM { | |
// Copy the source attributes. | ||
convertedAttr = NamedAttrList{srcOp->getAttrs()}; | ||
// Get the name of the arith fastmath attribute. | ||
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); | ||
StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); | ||
// Remove the source fastmath attribute. | ||
auto arithFMFAttr = dyn_cast_or_null<arith::FastMathFlagsAttr>( | ||
auto arithFMFAttr = dyn_cast_if_present<arith::FastMathFlagsAttr>( | ||
convertedAttr.erase(arithFMFAttrName)); | ||
if (arithFMFAttr) { | ||
llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName(); | ||
StringRef targetAttrName = TargetOp::getFastmathAttrName(); | ||
convertedAttr.set(targetAttrName, | ||
convertArithFastMathAttrToLLVM(arithFMFAttr)); | ||
} | ||
} | ||
|
||
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } | ||
|
||
private: | ||
NamedAttrList convertedAttr; | ||
}; | ||
|
||
// Attribute converter that populates a NamedAttrList by removing the overflow | ||
// attribute from the source operation attributes, and replacing it with an | ||
// equivalent LLVM overflow attribute. | ||
template <typename SourceOp, typename TargetOp> | ||
class AttrConvertOverflowToLLVM { | ||
public: | ||
AttrConvertOverflowToLLVM(SourceOp srcOp) { | ||
// Copy the source attributes. | ||
convertedAttr = NamedAttrList{srcOp->getAttrs()}; | ||
// Get the name of the arith overflow attribute. | ||
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName(); | ||
// Remove the source overflow attribute. | ||
auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>( | ||
convertedAttr.erase(arithAttrName)); | ||
if (arithAttr) { | ||
StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName(); | ||
convertedAttr.set(targetAttrName, | ||
convertArithOveflowAttrToLLVM(arithAttr)); | ||
} | ||
} | ||
|
||
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } | ||
|
||
private: | ||
NamedAttrList convertedAttr; | ||
}; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,4 +133,27 @@ def Arith_FastMathAttr : | |
let assemblyFormat = "`<` $value `>`"; | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// IntegerOverflowFlags | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to break the Python bindings, since now both the arith and llvm dialects define
The repro is
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was wondering why this wasn't the case for fastmath flags and they are actually named slightly differently between llvm and arith ( Anyways, I'm not familiar with in-tree MLIR python bindings and considering these enums live in different namespaces on C++ level, this sounds like quite a big limitation, which people will continue to hit randomly. Any ideas how to fix it properly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reverted for now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Independent of Maks upcoming change, it's good to "textually namespace" TableGen side too. That's the common case, but I see unfortunately only documented such for ops and here you were being consistent with how fast math flags were defined. (This reminds me of the ODS linter again). |
||
//===----------------------------------------------------------------------===// | ||
|
||
def IOFnone : I32BitEnumAttrCaseNone<"none">; | ||
def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>; | ||
def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>; | ||
|
||
def IntegerOverflowFlags : I32BitEnumAttr< | ||
"IntegerOverflowFlags", | ||
"Integer overflow arith flags", | ||
[IOFnone, IOFnsw, IOFnuw]> { | ||
let separator = ", "; | ||
let cppNamespace = "::mlir::arith"; | ||
let genSpecializedAttr = 0; | ||
let printBitEnumPrimaryGroups = 1; | ||
} | ||
|
||
def Arith_IntegerOverflowAttr : | ||
EnumAttr<Arith_Dialect, IntegerOverflowFlags, "overflow"> { | ||
let assemblyFormat = "`<` $value `>`"; | ||
} | ||
|
||
#endif // ARITH_BASE |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,4 +49,61 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> { | |
]; | ||
} | ||
|
||
def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> { | ||
let description = [{ | ||
Access to op integer overflow flags. | ||
}]; | ||
|
||
let cppNamespace = "::mlir::arith"; | ||
|
||
let methods = [ | ||
InterfaceMethod< | ||
/*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation", | ||
/*returnType=*/ "IntegerOverflowFlagsAttr", | ||
/*methodName=*/ "getOverflowAttr", | ||
/*args=*/ (ins), | ||
/*methodBody=*/ [{}], | ||
/*defaultImpl=*/ [{ | ||
auto op = cast<ConcreteOp>(this->getOperation()); | ||
return op.getOverflowFlagsAttr(); | ||
}] | ||
>, | ||
InterfaceMethod< | ||
/*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword", | ||
/*returnType=*/ "bool", | ||
/*methodName=*/ "hasNoUnsignedWrap", | ||
/*args=*/ (ins), | ||
/*methodBody=*/ [{}], | ||
/*defaultImpl=*/ [{ | ||
auto op = cast<ConcreteOp>(this->getOperation()); | ||
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue(); | ||
return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw); | ||
}] | ||
>, | ||
InterfaceMethod< | ||
/*desc=*/ "Returns whether the operation has the No Signed Wrap keyword", | ||
/*returnType=*/ "bool", | ||
/*methodName=*/ "hasNoSignedWrap", | ||
/*args=*/ (ins), | ||
/*methodBody=*/ [{}], | ||
/*defaultImpl=*/ [{ | ||
auto op = cast<ConcreteOp>(this->getOperation()); | ||
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue(); | ||
return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw); | ||
}] | ||
>, | ||
StaticInterfaceMethod< | ||
/*desc=*/ [{Returns the name of the IntegerOveflowFlagsAttr attribute | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: typo, Oveflow -> Overflow |
||
for the operation}], | ||
/*returnType=*/ "StringRef", | ||
/*methodName=*/ "getIntegerOverflowAttrName", | ||
/*args=*/ (ins), | ||
/*methodBody=*/ [{}], | ||
/*defaultImpl=*/ [{ | ||
return "overflowFlags"; | ||
}] | ||
> | ||
]; | ||
} | ||
|
||
#endif // ARITH_OPS_INTERFACES |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: typo, Oveflow -> Overflow