97 changes: 51 additions & 46 deletions mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp

Large diffs are not rendered by default.

48 changes: 20 additions & 28 deletions mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,11 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
complex::AbsOp::Adaptor transformed(operands);
auto loc = op.getLoc();

ComplexStructBuilder complexStruct(transformed.complex());
ComplexStructBuilder complexStruct(adaptor.complex());
Value real = complexStruct.real(rewriter, op.getLoc());
Value imag = complexStruct.imaginary(rewriter, op.getLoc());

Expand All @@ -81,16 +80,14 @@ struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(complex::CreateOp complexOp, ArrayRef<Value> operands,
matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
complex::CreateOp::Adaptor transformed(operands);

// Pack real and imaginary part in a complex number struct.
auto loc = complexOp.getLoc();
auto structType = typeConverter->convertType(complexOp.getType());
auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
complexStruct.setReal(rewriter, loc, transformed.real());
complexStruct.setImaginary(rewriter, loc, transformed.imaginary());
complexStruct.setReal(rewriter, loc, adaptor.real());
complexStruct.setImaginary(rewriter, loc, adaptor.imaginary());

rewriter.replaceOp(complexOp, {complexStruct});
return success();
Expand All @@ -101,12 +98,10 @@ struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(complex::ReOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
complex::ReOp::Adaptor transformed(operands);

// Extract real part from the complex number struct.
ComplexStructBuilder complexStruct(transformed.complex());
ComplexStructBuilder complexStruct(adaptor.complex());
Value real = complexStruct.real(rewriter, op.getLoc());
rewriter.replaceOp(op, real);

Expand All @@ -118,12 +113,10 @@ struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(complex::ImOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
complex::ImOp::Adaptor transformed(operands);

// Extract imaginary part from the complex number struct.
ComplexStructBuilder complexStruct(transformed.complex());
ComplexStructBuilder complexStruct(adaptor.complex());
Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
rewriter.replaceOp(op, imaginary);

Expand All @@ -138,17 +131,16 @@ struct BinaryComplexOperands {

template <typename OpTy>
BinaryComplexOperands
unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands,
unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) {
auto loc = op.getLoc();
typename OpTy::Adaptor transformed(operands);

// Extract real and imaginary values from operands.
BinaryComplexOperands unpacked;
ComplexStructBuilder lhs(transformed.lhs());
ComplexStructBuilder lhs(adaptor.lhs());
unpacked.lhs.real(lhs.real(rewriter, loc));
unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
ComplexStructBuilder rhs(transformed.rhs());
ComplexStructBuilder rhs(adaptor.rhs());
unpacked.rhs.real(rhs.real(rewriter, loc));
unpacked.rhs.imag(rhs.imaginary(rewriter, loc));

Expand All @@ -159,11 +151,11 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(complex::AddOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
BinaryComplexOperands arg =
unpackBinaryComplexOperands<complex::AddOp>(op, operands, rewriter);
unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);

// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
Expand All @@ -187,11 +179,11 @@ struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
BinaryComplexOperands arg =
unpackBinaryComplexOperands<complex::DivOp>(op, operands, rewriter);
unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);

// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
Expand Down Expand Up @@ -232,11 +224,11 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
BinaryComplexOperands arg =
unpackBinaryComplexOperands<complex::MulOp>(op, operands, rewriter);
unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);

// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
Expand Down Expand Up @@ -269,11 +261,11 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(complex::SubOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
BinaryComplexOperands arg =
unpackBinaryComplexOperands<complex::SubOp>(op, operands, rewriter);
unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);

// Initialize complex number struct for result.
auto structType = typeConverter->convertType(op.getType());
Expand Down
123 changes: 52 additions & 71 deletions mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,13 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
complex::AbsOp::Adaptor transformed(operands);
auto loc = op.getLoc();
auto type = op.getType();

Value real =
rewriter.create<complex::ReOp>(loc, type, transformed.complex());
Value imag =
rewriter.create<complex::ImOp>(loc, type, transformed.complex());
Value real = rewriter.create<complex::ReOp>(loc, type, adaptor.complex());
Value imag = rewriter.create<complex::ImOp>(loc, type, adaptor.complex());
Value realSqr = rewriter.create<MulFOp>(loc, real, real);
Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag);
Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr);
Expand All @@ -53,23 +50,16 @@ struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
AndOp, OrOp>;

LogicalResult
matchAndRewrite(ComparisonOp op, ArrayRef<Value> operands,
matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
typename ComparisonOp::Adaptor transformed(operands);
auto loc = op.getLoc();
auto type = transformed.lhs()
.getType()
.template cast<ComplexType>()
.getElementType();

Value realLhs =
rewriter.create<complex::ReOp>(loc, type, transformed.lhs());
Value imagLhs =
rewriter.create<complex::ImOp>(loc, type, transformed.lhs());
Value realRhs =
rewriter.create<complex::ReOp>(loc, type, transformed.rhs());
Value imagRhs =
rewriter.create<complex::ImOp>(loc, type, transformed.rhs());
auto type =
adaptor.lhs().getType().template cast<ComplexType>().getElementType();

Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.lhs());
Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.lhs());
Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.rhs());
Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.rhs());
Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs);
Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs);

Expand All @@ -87,19 +77,18 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(BinaryComplexOp op, ArrayRef<Value> operands,
matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
typename BinaryComplexOp::Adaptor transformed(operands);
auto type = transformed.lhs().getType().template cast<ComplexType>();
auto type = adaptor.lhs().getType().template cast<ComplexType>();
auto elementType = type.getElementType().template cast<FloatType>();
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);

Value realLhs = b.create<complex::ReOp>(elementType, transformed.lhs());
Value realRhs = b.create<complex::ReOp>(elementType, transformed.rhs());
Value realLhs = b.create<complex::ReOp>(elementType, adaptor.lhs());
Value realRhs = b.create<complex::ReOp>(elementType, adaptor.rhs());
Value resultReal =
b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
Value imagLhs = b.create<complex::ImOp>(elementType, transformed.lhs());
Value imagRhs = b.create<complex::ImOp>(elementType, transformed.rhs());
Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.lhs());
Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.rhs());
Value resultImag =
b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
Expand All @@ -112,21 +101,20 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
using OpConversionPattern<complex::DivOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
complex::DivOp::Adaptor transformed(operands);
auto loc = op.getLoc();
auto type = transformed.lhs().getType().cast<ComplexType>();
auto type = adaptor.lhs().getType().cast<ComplexType>();
auto elementType = type.getElementType().cast<FloatType>();

