diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp index 377cc44392028..b2e0237659963 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp @@ -36,10 +36,10 @@ namespace { /// Base class for passes converting transformational intrinsic operations into /// runtime calls template -class HlfirIntrinsicConversion : public mlir::OpRewritePattern { +class HlfirIntrinsicConversion : public mlir::OpConversionPattern { public: explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx) - : mlir::OpRewritePattern{ctx} { + : mlir::OpConversionPattern{ctx} { // required for cases where intrinsics are chained together e.g. // matmul(matmul(a, b), c) // because converting the inner operation then invalidates the @@ -145,7 +145,7 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern { void processReturnValue(mlir::Operation *op, const fir::ExtendedValue &resultExv, bool mustBeFreed, fir::FirOpBuilder &builder, - mlir::PatternRewriter &rewriter) const { + mlir::ConversionPatternRewriter &rewriter) const { mlir::Location loc = op->getLoc(); mlir::Value firBase = fir::getBase(resultExv); @@ -176,13 +176,9 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern { rewriter.eraseOp(use); } } - // TODO: This entire pass should be a greedy pattern rewrite or a manual - // IR traversal. A dialect conversion cannot be used here because - // `replaceAllUsesWith` is not supported. Similarly, `replaceOp` is not - // suitable because "op->getResult(0)" and "base" can have different types. - // In such a case, the dialect conversion will attempt to convert the type, - // but no type converter is specified in this pass. Also note that all - // patterns in this pass are actually rewrite patterns. + // the types might not match exactly (but are safe) + // e.g. !hlfir.expr vs !hlfir.expr<2xi32> + // TODO: is this allowed by MLIR? op->getResult(0).replaceAllUsesWith(base); rewriter.replaceOp(op, base); } @@ -203,48 +199,53 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion { typename HlfirIntrinsicConversion::IntrinsicArgument; using HlfirIntrinsicConversion::lowerArguments; using HlfirIntrinsicConversion::processReturnValue; + using Adaptor = typename OP::Adaptor; protected: - auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType, + auto buildNumericalArgs(mlir::Operation *operation, Adaptor adaptor, + mlir::Type i32, mlir::Type logicalType, mlir::PatternRewriter &rewriter, std::string opName) const { llvm::SmallVector inArgs; - inArgs.push_back({operation.getArray(), operation.getArray().getType()}); - inArgs.push_back({operation.getDim(), i32}); - inArgs.push_back({operation.getMask(), logicalType}); + inArgs.push_back({adaptor.getArray(), adaptor.getArray().getType()}); + inArgs.push_back({adaptor.getDim(), i32}); + inArgs.push_back({adaptor.getMask(), logicalType}); auto *argLowering = fir::getIntrinsicArgumentLowering(opName); return lowerArguments(operation, inArgs, rewriter, argLowering); }; - auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType, + auto buildMinMaxLocArgs(mlir::Operation *operation, Adaptor adaptor, + mlir::Type i32, mlir::Type logicalType, mlir::PatternRewriter &rewriter, std::string opName, fir::FirOpBuilder builder) const { llvm::SmallVector inArgs; - inArgs.push_back({operation.getArray(), operation.getArray().getType()}); - inArgs.push_back({operation.getDim(), i32}); - inArgs.push_back({operation.getMask(), logicalType}); + inArgs.push_back({adaptor.getArray(), adaptor.getArray().getType()}); + inArgs.push_back({adaptor.getDim(), i32}); + inArgs.push_back({adaptor.getMask(), logicalType}); mlir::Value kind = builder.createIntegerConstant( - operation->getLoc(), i32, getKindForType(operation.getType())); + operation->getLoc(), i32, + getKindForType(operation->getResult(0).getType())); inArgs.push_back({kind, i32}); - inArgs.push_back({operation.getBack(), i32}); + inArgs.push_back({adaptor.getBack(), i32}); auto *argLowering = fir::getIntrinsicArgumentLowering(opName); return lowerArguments(operation, inArgs, rewriter, argLowering); }; - auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType, + auto buildLogicalArgs(mlir::Operation *operation, Adaptor adaptor, + mlir::Type i32, mlir::Type logicalType, mlir::PatternRewriter &rewriter, std::string opName) const { llvm::SmallVector inArgs; - inArgs.push_back({operation.getMask(), logicalType}); - inArgs.push_back({operation.getDim(), i32}); + inArgs.push_back({adaptor.getMask(), logicalType}); + inArgs.push_back({adaptor.getDim(), i32}); auto *argLowering = fir::getIntrinsicArgumentLowering(opName); return lowerArguments(operation, inArgs, rewriter, argLowering); }; public: mlir::LogicalResult - matchAndRewrite(OP operation, - mlir::PatternRewriter &rewriter) const override { + matchAndRewrite(OP operation, Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { std::string opName; if constexpr (std::is_same_v) { opName = "sum"; @@ -279,13 +280,15 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion { std::is_same_v || std::is_same_v || std::is_same_v) { - args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName); + args = buildNumericalArgs(operation, adaptor, i32, logicalType, rewriter, + opName); } else if constexpr (std::is_same_v || std::is_same_v) { - args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName, - builder); + args = buildMinMaxLocArgs(operation, adaptor, i32, logicalType, rewriter, + opName, builder); } else { - args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName); + args = buildLogicalArgs(operation, adaptor, i32, logicalType, rewriter, + opName); } mlir::Type scalarResultType = @@ -319,8 +322,8 @@ struct CountOpConversion : public HlfirIntrinsicConversion { using HlfirIntrinsicConversion::HlfirIntrinsicConversion; mlir::LogicalResult - matchAndRewrite(hlfir::CountOp count, - mlir::PatternRewriter &rewriter) const override { + matchAndRewrite(hlfir::CountOp count, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { fir::FirOpBuilder builder{rewriter, count.getOperation()}; const mlir::Location &loc = count->getLoc(); @@ -329,8 +332,8 @@ struct CountOpConversion : public HlfirIntrinsicConversion { builder.getContext(), builder.getKindMap().defaultLogicalKind()); llvm::SmallVector inArgs; - inArgs.push_back({count.getMask(), logicalType}); - inArgs.push_back({count.getDim(), i32}); + inArgs.push_back({adaptor.getMask(), logicalType}); + inArgs.push_back({adaptor.getDim(), i32}); mlir::Value kind = builder.createIntegerConstant( count->getLoc(), i32, getKindForType(count.getType())); inArgs.push_back({kind, i32}); @@ -353,13 +356,13 @@ struct MatmulOpConversion : public HlfirIntrinsicConversion { using HlfirIntrinsicConversion::HlfirIntrinsicConversion; mlir::LogicalResult - matchAndRewrite(hlfir::MatmulOp matmul, - mlir::PatternRewriter &rewriter) const override { + matchAndRewrite(hlfir::MatmulOp matmul, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { fir::FirOpBuilder builder{rewriter, matmul.getOperation()}; const mlir::Location &loc = matmul->getLoc(); - mlir::Value lhs = matmul.getLhs(); - mlir::Value rhs = matmul.getRhs(); + mlir::Value lhs = adaptor.getLhs(); + mlir::Value rhs = adaptor.getRhs(); llvm::SmallVector inArgs; inArgs.push_back({lhs, lhs.getType()}); inArgs.push_back({rhs, rhs.getType()}); @@ -384,13 +387,13 @@ struct DotProductOpConversion using HlfirIntrinsicConversion::HlfirIntrinsicConversion; mlir::LogicalResult - matchAndRewrite(hlfir::DotProductOp dotProduct, - mlir::PatternRewriter &rewriter) const override { + matchAndRewrite(hlfir::DotProductOp dotProduct, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()}; const mlir::Location &loc = dotProduct->getLoc(); - mlir::Value lhs = dotProduct.getLhs(); - mlir::Value rhs = dotProduct.getRhs(); + mlir::Value lhs = adaptor.getLhs(); + mlir::Value rhs = adaptor.getRhs(); llvm::SmallVector inArgs; inArgs.push_back({lhs, lhs.getType()}); inArgs.push_back({rhs, rhs.getType()}); @@ -415,12 +418,12 @@ class TransposeOpConversion using HlfirIntrinsicConversion::HlfirIntrinsicConversion; mlir::LogicalResult - matchAndRewrite(hlfir::TransposeOp transpose, - mlir::PatternRewriter &rewriter) const override { + matchAndRewrite(hlfir::TransposeOp transpose, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { fir::FirOpBuilder builder{rewriter, transpose.getOperation()}; const mlir::Location &loc = transpose->getLoc(); - mlir::Value arg = transpose.getArray(); + mlir::Value arg = adaptor.getArray(); llvm::SmallVector inArgs; inArgs.push_back({arg, arg.getType()}); @@ -445,13 +448,13 @@ struct MatmulTransposeOpConversion hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion; mlir::LogicalResult - matchAndRewrite(hlfir::MatmulTransposeOp multranspose, - mlir::PatternRewriter &rewriter) const override { + matchAndRewrite(hlfir::MatmulTransposeOp multranspose, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { fir::FirOpBuilder builder{rewriter, multranspose.getOperation()}; const mlir::Location &loc = multranspose->getLoc(); - mlir::Value lhs = multranspose.getLhs(); - mlir::Value rhs = multranspose.getRhs(); + mlir::Value lhs = adaptor.getLhs(); + mlir::Value rhs = adaptor.getRhs(); llvm::SmallVector inArgs; inArgs.push_back({lhs, lhs.getType()}); inArgs.push_back({rhs, rhs.getType()});