Skip to content

Commit

Permalink
[mlir][llvm] Port overflowFlags to a native operation property (REL…
Browse files Browse the repository at this point in the history
…AND) (#89410)

This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on
operations as native properties.

Reland to fix flang
  • Loading branch information
Mogball committed Apr 19, 2024
1 parent d86079f commit e553ac4
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 116 deletions.
10 changes: 4 additions & 6 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2110,9 +2110,8 @@ struct XArrayCoorOpConversion
const bool baseIsBoxed = coor.getMemref().getType().isa<fir::BaseBoxType>();
TypePair baseBoxTyPair =
baseIsBoxed ? getBoxTypePair(coor.getMemref().getType()) : TypePair{};
mlir::LLVM::IntegerOverflowFlagsAttr nsw =
mlir::LLVM::IntegerOverflowFlagsAttr::get(
rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw);
mlir::LLVM::IntegerOverflowFlags nsw =
mlir::LLVM::IntegerOverflowFlags::nsw;

// For each dimension of the array, generate the offset calculation.
for (unsigned i = 0; i < rank; ++i, ++indexOffset, ++shapeOffset,
Expand Down Expand Up @@ -2396,9 +2395,8 @@ struct CoordinateOpConversion
auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy);
mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext());
mlir::Type byteTy = ::getI8Type(coor.getContext());
mlir::LLVM::IntegerOverflowFlagsAttr nsw =
mlir::LLVM::IntegerOverflowFlagsAttr::get(
rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw);
mlir::LLVM::IntegerOverflowFlags nsw =
mlir::LLVM::IntegerOverflowFlags::nsw;

for (unsigned i = 1, last = operands.size(); i < last; ++i) {
if (auto arrTy = cpnTy.dyn_cast<fir::SequenceType>()) {
Expand Down
22 changes: 11 additions & 11 deletions mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
LLVM::IntegerOverflowFlags
convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);

/// Creates an LLVM overflow attribute from a given arithmetic overflow
/// attribute.
LLVM::IntegerOverflowFlagsAttr
convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);

/// Creates an LLVM rounding mode enum value from a given arithmetic rounding
/// mode enum value.
LLVM::RoundingMode
Expand Down Expand Up @@ -72,6 +67,9 @@ class AttrConvertFastMathToLLVM {
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
LLVM::IntegerOverflowFlags getOverflowFlags() const {
return LLVM::IntegerOverflowFlags::none;
}

private:
NamedAttrList convertedAttr;
Expand All @@ -89,19 +87,18 @@ class AttrConvertOverflowToLLVM {
// Get the name of the arith overflow attribute.
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
// Remove the source overflow attribute.
auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
convertedAttr.erase(arithAttrName));
if (arithAttr) {
StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
convertedAttr.set(targetAttrName,
convertArithOverflowAttrToLLVM(arithAttr));
if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
convertedAttr.erase(arithAttrName))) {
overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
}
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; }

private:
NamedAttrList convertedAttr;
LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none;
};

template <typename SourceOp, typename TargetOp>
Expand Down Expand Up @@ -132,6 +129,9 @@ class AttrConverterConstrainedFPToLLVM {
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
LLVM::IntegerOverflowFlags getOverflowFlags() const {
return LLVM::IntegerOverflowFlags::none;
}

private:
NamedAttrList convertedAttr;
Expand Down
14 changes: 9 additions & 5 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,24 @@

#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
class CallOpInterface;

namespace LLVM {
namespace detail {
/// Handle generically setting flags as native properties on LLVM operations.
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);

/// Replaces the given operation "op" with a new operation of type "targetOp"
/// and given operands.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
LogicalResult oneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);

} // namespace detail
} // namespace LLVM
Expand Down
16 changes: 10 additions & 6 deletions mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ LogicalResult handleMultidimensionalVectors(
std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter);

LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
LogicalResult vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
} // namespace detail
} // namespace LLVM

Expand All @@ -70,6 +70,9 @@ class AttrConvertPassThrough {
AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}

ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
LLVM::IntegerOverflowFlags getOverflowFlags() const {
return LLVM::IntegerOverflowFlags::none;
}