Value lhsReal =
rewriter.create<complex::ReOp>(loc, elementType, transformed.lhs());
rewriter.create<complex::ReOp>(loc, elementType, adaptor.lhs());
Value lhsImag =
rewriter.create<complex::ImOp>(loc, elementType, transformed.lhs());
rewriter.create<complex::ImOp>(loc, elementType, adaptor.lhs());
Value rhsReal =
rewriter.create<complex::ReOp>(loc, elementType, transformed.rhs());
rewriter.create<complex::ReOp>(loc, elementType, adaptor.rhs());
Value rhsImag =
rewriter.create<complex::ImOp>(loc, elementType, transformed.rhs());
rewriter.create<complex::ImOp>(loc, elementType, adaptor.rhs());

// Smith's algorithm to divide complex numbers. It is just a bit smarter
// way to compute the following formula:
Expand Down Expand Up @@ -321,17 +309,16 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
using OpConversionPattern<complex::ExpOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(complex::ExpOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
complex::ExpOp::Adaptor transformed(operands);
auto loc = op.getLoc();
auto type = transformed.complex().getType().cast<ComplexType>();
auto type = adaptor.complex().getType().cast<ComplexType>();
auto elementType = type.getElementType().cast<FloatType>();

Value real =
rewriter.create<complex::ReOp>(loc, elementType, transformed.complex());
rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex());
Value imag =
rewriter.create<complex::ImOp>(loc, elementType, transformed.complex());
rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex());
Value expReal = rewriter.create<math::ExpOp>(loc, real);
Value cosImag = rewriter.create<math::CosOp>(loc, imag);
Value resultReal = rewriter.create<MulFOp>(loc, expReal, cosImag);
Expand All @@ -348,17 +335,16 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
using OpConversionPattern<complex::LogOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(complex::LogOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
complex::LogOp::Adaptor transformed(operands);
auto type = transformed.complex().getType().cast<ComplexType>();
auto type = adaptor.complex().getType().cast<ComplexType>();
auto elementType = type.getElementType().cast<FloatType>();
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);

Value abs = b.create<complex::AbsOp>(elementType, transformed.complex());
Value abs = b.create<complex::AbsOp>(elementType, adaptor.complex());
Value resultReal = b.create<math::LogOp>(elementType, abs);
Value real = b.create<complex::ReOp>(elementType, transformed.complex());
Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
Expand All @@ -370,15 +356,14 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(complex::Log1pOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
complex::Log1pOp::Adaptor transformed(operands);
auto type = transformed.complex().getType().cast<ComplexType>();
auto type = adaptor.complex().getType().cast<ComplexType>();
auto elementType = type.getElementType().cast<FloatType>();
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);

Value real = b.create<complex::ReOp>(elementType, transformed.complex());
Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
Value one =
b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1));
Value realPlusOne = b.create<AddFOp>(real, one);
Expand All @@ -392,20 +377,19 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
using OpConversionPattern<complex::MulOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
complex::MulOp::Adaptor transformed(operands);
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto type = transformed.lhs().getType().cast<ComplexType>();
auto type = adaptor.lhs().getType().cast<ComplexType>();
auto elementType = type.getElementType().cast<FloatType>();

Value lhsReal = b.create<complex::ReOp>(elementType, transformed.lhs());
Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.lhs());
Value lhsRealAbs = b.create<AbsFOp>(lhsReal);
Value lhsImag = b.create<complex::ImOp>(elementType, transformed.lhs());
Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.lhs());
Value lhsImagAbs = b.create<AbsFOp>(lhsImag);
Value rhsReal = b.create<complex::ReOp>(elementType, transformed.rhs());
Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.rhs());
Value rhsRealAbs = b.create<AbsFOp>(rhsReal);
Value rhsImag = b.create<complex::ImOp>(elementType, transformed.rhs());
Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.rhs());
Value rhsImagAbs = b.create<AbsFOp>(rhsImag);

Value lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal);
Expand Down Expand Up @@ -530,17 +514,16 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
using OpConversionPattern<complex::NegOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(complex::NegOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
complex::NegOp::Adaptor transformed(operands);
auto loc = op.getLoc();
auto type = transformed.complex().getType().cast<ComplexType>();
auto type = adaptor.complex().getType().cast<ComplexType>();
auto elementType = type.getElementType().cast<FloatType>();

Value real =
rewriter.create<complex::ReOp>(loc, elementType, transformed.complex());
rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex());
Value imag =
rewriter.create<complex::ImOp>(loc, elementType, transformed.complex());
rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex());
Value negReal = rewriter.create<NegFOp>(loc, real);
Value negImag = rewriter.create<NegFOp>(loc, imag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
Expand All @@ -552,25 +535,23 @@ struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
using OpConversionPattern<complex::SignOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(complex::SignOp op, ArrayRef<Value> operands,
matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
complex::SignOp::Adaptor transformed(operands);
auto type = transformed.complex().getType().cast<ComplexType>();
auto type = adaptor.complex().getType().cast<ComplexType>();
auto elementType = type.getElementType().cast<FloatType>();
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);

Value real = b.create<complex::ReOp>(elementType, transformed.complex());
Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType));
Value realIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, real, zero);
Value imagIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, imag, zero);
Value isZero = b.create<AndOp>(realIsZero, imagIsZero);
auto abs = b.create<complex::AbsOp>(elementType, transformed.complex());
auto abs = b.create<complex::AbsOp>(elementType, adaptor.complex());
Value realSign = b.create<DivFOp>(real, abs);
Value imagSign = b.create<DivFOp>(imag, abs);
Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, transformed.complex(),
sign);
rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, adaptor.complex(), sign);
return success();
}
};
Expand Down
4 changes: 1 addition & 3 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
using namespace mlir;

LogicalResult
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
ArrayRef<Value> operands,
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
assert(operands.empty() && "func op is not expected to have operands");
Location loc = gpuFuncOp.getLoc();

SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
kernelAttributeName(kernelAttributeName) {}

LogicalResult
matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, ArrayRef<Value> operands,
matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;

private:
Expand All @@ -37,9 +37,9 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
using ConvertOpToLLVMPattern<gpu::ReturnOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(gpu::ReturnOp op, ArrayRef<Value> operands,
matchAndRewrite(gpu::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
return success();
}
};
Expand Down
79 changes: 36 additions & 43 deletions mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class ConvertHostRegisterOpToGpuRuntimeCallPattern

private:
LogicalResult
matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -209,7 +209,7 @@ class ConvertAllocOpToGpuRuntimeCallPattern

