458 changes: 210 additions & 248 deletions mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ LLVMTypeConverter::LLVMTypeConverter(
addConversion([](LLVM::LLVMType type) { return type; });
}

/// Returns the MLIR context.
MLIRContext &LLVMTypeConverter::getContext() {
return *getDialect()->getContext();
}

/// Get the LLVM context.
llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
return module->getContext();
Expand Down Expand Up @@ -699,52 +704,35 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
results.push_back(d.memRefDescPtr(builder, loc));
}

namespace {
// Base class for Standard to LLVM IR op conversions. Matches the Op type
// provided as template argument. Carries a reference to the LLVM dialect in
// case it is necessary for rewriters.
template <typename SourceOp>
class LLVMLegalizationPattern : public ConvertToLLVMPattern {
public:
// Construct a conversion pattern.
explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_,
LLVMTypeConverter &typeConverter_)
: ConvertToLLVMPattern(SourceOp::getOperationName(),
dialect_.getContext(), typeConverter_),
dialect(dialect_) {}

// Get the LLVM IR dialect.
LLVM::LLVMDialect &getDialect() const { return dialect; }
// Get the LLVM context.
llvm::LLVMContext &getContext() const { return dialect.getLLVMContext(); }
// Get the LLVM module in which the types are constructed.
llvm::Module &getModule() const { return dialect.getLLVMModule(); }

// Get the MLIR type wrapping the LLVM integer type whose bit width is defined
// by the pointer size used in the LLVM module.
LLVM::LLVMType getIndexType() const {
return LLVM::LLVMType::getIntNTy(
&dialect, getModule().getDataLayout().getPointerSizeInBits());
}
LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
return *typeConverter.getDialect();
}

LLVM::LLVMType getVoidType() const {
return LLVM::LLVMType::getVoidTy(&dialect);
}
llvm::LLVMContext &ConvertToLLVMPattern::getContext() const {
return typeConverter.getLLVMContext();
}

// Get the MLIR type wrapping the LLVM i8* type.
LLVM::LLVMType getVoidPtrType() const {
return LLVM::LLVMType::getInt8PtrTy(&dialect);
}
llvm::Module &ConvertToLLVMPattern::getModule() const {
return getDialect().getLLVMModule();
}

// Create an LLVM IR pseudo-operation defining the given index constant.
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
uint64_t value) const {
return createIndexAttrConstant(builder, loc, getIndexType(), value);
}
LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
return LLVM::LLVMType::getIntNTy(
&getDialect(), getModule().getDataLayout().getPointerSizeInBits());
}

protected:
LLVM::LLVMDialect &dialect;
};
LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const {
return LLVM::LLVMType::getVoidTy(&getDialect());
}

LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const {
return LLVM::LLVMType::getInt8PtrTy(&getDialect());
}

Value ConvertToLLVMPattern::createIndexConstant(
ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
return createIndexAttrConstant(builder, loc, getIndexType(), value);
}

/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
Expand Down Expand Up @@ -876,9 +864,11 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
builder.create<LLVM::ReturnOp>(loc, call.getResults());
}

struct FuncOpConversionBase : public LLVMLegalizationPattern<FuncOp> {
namespace {

struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
protected:
using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
using UnsignedTypePair = std::pair<unsigned, Type>;

// Gather the positions and types of memref-typed arguments in a given
Expand Down Expand Up @@ -942,9 +932,8 @@ struct FuncOpConversionBase : public LLVMLegalizationPattern<FuncOp> {
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion(LLVM::LLVMDialect &dialect, LLVMTypeConverter &converter,
bool emitCWrappers)
: FuncOpConversionBase(dialect, converter), emitWrappers(emitCWrappers) {}
FuncOpConversion(LLVMTypeConverter &converter, bool emitCWrappers)
: FuncOpConversionBase(converter), emitWrappers(emitCWrappers) {}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand Down Expand Up @@ -1022,7 +1011,6 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
};

//////////////// Support for Lowering operations on n-D vectors ////////////////
namespace {
// Helper struct to "unroll" operations on n-D vectors in terms of operations on
// 1-D LLVM vectors.
struct NDVectorTypeInfo {
Expand Down Expand Up @@ -1098,55 +1086,49 @@ void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
fun(position);
}
}
////////////// End Support for Lowering operations on n-D vectors //////////////

// Basic lowering implementation for one-to-one rewriting from Standard Ops to
// LLVM Dialect Ops.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
using Super = OneToOneLLVMOpLowering<SourceOp, TargetOp>;

// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
unsigned numResults = op->getNumResults();
/// Replaces the given operaiton "op" with a new operation of type "targetOp"
/// and given operands.
LogicalResult LLVM::detail::oneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
unsigned numResults = op->getNumResults();

Type packedType;
if (numResults != 0) {
packedType =
this->typeConverter.packFunctionResults(op->getResultTypes());
if (!packedType)
return failure();
}

auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
op->getAttrs());

// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)
return rewriter.eraseOp(op), success();
if (numResults == 1)
return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)),
success();
Type packedType;
if (numResults != 0) {
packedType = typeConverter.packFunctionResults(op->getResultTypes());
if (!packedType)
return failure();
}

