diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h index 7ffc8613317603..0891e2ba7be760 100644 --- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h +++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h @@ -31,6 +31,11 @@ convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr); LLVM::IntegerOverflowFlags convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags); +/// Creates an LLVM overflow attribute from a given arithmetic overflow +/// attribute. +LLVM::IntegerOverflowFlagsAttr +convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr); + /// Creates an LLVM rounding mode enum value from a given arithmetic rounding /// mode enum value. LLVM::RoundingMode @@ -67,9 +72,6 @@ class AttrConvertFastMathToLLVM { } ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } - LLVM::IntegerOverflowFlags getOverflowFlags() const { - return LLVM::IntegerOverflowFlags::none; - } private: NamedAttrList convertedAttr; @@ -87,18 +89,19 @@ class AttrConvertOverflowToLLVM { // Get the name of the arith overflow attribute. StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName(); // Remove the source overflow attribute. - if (auto arithAttr = dyn_cast_if_present( - convertedAttr.erase(arithAttrName))) { - overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue()); + auto arithAttr = dyn_cast_if_present( + convertedAttr.erase(arithAttrName)); + if (arithAttr) { + StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName(); + convertedAttr.set(targetAttrName, + convertArithOverflowAttrToLLVM(arithAttr)); } } ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } - LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; } private: NamedAttrList convertedAttr; - LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none; }; template @@ -129,9 +132,6 @@ class AttrConverterConstrainedFPToLLVM { } ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } - LLVM::IntegerOverflowFlags getOverflowFlags() const { - return LLVM::IntegerOverflowFlags::none; - } private: NamedAttrList convertedAttr; diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index f3bf5b66398e09..f362167ee93249 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -11,7 +11,6 @@ #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -19,16 +18,13 @@ class CallOpInterface; namespace LLVM { namespace detail { -/// Handle generically setting flags as native properties on LLVM operations. -void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags); - /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. -LogicalResult oneToOneRewrite( - Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none); +LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, + ValueRange operands, + ArrayRef targetAttrs, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter); } // namespace detail } // namespace LLVM diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index 964281592cc65e..279175b6128fc7 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -54,11 +54,11 @@ LogicalResult handleMultidimensionalVectors( std::function createOperand, ConversionPatternRewriter &rewriter); -LogicalResult vectorOneToOneRewrite( - Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none); +LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, + ValueRange operands, + ArrayRef targetAttrs, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter); } // namespace detail } // namespace LLVM @@ -70,9 +70,6 @@ class AttrConvertPassThrough { AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {} ArrayRef getAttrs() const { return srcAttrs; } - LLVM::IntegerOverflowFlags getOverflowFlags() const { - return LLVM::IntegerOverflowFlags::none; - } private: ArrayRef srcAttrs; @@ -103,8 +100,7 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { return LLVM::detail::vectorOneToOneRewrite( op, TargetOp::getOperationName(), adaptor.getOperands(), - attrConvert.getAttrs(), *this->getTypeConverter(), rewriter, - attrConvert.getOverflowFlags()); + attrConvert.getAttrs(), *this->getTypeConverter(), rewriter); } }; } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td index 7085f81e203a1e..cee752aeb269b7 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td @@ -50,40 +50,58 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> { def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> { let description = [{ - This interface defines an LLVM operation with integer overflow flags and - provides a uniform API for accessing them. + Access to op integer overflow flags. }]; let cppNamespace = "::mlir::LLVM"; let methods = [ - InterfaceMethod<[{ - Get the integer overflow flags for the operation. - }], "IntegerOverflowFlags", "getOverflowFlags", (ins), [{}], [{ - return $_op.getProperties().overflowFlags; - }]>, - InterfaceMethod<[{ - Set the integer overflow flags for the operation. - }], "void", "setOverflowFlags", (ins "IntegerOverflowFlags":$flags), [{}], [{ - $_op.getProperties().overflowFlags = flags; - }]>, - InterfaceMethod<[{ - Returns whether the operation has the No Unsigned Wrap keyword. - }], "bool", "hasNoUnsignedWrap", (ins), [{}], [{ - return bitEnumContainsAll($_op.getOverflowFlags(), - IntegerOverflowFlags::nuw); - }]>, - InterfaceMethod<[{ - Returns whether the operation has the No Signed Wrap keyword. - }], "bool", "hasNoSignedWrap", (ins), [{}], [{ - return bitEnumContainsAll($_op.getOverflowFlags(), - IntegerOverflowFlags::nsw); - }]>, - StaticInterfaceMethod<[{ - Get the attribute name of the overflow flags property. - }], "StringRef", "getOverflowFlagsAttrName", (ins), [{}], [{ - return "overflowFlags"; - }]>, + InterfaceMethod< + /*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation", + /*returnType=*/ "IntegerOverflowFlagsAttr", + /*methodName=*/ "getOverflowAttr", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + return op.getOverflowFlagsAttr(); + }] + >, + InterfaceMethod< + /*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword", + /*returnType=*/ "bool", + /*methodName=*/ "hasNoUnsignedWrap", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue(); + return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw); + }] + >, + InterfaceMethod< + /*desc=*/ "Returns whether the operation has the No Signed Wrap keyword", + /*returnType=*/ "bool", + /*methodName=*/ "hasNoSignedWrap", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue(); + return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw); + }] + >, + StaticInterfaceMethod< + /*desc=*/ [{Returns the name of the IntegerOverflowFlagsAttr attribute + for the operation}], + /*returnType=*/ "StringRef", + /*methodName=*/ "getIntegerOverflowAttrName", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + return "overflowFlags"; + }] + > ]; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index f6dca8e2338816..f8f9264b3889be 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -60,16 +60,16 @@ class LLVM_IntArithmeticOpWithOverflowFlag], traits)> { dag iofArg = ( - ins EnumProperty<"IntegerOverflowFlags">:$overflowFlags); + ins DefaultValuedAttr:$overflowFlags); let arguments = !con(commonArgs, iofArg); string mlirBuilder = [{ auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); - moduleImport.setIntegerOverflowFlags(inst, op); + moduleImport.setIntegerOverflowFlagsAttr(inst, op); $res = op; }]; let assemblyFormat = [{ - $lhs `,` $rhs `` custom($overflowFlags) - `` custom(attr-dict) `:` type($res) + $lhs `,` $rhs (`overflow` `` $overflowFlags^)? + custom(attr-dict) `:` type($res) }]; string llvmBuilder = "$res = builder.Create" # instName # diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 6180d17697c271..b551eb937cfe8d 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -183,7 +183,8 @@ class ModuleImport { /// Sets the integer overflow flags (nsw/nuw) attribute for the imported /// operation `op` given the original instruction `inst`. Asserts if the /// operation does not implement the integer overflow flag interface. - void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const; + void setIntegerOverflowFlagsAttr(llvm::Instruction *inst, + Operation *op) const; /// Sets the fastmath flags attribute for the imported operation `op` given /// the original instruction `inst`. Asserts if the operation does not diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp index cf60a048f782c6..f12eba98480d33 100644 --- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp +++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp @@ -49,6 +49,13 @@ LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM( return llvmFlags; } +LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM( + arith::IntegerOverflowFlagsAttr flagsAttr) { + arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue(); + return LLVM::IntegerOverflowFlagsAttr::get( + flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags)); +} + LLVM::RoundingMode mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) { switch (roundingMode) { diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 1886dfa870961a..83c31a204efc7e 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -329,19 +329,14 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Detail methods //===----------------------------------------------------------------------===// -void LLVM::detail::setNativeProperties(Operation *op, - IntegerOverflowFlags overflowFlags) { - if (auto iface = dyn_cast(op)) - iface.setOverflowFlags(overflowFlags); -} - /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. -LogicalResult LLVM::detail::oneToOneRewrite( - Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags) { +LogicalResult +LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp, + ValueRange operands, + ArrayRef targetAttrs, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { unsigned numResults = op->getNumResults(); SmallVector resultTypes; @@ -357,8 +352,6 @@ LogicalResult LLVM::detail::oneToOneRewrite( rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, resultTypes, targetAttrs); - setNativeProperties(newOp, overflowFlags); - // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) return rewriter.eraseOp(op), success(); diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index 626135c10a3e96..544bcc71aca1b5 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -103,11 +103,12 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( return success(); } -LogicalResult LLVM::detail::vectorOneToOneRewrite( - Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags) { +LogicalResult +LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp, + ValueRange operands, + ArrayRef targetAttrs, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { assert(!operands.empty()); // Cannot convert ops if their operands are not of LLVM type. @@ -117,15 +118,14 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite( auto llvmNDVectorTy = operands[0].getType(); if (!isa(llvmNDVectorTy)) return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, - rewriter, overflowFlags); + rewriter); - auto callback = [op, targetOp, targetAttrs, overflowFlags, - &rewriter](Type llvm1DVectorTy, ValueRange operands) { - Operation *newOp = - rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), - operands, llvm1DVectorTy, targetAttrs); - LLVM::detail::setNativeProperties(newOp, overflowFlags); - return newOp->getResult(0); + auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy, + ValueRange operands) { + return rewriter + .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, + llvm1DVectorTy, targetAttrs) + ->getResult(0); }; return handleMultidimensionalVectors(op, operands, typeConverter, callback, diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 84994d816ad1a1..f90240a67dcc5f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -47,74 +47,6 @@ using mlir::LLVM::linkage::getMaxEnumValForLinkage; #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" -//===----------------------------------------------------------------------===// -// Property Helpers -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// IntegerOverflowFlags - -namespace mlir { -static Attribute convertToAttribute(MLIRContext *ctx, - IntegerOverflowFlags flags) { - return IntegerOverflowFlagsAttr::get(ctx, flags); -} - -static LogicalResult -convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr, - function_ref emitError) { - auto flagsAttr = dyn_cast(attr); - if (!flagsAttr) { - return emitError() << "expected 'overflowFlags' attribute to be an " - "IntegerOverflowFlagsAttr, but got " - << attr; - } - flags = flagsAttr.getValue(); - return success(); -} -} // namespace mlir - -static ParseResult parseOverflowFlags(AsmParser &p, - IntegerOverflowFlags &flags) { - if (failed(p.parseOptionalKeyword("overflow"))) { - flags = IntegerOverflowFlags::none; - return success(); - } - if (p.parseLess()) - return failure(); - do { - StringRef kw; - SMLoc loc = p.getCurrentLocation(); - if (p.parseKeyword(&kw)) - return failure(); - std::optional flag = - symbolizeIntegerOverflowFlags(kw); - if (!flag) - return p.emitError(loc, - "invalid overflow flag: expected nsw, nuw, or none"); - flags = flags | *flag; - } while (succeeded(p.parseOptionalComma())); - return p.parseGreater(); -} - -static void printOverflowFlags(AsmPrinter &p, Operation *op, - IntegerOverflowFlags flags) { - if (flags == IntegerOverflowFlags::none) - return; - p << " overflow<"; - SmallVector strs; - if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw)) - strs.push_back("nsw"); - if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw)) - strs.push_back("nuw"); - llvm::interleaveComma(strs, p); - p << ">"; -} - -//===----------------------------------------------------------------------===// -// Attribute Helpers -//===----------------------------------------------------------------------===// - static constexpr const char kElemTypeAttrName[] = "elem_type"; static auto processFMFAttr(ArrayRef attrs) { @@ -138,12 +70,12 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, DictionaryAttr attrs) { auto filteredAttrs = processFMFAttr(attrs.getValue()); - if (auto iface = dyn_cast(op)) { + if (auto iface = dyn_cast(op)) printer.printOptionalAttrDict( - filteredAttrs, /*elidedAttrs=*/{iface.getOverflowFlagsAttrName()}); - } else { + filteredAttrs, + /*elidedAttrs=*/{iface.getIntegerOverflowAttrName()}); + else printer.printOptionalAttrDict(filteredAttrs); - } } /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index d964710f8e3f38..af998b99d511f0 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -625,8 +625,8 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst, } } -void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst, - Operation *op) const { +void ModuleImport::setIntegerOverflowFlagsAttr(llvm::Instruction *inst, + Operation *op) const { auto iface = cast(op); IntegerOverflowFlags value = {}; @@ -634,7 +634,8 @@ void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst, value = bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap()); - iface.setOverflowFlags(value); + auto attr = IntegerOverflowFlagsAttr::get(op->getContext(), value); + iface->setAttr(iface.getIntegerOverflowAttrName(), attr); } void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,