private:
LogicalResult
matchAndRewrite(gpu::AllocOp allocOp, ArrayRef<Value> operands,
matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -223,7 +223,7 @@ class ConvertDeallocOpToGpuRuntimeCallPattern

private:
LogicalResult
matchAndRewrite(gpu::DeallocOp deallocOp, ArrayRef<Value> operands,
matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -235,7 +235,7 @@ class ConvertAsyncYieldToGpuRuntimeCallPattern

private:
LogicalResult
matchAndRewrite(async::YieldOp yieldOp, ArrayRef<Value> operands,
matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -249,7 +249,7 @@ class ConvertWaitOpToGpuRuntimeCallPattern

private:
LogicalResult
matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -263,7 +263,7 @@ class ConvertWaitAsyncOpToGpuRuntimeCallPattern

private:
LogicalResult
matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -289,13 +289,13 @@ class ConvertLaunchFuncOpToGpuRuntimeCallPattern
gpuBinaryAnnotation(gpuBinaryAnnotation) {}

private:
Value generateParamsArray(gpu::LaunchFuncOp launchOp,
ArrayRef<Value> operands, OpBuilder &builder) const;
Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
OpBuilder &builder) const;
Value generateKernelNameConstant(StringRef moduleName, StringRef name,
Location loc, OpBuilder &builder) const;

LogicalResult
matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;

llvm::SmallString<32> gpuBinaryAnnotation;
Expand Down Expand Up @@ -323,7 +323,7 @@ class ConvertMemcpyOpToGpuRuntimeCallPattern

private:
LogicalResult
matchAndRewrite(gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands,
matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -337,7 +337,7 @@ class ConvertMemsetOpToGpuRuntimeCallPattern

private:
LogicalResult
matchAndRewrite(gpu::MemsetOp memsetOp, ArrayRef<Value> operands,
matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
Expand Down Expand Up @@ -398,10 +398,10 @@ isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
}

LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *op = hostRegisterOp.getOperation();
if (failed(areAllLLVMTypes(op, operands, rewriter)))
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();

Location loc = op->getLoc();
Expand All @@ -410,8 +410,8 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);

auto arguments = getTypeConverter()->promoteOperands(loc, op->getOperands(),
operands, rewriter);
auto arguments = getTypeConverter()->promoteOperands(
loc, op->getOperands(), adaptor.getOperands(), rewriter);
arguments.push_back(elementSize);
hostRegisterCallBuilder.create(loc, rewriter, arguments);

Expand All @@ -420,17 +420,16 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
}

LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::AllocOp allocOp, ArrayRef<Value> operands,
gpu::AllocOp allocOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType memRefType = allocOp.getType();

if (failed(areAllLLVMTypes(allocOp, operands, rewriter)) ||
if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
!isConvertibleAndHasIdentityMaps(memRefType) ||
failed(isAsyncWithOneDependency(rewriter, allocOp)))
return failure();

auto loc = allocOp.getLoc();
auto adaptor = gpu::AllocOpAdaptor(operands, allocOp->getAttrDictionary());

// Get shape of the memref as values: static sizes are constant
// values and dynamic sizes are passed to 'alloc' as operands.
Expand Down Expand Up @@ -462,16 +461,14 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
}

LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::DeallocOp deallocOp, ArrayRef<Value> operands,
gpu::DeallocOp deallocOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(deallocOp, operands, rewriter)) ||
if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
failed(isAsyncWithOneDependency(rewriter, deallocOp)))
return failure();

Location loc = deallocOp.getLoc();

auto adaptor =
gpu::DeallocOpAdaptor(operands, deallocOp->getAttrDictionary());
Value pointer =
MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
Expand All @@ -491,19 +488,19 @@ static bool isGpuAsyncTokenType(Value value) {
// are passed as events between them. For each !gpu.async.token operand, we
// create an event and record it on the stream.
LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
async::YieldOp yieldOp, ArrayRef<Value> operands,
async::YieldOp yieldOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType))
return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");