private:
ArrayRef<NamedAttribute> srcAttrs;
Expand Down Expand Up @@ -100,7 +103,8 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {

return LLVM::detail::vectorOneToOneRewrite(
op, TargetOp::getOperationName(), adaptor.getOperands(),
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter);
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter,
attrConvert.getOverflowFlags());
}
};
} // namespace mlir
Expand Down
76 changes: 29 additions & 47 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -50,58 +50,40 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {

def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
let description = [{
Access to op integer overflow flags.
This interface defines an LLVM operation with integer overflow flags and
provides a uniform API for accessing them.
}];

let cppNamespace = "::mlir::LLVM";

let methods = [
InterfaceMethod<
/*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation",
/*returnType=*/ "IntegerOverflowFlagsAttr",
/*methodName=*/ "getOverflowAttr",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getOverflowFlagsAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
/*returnType=*/ "bool",
/*methodName=*/ "hasNoUnsignedWrap",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
}]
>,
InterfaceMethod<
/*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
/*returnType=*/ "bool",
/*methodName=*/ "hasNoSignedWrap",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the IntegerOverflowFlagsAttr attribute
for the operation}],
/*returnType=*/ "StringRef",
/*methodName=*/ "getIntegerOverflowAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
return "overflowFlags";
}]
>
InterfaceMethod<[{
Get the integer overflow flags for the operation.
}], "IntegerOverflowFlags", "getOverflowFlags", (ins), [{}], [{
return $_op.getProperties().overflowFlags;
}]>,
InterfaceMethod<[{
Set the integer overflow flags for the operation.
}], "void", "setOverflowFlags", (ins "IntegerOverflowFlags":$flags), [{}], [{
$_op.getProperties().overflowFlags = flags;
}]>,
InterfaceMethod<[{
Returns whether the operation has the No Unsigned Wrap keyword.
}], "bool", "hasNoUnsignedWrap", (ins), [{}], [{
return bitEnumContainsAll($_op.getOverflowFlags(),
IntegerOverflowFlags::nuw);
}]>,
InterfaceMethod<[{
Returns whether the operation has the No Signed Wrap keyword.
}], "bool", "hasNoSignedWrap", (ins), [{}], [{
return bitEnumContainsAll($_op.getOverflowFlags(),
IntegerOverflowFlags::nsw);
}]>,
StaticInterfaceMethod<[{
Get the attribute name of the overflow flags property.
}], "StringRef", "getOverflowFlagsAttrName", (ins), [{}], [{
return "overflowFlags";
}]>,
];
}

Expand Down
23 changes: 18 additions & 5 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,30 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
!listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
dag iofArg = (
ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags);
dag iofArg = (ins EnumProperty<"IntegerOverflowFlags">:$overflowFlags);
let arguments = !con(commonArgs, iofArg);

let builders = [
OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs,
"IntegerOverflowFlags":$overflowFlags), [{
build($_builder, $_state, type, lhs, rhs);
$_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
}]>,
OpBuilder<(ins "Value":$lhs, "Value":$rhs,
"IntegerOverflowFlags":$overflowFlags), [{
build($_builder, $_state, lhs, rhs);
$_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
}]>
];

string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
moduleImport.setIntegerOverflowFlagsAttr(inst, op);
moduleImport.setIntegerOverflowFlags(inst, op);
$res = op;
}];
let assemblyFormat = [{
$lhs `,` $rhs (`overflow` `` $overflowFlags^)?
custom<LLVMOpAttrs>(attr-dict) `:` type($res)
$lhs `,` $rhs `` custom<OverflowFlags>($overflowFlags)
`` custom<LLVMOpAttrs>(attr-dict) `:` type($res)
}];
string llvmBuilder =
"$res = builder.Create" # instName #
Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,7 @@ class ModuleImport {
/// Sets the integer overflow flags (nsw/nuw) attribute for the imported
/// operation `op` given the original instruction `inst`. Asserts if the
/// operation does not implement the integer overflow flag interface.
void setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
Operation *op) const;
void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const;

/// Sets the fastmath flags attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
Expand Down
7 changes: 0 additions & 7 deletions mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,6 @@ LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM(
return llvmFlags;
}

LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
arith::IntegerOverflowFlagsAttr flagsAttr) {
arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
return LLVM::IntegerOverflowFlagsAttr::get(
flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
}

LLVM::RoundingMode
mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
switch (roundingMode) {
Expand Down
19 changes: 13 additions & 6 deletions mlir/lib/Conversion/LLVMCommon/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,19 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Detail methods
//===----------------------------------------------------------------------===//

void LLVM::detail::setNativeProperties(Operation *op,
IntegerOverflowFlags overflowFlags) {
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
iface.setOverflowFlags(overflowFlags);
}

/// Replaces the given operation "op" with a new operation of type "targetOp"
/// and given operands.
LogicalResult
LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
LogicalResult LLVM::detail::oneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags) {
unsigned numResults = op->getNumResults();

SmallVector<Type> resultTypes;
Expand All @@ -352,6 +357,8 @@ LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
resultTypes, targetAttrs);

setNativeProperties(newOp, overflowFlags);

// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)
return rewriter.eraseOp(op), success();
Expand Down
26 changes: 13 additions & 13 deletions mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
return success();
}

LogicalResult
LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
LogicalResult LLVM::detail::vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags) {
assert(!operands.empty());

// Cannot convert ops if their operands are not of LLVM type.
Expand All @@ -118,14 +117,15 @@ LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
auto llvmNDVectorTy = operands[0].getType();
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
rewriter);
rewriter, overflowFlags);

auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy,
ValueRange operands) {
return rewriter
.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
llvm1DVectorTy, targetAttrs)
->getResult(0);
auto callback = [op, targetOp, targetAttrs, overflowFlags,
&rewriter](Type llvm1DVectorTy, ValueRange operands) {
Operation *newOp =
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp),
operands, llvm1DVectorTy, targetAttrs);
LLVM::detail::setNativeProperties(newOp, overflowFlags);
return newOp->getResult(0);
};

return handleMultidimensionalVectors(op, operands, typeConverter, callback,
Expand Down

0 comments on commit e553ac4

Please sign in to comment.