// Otherwise, it had been converted to an operation producing a structure.
// Extract individual results from the structure and return them as list.
SmallVector<Value, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
auto type = this->typeConverter.convertType(op->getResult(i).getType());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), type, newOp.getOperation()->getResult(0),
rewriter.getI64ArrayAttr(i)));
}
rewriter.replaceOp(op, results);
return success();
// Create the operation through state since we don't know its C++ type.
OperationState state(op->getLoc(), targetOp);
state.addTypes(packedType);
state.addOperands(operands);
state.addAttributes(op->getAttrs());
Operation *newOp = rewriter.createOperation(state);

// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)
return rewriter.eraseOp(op), success();
if (numResults == 1)
return rewriter.replaceOp(op, newOp->getResult(0)), success();

// Otherwise, it had been converted to an operation producing a structure.
// Extract individual results from the structure and return them as list.
SmallVector<Value, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
auto type = typeConverter.convertType(op->getResult(i).getType());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
}
};
rewriter.replaceOp(op, results);
return success();
}

////////////// End Support for Lowering operations on n-D vectors //////////////
namespace {
template <typename SourceOp, unsigned OpCount>
struct OpCountValidator {
static_assert(
Expand All @@ -1166,9 +1148,10 @@ template <typename SourceOp, unsigned OpCount>
void ValidateOpCount() {
OpCountValidator<SourceOp, OpCount>();
}
} // namespace

static LogicalResult HandleMultidimensionalVectors(
Operation *op, ArrayRef<Value> operands, LLVMTypeConverter &typeConverter,
static LogicalResult handleMultidimensionalVectors(
Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
std::function<Value(LLVM::LLVMType, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter) {
auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
Expand Down Expand Up @@ -1197,159 +1180,145 @@ static LogicalResult HandleMultidimensionalVectors(
return success();
}

// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
// Ops for N-ary ops with one result. This supports higher-dimensional vector
// types.
template <typename SourceOp, typename TargetOp, unsigned OpCount>
struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>;
LogicalResult LLVM::detail::vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
assert(!operands.empty());

// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
ValidateOpCount<SourceOp, OpCount>();
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
SourceOp>::value,
"expected same operands and result type");

// Cannot convert ops if their operands are not of LLVM type.
for (Value operand : operands) {
if (!operand || !operand.getType().isa<LLVM::LLVMType>())
return failure();
}

auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
// Cannot convert ops if their operands are not of LLVM type.
if (!llvm::all_of(operands.getTypes(),
[](Type t) { return t.isa<LLVM::LLVMType>(); }))
return failure();

if (!llvmArrayTy.isArrayTy()) {
auto newOp = rewriter.create<TargetOp>(
op->getLoc(), operands[0].getType(), operands, op->getAttrs());
rewriter.replaceOp(op, newOp.getResult());
return success();
}
auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
if (!llvmArrayTy.isArrayTy())
return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);

if (succeeded(HandleMultidimensionalVectors(
op, operands, this->typeConverter,
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
return rewriter.create<TargetOp>(op->getLoc(), llvmVectorTy,
operands, op->getAttrs());
},
rewriter)))
return success();
return failure();
}
};
auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy,
ValueRange operands) {
OperationState state(op->getLoc(), targetOp);
state.addTypes(llvmVectorTy);
state.addOperands(operands);
state.addAttributes(op->getAttrs());
return rewriter.createOperation(state)->getResult(0);
};