Location loc = yieldOp.getLoc();
SmallVector<Value, 4> newOperands(operands.begin(), operands.end());
SmallVector<Value, 4> newOperands(adaptor.getOperands());
llvm::SmallDenseSet<Value> streams;
for (auto &operand : yieldOp->getOpOperands()) {
if (!isGpuAsyncTokenType(operand.get()))
continue;
auto idx = operand.getOperandNumber();
auto stream = operands[idx];
auto stream = adaptor.getOperands()[idx];
auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
eventRecordCallBuilder.create(loc, rewriter, {event, stream});
newOperands[idx] = event;
Expand All @@ -530,14 +527,14 @@ static bool isDefinedByCallTo(Value value, StringRef functionName) {
// assumes that it is not used afterwards or elsewhere. Otherwise we will get a
// runtime error. Eventually, we should guarantee this property.
LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::WaitOp waitOp, ArrayRef<Value> operands,
gpu::WaitOp waitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (waitOp.asyncToken())
return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");

Location loc = waitOp.getLoc();

for (auto operand : operands) {
for (auto operand : adaptor.getOperands()) {
if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
// The converted operand's definition created a stream.
streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
Expand All @@ -560,7 +557,7 @@ LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
// Otherwise we will get a runtime error. Eventually, we should guarantee this
// property.
LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::WaitOp waitOp, ArrayRef<Value> operands,
gpu::WaitOp waitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (!waitOp.asyncToken())
return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
Expand All @@ -569,7 +566,8 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(

auto insertionPoint = rewriter.saveInsertionPoint();
SmallVector<Value, 1> events;
for (auto pair : llvm::zip(waitOp.asyncDependencies(), operands)) {
for (auto pair :
llvm::zip(waitOp.asyncDependencies(), adaptor.getOperands())) {
auto operand = std::get<1>(pair);
if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
// The converted operand's definition created a stream. Insert an event
Expand Down Expand Up @@ -611,13 +609,12 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
// llvm.store %fieldPtr, %elementPtr
// return %array
Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
OpBuilder &builder) const {
gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
auto loc = launchOp.getLoc();
auto numKernelOperands = launchOp.getNumKernelOperands();
auto arguments = getTypeConverter()->promoteOperands(
loc, launchOp.getOperands().take_back(numKernelOperands),
operands.take_back(numKernelOperands), builder);
adaptor.getOperands().take_back(numKernelOperands), builder);
auto numArguments = arguments.size();
SmallVector<Type, 4> argumentTypes;
argumentTypes.reserve(numArguments);
Expand Down Expand Up @@ -693,9 +690,9 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
// If the op is async, the stream corresponds to the (single) async dependency
// as well as the async token the op produces.
LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(launchOp, operands, rewriter)))
if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
return failure();

if (launchOp.asyncDependencies().size() > 1)
Expand Down Expand Up @@ -741,14 +738,12 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
loc, rewriter, {module.getResult(0), kernelName});
auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
rewriter.getI32IntegerAttr(0));
auto adaptor =
gpu::LaunchFuncOpAdaptor(operands, launchOp->getAttrDictionary());
Value stream =
adaptor.asyncDependencies().empty()
? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
: adaptor.asyncDependencies().front();
// Create array of pointers to kernel arguments.
auto kernelParams = generateParamsArray(launchOp, operands, rewriter);
auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
launchKernelCallBuilder.create(loc, rewriter,
{function.getResult(0), adaptor.gridSizeX(),
Expand All @@ -775,17 +770,16 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
}

LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands,
gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto memRefType = memcpyOp.src().getType().cast<MemRefType>();

if (failed(areAllLLVMTypes(memcpyOp, operands, rewriter)) ||
if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
!isConvertibleAndHasIdentityMaps(memRefType) ||
failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
return failure();

auto loc = memcpyOp.getLoc();
auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary());

MemRefDescriptor srcDesc(adaptor.src());
Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
Expand All @@ -812,17 +806,16 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
}

LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::MemsetOp memsetOp, ArrayRef<Value> operands,
gpu::MemsetOp memsetOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto memRefType = memsetOp.dst().getType().cast<MemRefType>();

if (failed(areAllLLVMTypes(memsetOp, operands, rewriter)) ||
if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
!isConvertibleAndHasIdentityMaps(memRefType) ||
failed(isAsyncWithOneDependency(rewriter, memsetOp)))
return failure();

auto loc = memsetOp.getLoc();
auto adaptor = gpu::MemsetOpAdaptor(operands, memsetOp->getAttrDictionary());

Type valueType = adaptor.value().getType();
if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct GPUIndexIntrinsicOpLowering : public ConvertOpToLLVMPattern<Op> {

// Convert the kernel arguments to an LLVM type, preserve the rest.
LogicalResult
matchAndRewrite(Op op, ArrayRef<Value> operands,
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
MLIRContext *context = rewriter.getContext();
Expand Down
16 changes: 7 additions & 9 deletions mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
f64Func(f64Func) {}

LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
using LLVM::LLVMFuncOp;

Expand All @@ -50,7 +50,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
"expected op with same operand and result types");

SmallVector<Value, 1> castedOperands;
for (Value operand : operands)
for (Value operand : adaptor.getOperands())
castedOperands.push_back(maybeCast(operand, rewriter));

Type resultType = castedOperands.front().getType();
Expand All @@ -64,13 +64,14 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
auto callOp = rewriter.create<LLVM::CallOp>(
op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands);

if (resultType == operands.front().getType()) {
if (resultType == adaptor.getOperands().front().getType()) {
rewriter.replaceOp(op, {callOp.getResult(0)});
return success();
}

Value truncated = rewriter.create<LLVM::FPTruncOp>(
op->getLoc(), operands.front().getType(), callOp.getResult(0));
op->getLoc(), adaptor.getOperands().front().getType(),
callOp.getResult(0));
rewriter.replaceOp(op, {truncated});
return success();
}
Expand All @@ -85,11 +86,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
}

Type getFunctionType(Type resultType, ArrayRef<Value> operands) const {
SmallVector<Type, 1> operandTypes;
for (Value operand : operands) {
operandTypes.push_back(operand.getType());
}
Type getFunctionType(Type resultType, ValueRange operands) const {
SmallVector<Type> operandTypes(operands.getTypes());
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}

Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,9 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
/// %shfl_pred = llvm.extractvalue %shfl[1 : index] :
/// !llvm<"{ float, i1 }">
LogicalResult
matchAndRewrite(gpu::ShuffleOp op, ArrayRef<Value> operands,
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
gpu::ShuffleOpAdaptor adaptor(operands);

auto valueTy = adaptor.value().getType();
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Expand Down
33 changes: 15 additions & 18 deletions mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ struct WmmaLoadOpToNVVMLowering

LogicalResult
matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
ArrayRef<Value> operands,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Operation *op = subgroupMmaLoadMatrixOp.getOperation();
if (failed(areAllLLVMTypes(op, operands, rewriter)))
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();

unsigned indexTypeBitwidth =
Expand All @@ -88,7 +88,6 @@ struct WmmaLoadOpToNVVMLowering

auto leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr();

gpu::SubgroupMmaLoadMatrixOpAdaptor adaptor(operands);
// MemRefDescriptor to extract alignedPtr and offset.
MemRefDescriptor promotedSrcOp(adaptor.srcMemref());

Expand Down Expand Up @@ -177,10 +176,10 @@ struct WmmaStoreOpToNVVMLowering

LogicalResult
matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
ArrayRef<Value> operands,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Operation *op = subgroupMmaStoreMatrixOp.getOperation();
if (failed(areAllLLVMTypes(op, operands, rewriter)))
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();

unsigned indexTypeBitwidth =
Expand All @@ -194,7 +193,6 @@ struct WmmaStoreOpToNVVMLowering

Location loc = op->getLoc();

gpu::SubgroupMmaStoreMatrixOpAdaptor adaptor(operands);
// MemRefDescriptor to extract alignedPtr and offset.
MemRefDescriptor promotedDstOp(adaptor.dstMemref());

Expand Down Expand Up @@ -282,10 +280,10 @@ struct WmmaMmaOpToNVVMLowering

LogicalResult
matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
ArrayRef<Value> operands,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Operation *op = subgroupMmaComputeOp.getOperation();
if (failed(areAllLLVMTypes(op, operands, rewriter)))
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();

Location loc = op->getLoc();
Expand Down Expand Up @@ -317,17 +315,16 @@ struct WmmaMmaOpToNVVMLowering
subgroupMmaComputeOp.opC().getType().cast<gpu::MMAMatrixType>();
ArrayRef<int64_t> cTypeShape = cType.getShape();

gpu::SubgroupMmaComputeOpAdaptor transformedOperands(operands);
unpackOp(transformedOperands.opA());
unpackOp(transformedOperands.opB());
unpackOp(transformedOperands.opC());
unpackOp(adaptor.opA());
unpackOp(adaptor.opB());
unpackOp(adaptor.opC());

if (cType.getElementType().isF16()) {
if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 &&
bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) {
// Create nvvm.wmma.mma op.
rewriter.replaceOpWithNewOp<NVVM::WMMAMmaF16F16M16N16K16Op>(
op, transformedOperands.opC().getType(), unpackedOps);
op, adaptor.opC().getType(), unpackedOps);

return success();
}
Expand All @@ -338,7 +335,7 @@ struct WmmaMmaOpToNVVMLowering
bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) {
// Create nvvm.wmma.mma op.
rewriter.replaceOpWithNewOp<NVVM::WMMAMmaF32F32M16N16K16Op>(
op, transformedOperands.opC().getType(), unpackedOps);
op, adaptor.opC().getType(), unpackedOps);

return success();
}
Expand All @@ -356,13 +353,13 @@ struct WmmaConstantOpToNVVMLowering

LogicalResult
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
ArrayRef<Value> operands,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(), operands,
rewriter)))
if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(),
adaptor.getOperands(), rewriter)))
return failure();
Location loc = subgroupMmaConstantOp.getLoc();
Value cst = operands[0];
Value cst = adaptor.getOperands()[0];
LLVM::LLVMStructType type = convertMMAToLLVMType(
subgroupMmaConstantOp.getType().cast<gpu::MMAMatrixType>());
// If the element type is a vector create a vector from the operand.
Expand Down
28 changes: 14 additions & 14 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
using OpConversionPattern<SourceOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -45,7 +45,7 @@ class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
using OpConversionPattern<SourceOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -58,7 +58,7 @@ class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
using OpConversionPattern<gpu::BlockDimOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::BlockDimOp op, ArrayRef<Value> operands,
matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -68,7 +68,7 @@ class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;

private:
Expand All @@ -81,7 +81,7 @@ class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -91,7 +91,7 @@ class GPUModuleEndConversion final
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::ModuleEndOp endOp, ArrayRef<Value> operands,
matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(endOp);
return success();
Expand All @@ -105,7 +105,7 @@ class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef<Value> operands,
matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -129,7 +129,7 @@ static Optional<int32_t> getLaunchConfigIndex(Operation *op) {

template <typename SourceOp, spirv::BuiltIn builtin>
LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
SourceOp op, ArrayRef<Value> operands,
SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto index = getLaunchConfigIndex(op);
if (!index)
Expand All @@ -150,7 +150,7 @@ LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
template <typename SourceOp, spirv::BuiltIn builtin>
LogicalResult
SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
SourceOp op, ArrayRef<Value> operands,
SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
auto indexType = typeConverter->getIndexType();
Expand All @@ -162,7 +162,7 @@ SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
}

LogicalResult WorkGroupSizeConversion::matchAndRewrite(
gpu::BlockDimOp op, ArrayRef<Value> operands,
gpu::BlockDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto index = getLaunchConfigIndex(op);
if (!index)
Expand Down Expand Up @@ -264,7 +264,7 @@ getDefaultABIAttrs(MLIRContext *context, gpu::GPUFuncOp funcOp,
}

LogicalResult GPUFuncOpConversion::matchAndRewrite(
gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (!gpu::GPUDialect::isKernel(funcOp))
return failure();
Expand Down Expand Up @@ -306,7 +306,7 @@ LogicalResult GPUFuncOpConversion::matchAndRewrite(
//===----------------------------------------------------------------------===//

LogicalResult GPUModuleConversion::matchAndRewrite(
gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnvOrDefault(moduleOp);
spirv::AddressingModel addressingModel = spirv::getAddressingModel(targetEnv);
Expand Down Expand Up @@ -336,9 +336,9 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
//===----------------------------------------------------------------------===//

LogicalResult GPUReturnOpConversion::matchAndRewrite(
gpu::ReturnOp returnOp, ArrayRef<Value> operands,
gpu::ReturnOp returnOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (!operands.empty())
if (!adaptor.getOperands().empty())
return failure();

rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
Expand Down
7 changes: 3 additions & 4 deletions mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,14 @@ class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands,
matchAndRewrite(RangeOp rangeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto rangeDescriptorTy = convertRangeType(
rangeOp.getType().cast<RangeType>(), *getTypeConverter());

ImplicitLocOpBuilder b(rangeOp->getLoc(), rewriter);

// Fill in an aggregate value of the descriptor.
RangeOpAdaptor adaptor(operands);
Value desc = b.create<LLVM::UndefOp>(rangeDescriptorTy);
desc = b.create<LLVM::InsertValueOp>(desc, adaptor.min(),
rewriter.getI64ArrayAttr(0));
Expand All @@ -101,9 +100,9 @@ class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands,
matchAndRewrite(linalg::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
return success();
}
};
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct SingleWorkgroupReduction final
matchAsPerformingReduction(linalg::GenericOp genericOp);

LogicalResult
matchAndRewrite(linalg::GenericOp genericOp, ArrayRef<Value> operands,
matchAndRewrite(linalg::GenericOp genericOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand Down Expand Up @@ -109,7 +109,7 @@ SingleWorkgroupReduction::matchAsPerformingReduction(
}

LogicalResult SingleWorkgroupReduction::matchAndRewrite(
linalg::GenericOp genericOp, ArrayRef<Value> operands,
linalg::GenericOp genericOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Operation *op = genericOp.getOperation();
auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
Expand All @@ -134,7 +134,8 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
// TODO: Query the target environment to make sure the current
// workload fits in a local workgroup.

Value convertedInput = operands[0], convertedOutput = operands[1];
Value convertedInput = adaptor.getOperands()[0];
Value convertedOutput = adaptor.getOperands()[1];
Location loc = genericOp.getLoc();

auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
Expand Down
27 changes: 12 additions & 15 deletions mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(math::ExpM1Op op, ArrayRef<Value> operands,
matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
math::ExpM1Op::Adaptor transformed(operands);
auto operandType = transformed.operand().getType();
auto operandType = adaptor.operand().getType();

if (!operandType || !LLVM::isCompatibleType(operandType))
return failure();
Expand All @@ -56,7 +55,7 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
} else {
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
}
auto exp = rewriter.create<LLVM::ExpOp>(loc, transformed.operand());
auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.operand());
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
return success();
}
Expand All @@ -66,7 +65,7 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
return rewriter.notifyMatchFailure(op, "expected vector result type");

return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), operands, *getTypeConverter(),
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
Expand All @@ -88,10 +87,9 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(math::Log1pOp op, ArrayRef<Value> operands,
matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
math::Log1pOp::Adaptor transformed(operands);
auto operandType = transformed.operand().getType();
auto operandType = adaptor.operand().getType();

if (!operandType || !LLVM::isCompatibleType(operandType))
return rewriter.notifyMatchFailure(op, "unsupported operand type");
Expand All @@ -111,7 +109,7 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
: rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);

auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
transformed.operand());
adaptor.operand());
rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
return success();
}
Expand All @@ -121,7 +119,7 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
return rewriter.notifyMatchFailure(op, "expected vector result type");

return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), operands, *getTypeConverter(),
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
Expand All @@ -143,10 +141,9 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(math::RsqrtOp op, ArrayRef<Value> operands,
matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
math::RsqrtOp::Adaptor transformed(operands);
auto operandType = transformed.operand().getType();
auto operandType = adaptor.operand().getType();

if (!operandType || !LLVM::isCompatibleType(operandType))
return failure();
Expand All @@ -165,7 +162,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
} else {
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
}
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.operand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
return success();
}
Expand All @@ -175,7 +172,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
return failure();