template <typename SourceOp, typename TargetOp>
using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1>;
template <typename SourceOp, typename TargetOp>
using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2>;
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
rewriter);
}

namespace {
// Specific lowerings.
// FIXME: this should be tablegen'ed.
struct AbsFOpLowering : public UnaryOpLLVMOpLowering<AbsFOp, LLVM::FAbsOp> {
struct AbsFOpLowering
: public VectorConvertToLLVMPattern<AbsFOp, LLVM::FAbsOp> {
using Super::Super;
};
struct CeilFOpLowering : public UnaryOpLLVMOpLowering<CeilFOp, LLVM::FCeilOp> {
struct CeilFOpLowering
: public VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp> {
using Super::Super;
};
struct CosOpLowering : public UnaryOpLLVMOpLowering<CosOp, LLVM::CosOp> {
struct CosOpLowering : public VectorConvertToLLVMPattern<CosOp, LLVM::CosOp> {
using Super::Super;
};
struct ExpOpLowering : public UnaryOpLLVMOpLowering<ExpOp, LLVM::ExpOp> {
struct ExpOpLowering : public VectorConvertToLLVMPattern<ExpOp, LLVM::ExpOp> {
using Super::Super;
};
struct LogOpLowering : public UnaryOpLLVMOpLowering<LogOp, LLVM::LogOp> {
struct LogOpLowering : public VectorConvertToLLVMPattern<LogOp, LLVM::LogOp> {
using Super::Super;
};
struct Log10OpLowering : public UnaryOpLLVMOpLowering<Log10Op, LLVM::Log10Op> {
struct Log10OpLowering
: public VectorConvertToLLVMPattern<Log10Op, LLVM::Log10Op> {
using Super::Super;
};
struct Log2OpLowering : public UnaryOpLLVMOpLowering<Log2Op, LLVM::Log2Op> {
struct Log2OpLowering
: public VectorConvertToLLVMPattern<Log2Op, LLVM::Log2Op> {
using Super::Super;
};
struct NegFOpLowering : public UnaryOpLLVMOpLowering<NegFOp, LLVM::FNegOp> {
struct NegFOpLowering
: public VectorConvertToLLVMPattern<NegFOp, LLVM::FNegOp> {
using Super::Super;
};
struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
struct AddIOpLowering : public VectorConvertToLLVMPattern<AddIOp, LLVM::AddOp> {
using Super::Super;
};
struct SubIOpLowering : public BinaryOpLLVMOpLowering<SubIOp, LLVM::SubOp> {
struct SubIOpLowering : public VectorConvertToLLVMPattern<SubIOp, LLVM::SubOp> {
using Super::Super;
};
struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> {
struct MulIOpLowering : public VectorConvertToLLVMPattern<MulIOp, LLVM::MulOp> {
using Super::Super;
};
struct SignedDivIOpLowering
: public BinaryOpLLVMOpLowering<SignedDivIOp, LLVM::SDivOp> {
: public VectorConvertToLLVMPattern<SignedDivIOp, LLVM::SDivOp> {
using Super::Super;
};
struct SqrtOpLowering : public UnaryOpLLVMOpLowering<SqrtOp, LLVM::SqrtOp> {
struct SqrtOpLowering
: public VectorConvertToLLVMPattern<SqrtOp, LLVM::SqrtOp> {
using Super::Super;
};
struct UnsignedDivIOpLowering
: public BinaryOpLLVMOpLowering<UnsignedDivIOp, LLVM::UDivOp> {
: public VectorConvertToLLVMPattern<UnsignedDivIOp, LLVM::UDivOp> {
using Super::Super;
};
struct SignedRemIOpLowering
: public BinaryOpLLVMOpLowering<SignedRemIOp, LLVM::SRemOp> {
: public VectorConvertToLLVMPattern<SignedRemIOp, LLVM::SRemOp> {
using Super::Super;
};
struct UnsignedRemIOpLowering
: public BinaryOpLLVMOpLowering<UnsignedRemIOp, LLVM::URemOp> {
: public VectorConvertToLLVMPattern<UnsignedRemIOp, LLVM::URemOp> {
using Super::Super;
};
struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> {
struct AndOpLowering : public VectorConvertToLLVMPattern<AndOp, LLVM::AndOp> {
using Super::Super;
};
struct OrOpLowering : public BinaryOpLLVMOpLowering<OrOp, LLVM::OrOp> {
struct OrOpLowering : public VectorConvertToLLVMPattern<OrOp, LLVM::OrOp> {
using Super::Super;
};
struct XOrOpLowering : public BinaryOpLLVMOpLowering<XOrOp, LLVM::XOrOp> {
struct XOrOpLowering : public VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp> {
using Super::Super;
};
struct AddFOpLowering : public BinaryOpLLVMOpLowering<AddFOp, LLVM::FAddOp> {
struct AddFOpLowering
: public VectorConvertToLLVMPattern<AddFOp, LLVM::FAddOp> {
using Super::Super;
};
struct SubFOpLowering : public BinaryOpLLVMOpLowering<SubFOp, LLVM::FSubOp> {
struct SubFOpLowering
: public VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp> {
using Super::Super;
};
struct MulFOpLowering : public BinaryOpLLVMOpLowering<MulFOp, LLVM::FMulOp> {
struct MulFOpLowering
: public VectorConvertToLLVMPattern<MulFOp, LLVM::FMulOp> {
using Super::Super;
};
struct DivFOpLowering : public BinaryOpLLVMOpLowering<DivFOp, LLVM::FDivOp> {
struct DivFOpLowering
: public VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp> {
using Super::Super;
};
struct RemFOpLowering : public BinaryOpLLVMOpLowering<RemFOp, LLVM::FRemOp> {
struct RemFOpLowering
: public VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp> {
using Super::Super;
};
struct CopySignOpLowering
: public BinaryOpLLVMOpLowering<CopySignOp, LLVM::CopySignOp> {
: public VectorConvertToLLVMPattern<CopySignOp, LLVM::CopySignOp> {
using Super::Super;
};
struct SelectOpLowering
: public OneToOneLLVMOpLowering<SelectOp, LLVM::SelectOp> {
: public OneToOneConvertToLLVMPattern<SelectOp, LLVM::SelectOp> {
using Super::Super;
};
struct ConstLLVMOpLowering
: public OneToOneLLVMOpLowering<ConstantOp, LLVM::ConstantOp> {
: public OneToOneConvertToLLVMPattern<ConstantOp, LLVM::ConstantOp> {
using Super::Super;
};
struct ShiftLeftOpLowering
: public OneToOneLLVMOpLowering<ShiftLeftOp, LLVM::ShlOp> {
: public OneToOneConvertToLLVMPattern<ShiftLeftOp, LLVM::ShlOp> {
using Super::Super;
};
struct SignedShiftRightOpLowering
: public OneToOneLLVMOpLowering<SignedShiftRightOp, LLVM::AShrOp> {
: public OneToOneConvertToLLVMPattern<SignedShiftRightOp, LLVM::AShrOp> {
using Super::Super;
};
struct UnsignedShiftRightOpLowering
: public OneToOneLLVMOpLowering<UnsignedShiftRightOp, LLVM::LShrOp> {
: public OneToOneConvertToLLVMPattern<UnsignedShiftRightOp, LLVM::LShrOp> {
using Super::Super;
};

Expand All @@ -1373,13 +1342,11 @@ static bool isSupportedMemRefType(MemRefType type) {
// Alignment is obtained by allocating `alignment - 1` more bytes than requested
// and shifting the aligned pointer relative to the allocated memory. If
// alignment is unspecified, the two pointers are equal.
struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
struct AllocOpLowering : public ConvertOpToLLVMPattern<AllocOp> {
using ConvertOpToLLVMPattern<AllocOp>::ConvertOpToLLVMPattern;

AllocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter,
bool useAlloca = false)
: LLVMLegalizationPattern<AllocOp>(dialect_, converter),
useAlloca(useAlloca) {}
explicit AllocOpLowering(LLVMTypeConverter &converter, bool useAlloca = false)
: ConvertOpToLLVMPattern<AllocOp>(converter), useAlloca(useAlloca) {}

LogicalResult match(Operation *op) const override {
MemRefType type = cast<AllocOp>(op).getType();
Expand Down Expand Up @@ -1569,10 +1536,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
// passes the pointer to the MemRef across function boundaries.
template <typename CallOpType>
struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
using LLVMLegalizationPattern<CallOpType>::LLVMLegalizationPattern;
struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
using Super = CallOpInterfaceLowering<CallOpType>;
using Base = LLVMLegalizationPattern<CallOpType>;
using Base = ConvertOpToLLVMPattern<CallOpType>;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand Down Expand Up @@ -1639,13 +1606,12 @@ struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
// The memref descriptor being an SSA value, there is no need to clean it up
// in any way.
struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
using ConvertOpToLLVMPattern<DeallocOp>::ConvertOpToLLVMPattern;

DeallocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter,
bool useAlloca = false)
: LLVMLegalizationPattern<DeallocOp>(dialect_, converter),
useAlloca(useAlloca) {}
explicit DeallocOpLowering(LLVMTypeConverter &converter,
bool useAlloca = false)
: ConvertOpToLLVMPattern<DeallocOp>(converter), useAlloca(useAlloca) {}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand Down Expand Up @@ -1680,8 +1646,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
};

// A `rsqrt` is converted into `1 / sqrt`.
struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
using LLVMLegalizationPattern<RsqrtOp>::LLVMLegalizationPattern;
struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand Down Expand Up @@ -1716,29 +1682,26 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
if (!vectorType)
return failure();

if (succeeded(HandleMultidimensionalVectors(
op, operands, typeConverter,
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get({llvmVectorTy.getUnderlyingType()
->getVectorNumElements()},
floatType),
floatOne);
auto one = rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy,
splatAttr);
auto sqrt =
rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one,
sqrt);
},
rewriter)))
return success();
return failure();
return handleMultidimensionalVectors(
op, operands, typeConverter,
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
{llvmVectorTy.getUnderlyingType()->getVectorNumElements()},
floatType),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
auto sqrt =
rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
},
rewriter);
}
};

struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
using ConvertOpToLLVMPattern<MemRefCastOp>::ConvertOpToLLVMPattern;

LogicalResult match(Operation *op) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
Expand Down Expand Up @@ -1833,8 +1796,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
};

struct DialectCastOpLowering
: public LLVMLegalizationPattern<LLVM::DialectCastOp> {
using LLVMLegalizationPattern<LLVM::DialectCastOp>::LLVMLegalizationPattern;
: public ConvertOpToLLVMPattern<LLVM::DialectCastOp> {
using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand All @@ -1852,8 +1815,8 @@ struct DialectCastOpLowering

// A `dim` is converted to a constant for static sizes and to an access to the
// size stored in the memref descriptor for dynamic sizes.
struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
using ConvertOpToLLVMPattern<DimOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand All @@ -1880,8 +1843,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
template <typename Derived>
struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern;
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
using Base = LoadStoreOpLowering<Derived>;

LogicalResult match(Operation *op) const override {
Expand Down Expand Up @@ -2029,8 +1992,8 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
// an integer. If the bit width of the source and target integer types is the
// same, just erase the cast. If the target type is wider, sign-extend the
// value, otherwise truncate it.
struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern;
struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
using ConvertOpToLLVMPattern<IndexCastOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand Down Expand Up @@ -2064,8 +2027,8 @@ static LLVMPredType convertCmpPredicate(StdPredType pred) {
return static_cast<LLVMPredType>(pred);
}

struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern;
struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
using ConvertOpToLLVMPattern<CmpIOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand All @@ -2083,8 +2046,8 @@ struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
}
};

struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern;
struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
using ConvertOpToLLVMPattern<CmpFOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand All @@ -2103,39 +2066,40 @@ struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
};

struct SIToFPLowering
: public OneToOneLLVMOpLowering<SIToFPOp, LLVM::SIToFPOp> {
: public OneToOneConvertToLLVMPattern<SIToFPOp, LLVM::SIToFPOp> {
using Super::Super;
};

struct FPExtLowering : public OneToOneLLVMOpLowering<FPExtOp, LLVM::FPExtOp> {
struct FPExtLowering
: public OneToOneConvertToLLVMPattern<FPExtOp, LLVM::FPExtOp> {
using Super::Super;
};

struct FPTruncLowering
: public OneToOneLLVMOpLowering<FPTruncOp, LLVM::FPTruncOp> {
: public OneToOneConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp> {
using Super::Super;
};

struct SignExtendIOpLowering
: public OneToOneLLVMOpLowering<SignExtendIOp, LLVM::SExtOp> {
: public OneToOneConvertToLLVMPattern<SignExtendIOp, LLVM::SExtOp> {
using Super::Super;
};

struct TruncateIOpLowering
: public OneToOneLLVMOpLowering<TruncateIOp, LLVM::TruncOp> {
: public OneToOneConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp> {
using Super::Super;
};

struct ZeroExtendIOpLowering
: public OneToOneLLVMOpLowering<ZeroExtendIOp, LLVM::ZExtOp> {
: public OneToOneConvertToLLVMPattern<ZeroExtendIOp, LLVM::ZExtOp> {
using Super::Super;
};

// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
: public LLVMLegalizationPattern<SourceOp> {
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
: public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;

LogicalResult
Expand All @@ -2153,8 +2117,8 @@ struct OneToOneLLVMTerminatorLowering
// can only return 0 or 1 value, we pack multiple values into a structure type.
// Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
// necessary before returning it
struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand Down Expand Up @@ -2202,8 +2166,8 @@ struct CondBranchOpLowering

// The Splat operation is lowered to an insertelement + a shufflevector
// operation. Splat to only 1-d vector result types are lowered.
struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand Down Expand Up @@ -2236,8 +2200,8 @@ struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
// The Splat operation is lowered to an insertelement + a shufflevector
// operation. Splat to only 2+-d vector result types are lowered by the
// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand Down Expand Up @@ -2290,8 +2254,8 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
/// and stride.
/// The subview op is replaced by the descriptor.
struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
using LLVMLegalizationPattern<SubViewOp>::LLVMLegalizationPattern;
struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
using ConvertOpToLLVMPattern<SubViewOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand Down Expand Up @@ -2418,8 +2382,8 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
/// and stride.
/// The view op is replaced by the descriptor.
struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
using LLVMLegalizationPattern<ViewOp>::LLVMLegalizationPattern;
struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
using ConvertOpToLLVMPattern<ViewOp>::ConvertOpToLLVMPattern;

// Build and return the value for the idx^th shape dimension, either by
// returning the constant shape dimension or counting the proper dynamic size.
Expand Down Expand Up @@ -2535,8 +2499,8 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
};

struct AssumeAlignmentOpLowering
: public LLVMLegalizationPattern<AssumeAlignmentOp> {
using LLVMLegalizationPattern<AssumeAlignmentOp>::LLVMLegalizationPattern;
: public ConvertOpToLLVMPattern<AssumeAlignmentOp> {
using ConvertOpToLLVMPattern<AssumeAlignmentOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand Down Expand Up @@ -2788,7 +2752,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
UnsignedRemIOpLowering,
UnsignedShiftRightOpLowering,
XOrOpLowering,
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
ZeroExtendIOpLowering>(converter);
// clang-format on
}

Expand All @@ -2803,19 +2767,17 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
MemRefCastOpLowering,
StoreOpLowering,
SubViewOpLowering,
ViewOpLowering>(*converter.getDialect(), converter);
ViewOpLowering>(converter);
patterns.insert<
AllocOpLowering,
DeallocOpLowering>(
*converter.getDialect(), converter, useAlloca);
DeallocOpLowering>(converter, useAlloca);
// clang-format on
}

void mlir::populateStdToLLVMDefaultFuncOpConversionPattern(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
bool emitCWrappers) {
patterns.insert<FuncOpConversion>(*converter.getDialect(), converter,
emitCWrappers);
patterns.insert<FuncOpConversion>(converter, emitCWrappers);
}

void mlir::populateStdToLLVMConversionPatterns(
Expand All @@ -2829,7 +2791,7 @@ void mlir::populateStdToLLVMConversionPatterns(

static void populateStdToLLVMBarePtrFuncOpConversionPattern(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<BarePtrFuncOpConversion>(*converter.getDialect(), converter);
patterns.insert<BarePtrFuncOpConversion>(converter);
}

void mlir::populateStdToLLVMBarePtrConversionPatterns(
Expand Down