return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), operands, *getTypeConverter(),
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
Expand Down
14 changes: 8 additions & 6 deletions mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class UnaryAndBinaryOpPattern final : public OpConversionPattern<StdOp> {
using OpConversionPattern<StdOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(operands.size() <= 2);
assert(adaptor.getOperands().size() <= 2);
auto dstType = this->getTypeConverter()->convertType(operation.getType());
if (!dstType)
return failure();
Expand All @@ -48,7 +48,8 @@ class UnaryAndBinaryOpPattern final : public OpConversionPattern<StdOp> {
return operation.emitError(
"bitwidth emulation is not implemented yet on unsigned op");
}
rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands);
rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType,
adaptor.getOperands());
return success();
}
};
Expand All @@ -62,14 +63,15 @@ class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
using OpConversionPattern<math::Log1pOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(math::Log1pOp operation, ArrayRef<Value> operands,
matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(operands.size() == 1);
assert(adaptor.getOperands().size() == 1);
Location loc = operation.getLoc();
auto type =
this->getTypeConverter()->convertType(operation.operand().getType());
auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
auto onePlus = rewriter.create<spirv::FAddOp>(loc, one, operands[0]);
auto onePlus =
rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]);
rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus);
return success();
}
Expand Down
126 changes: 54 additions & 72 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Large diffs are not rendered by default.

57 changes: 24 additions & 33 deletions mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
using OpConversionPattern<memref::AllocOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::AllocOp operation, ArrayRef<Value> operands,
matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -169,7 +169,7 @@ class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::DeallocOp operation, ArrayRef<Value> operands,
matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -179,7 +179,7 @@ class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -189,7 +189,7 @@ class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -199,7 +199,7 @@ class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
using OpConversionPattern<memref::StoreOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::StoreOp storeOp, ArrayRef<Value> operands,
matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -209,7 +209,7 @@ class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
using OpConversionPattern<memref::StoreOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::StoreOp storeOp, ArrayRef<Value> operands,
matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -220,8 +220,7 @@ class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
//===----------------------------------------------------------------------===//

LogicalResult
AllocOpPattern::matchAndRewrite(memref::AllocOp operation,
ArrayRef<Value> operands,
AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType allocType = operation.getType();
if (!isAllocationSupported(allocType))
Expand Down Expand Up @@ -260,7 +259,7 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation,

LogicalResult
DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
ArrayRef<Value> operands,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
if (!isAllocationSupported(deallocType))
Expand All @@ -274,19 +273,17 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
//===----------------------------------------------------------------------===//

LogicalResult
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
ArrayRef<Value> operands,
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
memref::LoadOpAdaptor loadOperands(operands);
auto loc = loadOp.getLoc();
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
if (!memrefType.getElementType().isSignlessInteger())
return failure();

auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
spirv::AccessChainOp accessChainOp =
spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
loadOperands.indices(), loc, rewriter);
spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(),
adaptor.indices(), loc, rewriter);

if (!accessChainOp)
return failure();
Expand Down Expand Up @@ -372,15 +369,14 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
}

LogicalResult
LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
memref::LoadOpAdaptor loadOperands(operands);
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
if (memrefType.getElementType().isSignlessInteger())
return failure();
auto loadPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType,
loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
adaptor.indices(), loadOp.getLoc(), rewriter);

if (!loadPtr)
return failure();
Expand All @@ -390,19 +386,17 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
}

LogicalResult
IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
ArrayRef<Value> operands,
IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
memref::StoreOpAdaptor storeOperands(operands);
auto memrefType = storeOp.memref().getType().cast<MemRefType>();
if (!memrefType.getElementType().isSignlessInteger())
return failure();

auto loc = storeOp.getLoc();
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
spirv::AccessChainOp accessChainOp =
spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
storeOperands.indices(), loc, rewriter);
spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(),
adaptor.indices(), loc, rewriter);

if (!accessChainOp)
return failure();
Expand All @@ -427,7 +421,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
assert(dstBits % srcBits == 0);

if (srcBits == dstBits) {
Value storeVal = storeOperands.value();
Value storeVal = adaptor.value();
if (isBool)
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
Expand Down Expand Up @@ -458,7 +452,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);

Value storeVal = storeOperands.value();
Value storeVal = adaptor.value();
if (isBool)
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
Expand Down Expand Up @@ -487,23 +481,20 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
}

LogicalResult
StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
ArrayRef<Value> operands,
StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
memref::StoreOpAdaptor storeOperands(operands);
auto memrefType = storeOp.memref().getType().cast<MemRefType>();
if (memrefType.getElementType().isSignlessInteger())
return failure();
auto storePtr =
spirv::getElementPtr(*getTypeConverter<SPIRVTypeConverter>(), memrefType,
storeOperands.memref(), storeOperands.indices(),
storeOp.getLoc(), rewriter);
auto storePtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
adaptor.indices(), storeOp.getLoc(), rewriter);

if (!storePtr)
return failure();

rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
storeOperands.value());
adaptor.value());
return success();
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,16 @@ class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Op op, ArrayRef<Value> operands,
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &builder) const override {
Location loc = op.getLoc();
TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();

unsigned numDataOperand = op.getNumDataOperands();

// Keep the non data operands without modification.
auto nonDataOperands =
operands.take_front(operands.size() - numDataOperand);
auto nonDataOperands = adaptor.getOperands().take_front(
adaptor.getOperands().size() - numDataOperand);
SmallVector<Value> convertedOperands;
convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end());

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(OpType curOp, ArrayRef<Value> operands,
matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = rewriter.create<OpType>(curOp.getLoc(), TypeRange(), operands,
curOp->getAttrs());
auto newOp = rewriter.create<OpType>(
curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
newOp.region().end());
if (failed(rewriter.convertRegionTypes(&newOp.region(),
Expand Down
33 changes: 16 additions & 17 deletions mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> {
using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;

LogicalResult
matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -95,7 +95,7 @@ class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> {
using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;

LogicalResult
matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -104,7 +104,7 @@ class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;

LogicalResult
matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
Expand Down Expand Up @@ -146,14 +146,13 @@ static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
//===----------------------------------------------------------------------===//

LogicalResult
ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// scf::ForOp can be lowered to the structured control flow represented by
// spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
// latch and the merge block the exit block. The resulting spirv::LoopOp has a
// single back edge from the continue to header block, and a single exit from
// header to merge.
scf::ForOpAdaptor forOperands(operands);
auto loc = forOp.getLoc();
auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
loopOp.addEntryAndMergeBlock();
Expand All @@ -165,9 +164,8 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);

// Create the new induction variable to use.
BlockArgument newIndVar =
header->addArgument(forOperands.lowerBound().getType());
for (Value arg : forOperands.initArgs())
BlockArgument newIndVar = header->addArgument(adaptor.lowerBound().getType());
for (Value arg : adaptor.initArgs())
header->addArgument(arg.getType());
Block *body = forOp.getBody();

Expand All @@ -187,8 +185,8 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(),
std::next(loopOp.body().begin(), 2));

SmallVector<Value, 8> args(1, forOperands.lowerBound());
args.append(forOperands.initArgs().begin(), forOperands.initArgs().end());
SmallVector<Value, 8> args(1, adaptor.lowerBound());
args.append(adaptor.initArgs().begin(), adaptor.initArgs().end());
// Branch into it from the entry.
rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
rewriter.create<spirv::BranchOp>(loc, header, args);
Expand All @@ -197,7 +195,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
rewriter.setInsertionPointToEnd(header);
auto *mergeBlock = loopOp.getMergeBlock();
auto cmpOp = rewriter.create<spirv::SLessThanOp>(
loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
loc, rewriter.getI1Type(), newIndVar, adaptor.upperBound());

rewriter.create<spirv::BranchConditionalOp>(
loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
Expand All @@ -209,15 +207,15 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,

// Add the step to the induction variable and branch to the header.
Value updatedIndVar = rewriter.create<spirv::IAddOp>(
loc, newIndVar.getType(), newIndVar, forOperands.step());
loc, newIndVar.getType(), newIndVar, adaptor.step());
rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);

// Infer the return types from the init operands. Vector type may get
// converted to CooperativeMatrix or to Vector type, to avoid having complex
// extra logic to figure out the right type we just infer it from the Init
// operands.
SmallVector<Type, 8> initTypes;
for (auto arg : forOperands.initArgs())
for (auto arg : adaptor.initArgs())
initTypes.push_back(arg.getType());
replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes);
return success();
Expand All @@ -228,12 +226,11 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
//===----------------------------------------------------------------------===//

LogicalResult
IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// When lowering `scf::IfOp` we explicitly create a selection header block
// before the control flow diverges and a merge block where control flow
// subsequently converges.
scf::IfOpAdaptor ifOperands(operands);
auto loc = ifOp.getLoc();

// Create `spv.selection` operation, selection header block and merge block.
Expand Down Expand Up @@ -267,7 +264,7 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,

// Create a `spv.BranchConditional` operation for selection header block.
rewriter.setInsertionPointToEnd(selectionHeaderBlock);
rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(),
rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.condition(),
thenBlock, ArrayRef<Value>(),
elseBlock, ArrayRef<Value>());

Expand All @@ -289,8 +286,10 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
/// parent region. For loops we also need to update the branch looping back to
/// the header with the loop carried values.
LogicalResult TerminatorOpConversion::matchAndRewrite(
scf::YieldOp terminatorOp, ArrayRef<Value> operands,
scf::YieldOp terminatorOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
ValueRange operands = adaptor.getOperands();

// If the region is return values, store each value into the associated
// VariableOp created during lowering of the parent region.
if (!operands.empty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *op = launchOp.getOperation();
MLIRContext *context = rewriter.getContext();
Expand Down Expand Up @@ -206,7 +206,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
Location loc = launchOp.getLoc();
SmallVector<CopyInfo, 4> copyInfo;
auto numKernelOperands = launchOp.getNumKernelOperands();
auto kernelOperands = operands.take_back(numKernelOperands);
auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
for (auto operand : llvm::enumerate(kernelOperands)) {
// Check if the kernel's operand is a ranked memref.
auto memRefType = launchOp.getKernelOperand(operand.index())
Expand Down
96 changes: 47 additions & 49 deletions mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp

Large diffs are not rendered by default.

111 changes: 46 additions & 65 deletions mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp

Large diffs are not rendered by default.

101 changes: 44 additions & 57 deletions mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
: FuncOpConversionBase(converter) {}

LogicalResult
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
if (!newFuncOp)
Expand All @@ -319,7 +319,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
using FuncOpConversionBase::FuncOpConversionBase;

LogicalResult
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// TODO: bare ptr conversion could be handled by argument materialization
Expand Down Expand Up @@ -442,10 +442,9 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
matchAndRewrite(AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
AssertOp::Adaptor transformed(operands);

// Insert the `abort` declaration if necessary.
auto module = op->getParentOfType<ModuleOp>();
Expand All @@ -471,7 +470,7 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
// Generate assertion test.
rewriter.setInsertionPointToEnd(opBlock);
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, transformed.arg(), continuationBlock, failureBlock);
op, adaptor.arg(), continuationBlock, failureBlock);

return success();
}
Expand All @@ -481,7 +480,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// If constant refers to a function, convert it to "addressof".
if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
Expand All @@ -506,8 +505,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
op, "referring to a symbol outside of the current module");

return LLVM::detail::oneToOneRewrite(
op, LLVM::ConstantOp::getOperationName(), operands, *getTypeConverter(),
rewriter);
op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
*getTypeConverter(), rewriter);
}
};

Expand All @@ -520,10 +519,8 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
using Base = ConvertOpToLLVMPattern<CallOpType>;

LogicalResult
matchAndRewrite(CallOpType callOp, ArrayRef<Value> operands,
matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
typename CallOpType::Adaptor transformed(operands);

// Pack the result types into a struct.
Type packedResult = nullptr;
unsigned numResults = callOp.getNumResults();
Expand All @@ -536,8 +533,8 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
}

auto promoted = this->getTypeConverter()->promoteOperands(
callOp.getLoc(), /*opOperands=*/callOp->getOperands(), operands,
rewriter);
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
adaptor.getOperands(), rewriter);
auto newOp = rewriter.create<LLVM::CallOp>(
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
promoted, callOp->getAttrs());
Expand Down Expand Up @@ -591,22 +588,21 @@ struct UnrealizedConversionCastOpLowering
UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(UnrealizedConversionCastOp op, ArrayRef<Value> operands,
matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
UnrealizedConversionCastOp::Adaptor transformed(operands);
SmallVector<Type> convertedTypes;
if (succeeded(typeConverter->convertTypes(op.outputs().getTypes(),
convertedTypes)) &&
convertedTypes == transformed.inputs().getTypes()) {
rewriter.replaceOp(op, transformed.inputs());
convertedTypes == adaptor.inputs().getTypes()) {
rewriter.replaceOp(op, adaptor.inputs());
return success();
}

convertedTypes.clear();
if (succeeded(typeConverter->convertTypes(transformed.inputs().getTypes(),
if (succeeded(typeConverter->convertTypes(adaptor.inputs().getTypes(),
convertedTypes)) &&
convertedTypes == op.outputs().getType()) {
rewriter.replaceOp(op, transformed.inputs());
rewriter.replaceOp(op, adaptor.inputs());
return success();
}
return failure();
Expand All @@ -617,12 +613,12 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(RankOp op, ArrayRef<Value> operands,
matchAndRewrite(RankOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type operandType = op.memrefOrTensor().getType();
if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor());
UnrankedMemRefDescriptor desc(adaptor.memrefOrTensor());
rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
return success();
}
Expand Down Expand Up @@ -658,29 +654,27 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
using ConvertOpToLLVMPattern<IndexCastOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(IndexCastOp indexCastOp, ArrayRef<Value> operands,
matchAndRewrite(IndexCastOp indexCastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
IndexCastOpAdaptor transformed(operands);

auto targetType =
typeConverter->convertType(indexCastOp.getResult().getType());
auto targetElementType =
typeConverter
->convertType(getElementTypeOrSelf(indexCastOp.getResult()))
.cast<IntegerType>();
auto sourceElementType =
getElementTypeOrSelf(transformed.in()).cast<IntegerType>();
getElementTypeOrSelf(adaptor.in()).cast<IntegerType>();
unsigned targetBits = targetElementType.getWidth();
unsigned sourceBits = sourceElementType.getWidth();

if (targetBits == sourceBits)
rewriter.replaceOp(indexCastOp, transformed.in());
rewriter.replaceOp(indexCastOp, adaptor.in());
else if (targetBits < sourceBits)
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(indexCastOp, targetType,
transformed.in());
adaptor.in());
else
rewriter.replaceOpWithNewOp<LLVM::SExtOp>(indexCastOp, targetType,
transformed.in());
adaptor.in());
return success();
}
};
Expand All @@ -696,18 +690,17 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
using ConvertOpToLLVMPattern<CmpIOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(CmpIOp cmpiOp, ArrayRef<Value> operands,
matchAndRewrite(CmpIOp cmpiOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
CmpIOpAdaptor transformed(operands);
auto operandType = transformed.lhs().getType();
auto operandType = adaptor.lhs().getType();
auto resultType = cmpiOp.getResult().getType();

// Handle the scalar and 1D vector cases.
if (!operandType.isa<LLVM::LLVMArrayType>()) {
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
cmpiOp, typeConverter->convertType(resultType),
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
transformed.lhs(), transformed.rhs());
adaptor.lhs(), adaptor.rhs());
return success();
}

Expand All @@ -716,13 +709,13 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type");

return LLVM::detail::handleMultidimensionalVectors(
cmpiOp.getOperation(), operands, *getTypeConverter(),
cmpiOp.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
CmpIOpAdaptor transformed(operands);
CmpIOpAdaptor adaptor(operands);
return rewriter.create<LLVM::ICmpOp>(
cmpiOp.getLoc(), llvm1DVectorTy,
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
transformed.lhs(), transformed.rhs());
adaptor.lhs(), adaptor.rhs());
},
rewriter);

Expand All @@ -734,18 +727,17 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
using ConvertOpToLLVMPattern<CmpFOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(CmpFOp cmpfOp, ArrayRef<Value> operands,
matchAndRewrite(CmpFOp cmpfOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
CmpFOpAdaptor transformed(operands);
auto operandType = transformed.lhs().getType();
auto operandType = adaptor.lhs().getType();
auto resultType = cmpfOp.getResult().getType();

// Handle the scalar and 1D vector cases.
if (!operandType.isa<LLVM::LLVMArrayType>()) {
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
cmpfOp, typeConverter->convertType(resultType),
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
transformed.lhs(), transformed.rhs());
adaptor.lhs(), adaptor.rhs());
return success();
}

Expand All @@ -754,13 +746,13 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type");

return LLVM::detail::handleMultidimensionalVectors(
cmpfOp.getOperation(), operands, *getTypeConverter(),
cmpfOp.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
CmpFOpAdaptor transformed(operands);
CmpFOpAdaptor adaptor(operands);
return rewriter.create<LLVM::FCmpOp>(
cmpfOp.getLoc(), llvm1DVectorTy,
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
transformed.lhs(), transformed.rhs());
adaptor.lhs(), adaptor.rhs());
},
rewriter);
}
Expand All @@ -774,10 +766,10 @@ struct OneToOneLLVMTerminatorLowering
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;

LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TargetOp>(op, operands, op->getSuccessors(),
op->getAttrs());
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
op->getSuccessors(), op->getAttrs());
return success();
}
};
Expand All @@ -792,7 +784,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
unsigned numArguments = op.getNumOperands();
Expand All @@ -801,7 +793,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
// For the bare-ptr calling convention, extract the aligned pointer to
// be returned from the memref descriptor.
for (auto it : llvm::zip(op->getOperands(), operands)) {
for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
Type oldTy = std::get<0>(it).getType();
Value newOperand = std::get<1>(it);
if (oldTy.isa<MemRefType>()) {
Expand All @@ -815,7 +807,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
updatedOperands.push_back(newOperand);
}
} else {
updatedOperands = llvm::to_vector<4>(operands);
updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
(void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
updatedOperands,
/*toDynamic=*/true);
Expand Down Expand Up @@ -870,14 +862,12 @@ struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
if (!resultType || resultType.getRank() != 1)
return failure();

SplatOp::Adaptor adaptor(operands);

// First insert it into an undef vector so we can shuffle it.
auto vectorType = typeConverter->convertType(splatOp.getType());
Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
Expand Down Expand Up @@ -907,9 +897,8 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SplatOp::Adaptor adaptor(operands);
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
if (!resultType || resultType.getRank() == 1)
return failure();
Expand Down Expand Up @@ -984,14 +973,13 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
using Base::Base;

LogicalResult
matchAndRewrite(AtomicRMWOp atomicOp, ArrayRef<Value> operands,
matchAndRewrite(AtomicRMWOp atomicOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(match(atomicOp)))
return failure();
auto maybeKind = matchSimpleAtomicOp(atomicOp);
if (!maybeKind)
return failure();
AtomicRMWOp::Adaptor adaptor(operands);
auto resultType = adaptor.value().getType();
auto memRefType = atomicOp.getMemRefType();
auto dataPtr =
Expand Down Expand Up @@ -1036,11 +1024,10 @@ struct GenericAtomicRMWOpLowering
using Base::Base;

LogicalResult
matchAndRewrite(GenericAtomicRMWOp atomicOp, ArrayRef<Value> operands,
matchAndRewrite(GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto loc = atomicOp.getLoc();
GenericAtomicRMWOp::Adaptor adaptor(operands);
Type valueType = typeConverter->convertType(atomicOp.getResult().getType());

// Split the block into initial, loop, and ending parts.
Expand Down
Loading