Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] de-duplicate AbstractResult pass #88867

Merged
merged 4 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions flang/include/flang/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ namespace fir {
// Passes defined in Passes.td
//===----------------------------------------------------------------------===//

#define GEN_PASS_DECL_ABSTRACTRESULTONFUNCOPT
#define GEN_PASS_DECL_ABSTRACTRESULTONGLOBALOPT
#define GEN_PASS_DECL_ABSTRACTRESULTOPT
#define GEN_PASS_DECL_AFFINEDIALECTPROMOTION
#define GEN_PASS_DECL_AFFINEDIALECTDEMOTION
#define GEN_PASS_DECL_ANNOTATECONSTANTOPERANDS
Expand All @@ -50,8 +49,6 @@ namespace fir {
#define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION
#include "flang/Optimizer/Transforms/Passes.h.inc"

std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
std::unique_ptr<mlir::Pass> createAbstractResultOnGlobalOptPass();
std::unique_ptr<mlir::Pass> createAffineDemotionPass();
std::unique_ptr<mlir::Pass>
createArrayValueCopyPass(fir::ArrayValueCopyOptions options = {});
Expand Down
12 changes: 2 additions & 10 deletions flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

include "mlir/Pass/PassBase.td"

class AbstractResultOptBase<string optExt, string operation>
: Pass<"abstract-result-on-" # optExt # "-opt", operation> {
def AbstractResultOpt
: Pass<"abstract-result"> {
let summary = "Convert fir.array, fir.box and fir.rec function result to "
"function argument";
let description = [{
Expand All @@ -35,14 +35,6 @@ class AbstractResultOptBase<string optExt, string operation>
];
}

def AbstractResultOnFuncOpt : AbstractResultOptBase<"func", "mlir::func::FuncOp"> {
let constructor = "::fir::createAbstractResultOnFuncOptPass()";
}

def AbstractResultOnGlobalOpt : AbstractResultOptBase<"global", "fir::GlobalOp"> {
let constructor = "::fir::createAbstractResultOnGlobalOptPass()";
}

def AffineDialectPromotion : Pass<"promote-to-affine", "::mlir::func::FuncOp"> {
let summary = "Promotes `fir.{do_loop,if}` to `affine.{for,if}`.";
let description = [{
Expand Down
28 changes: 25 additions & 3 deletions flang/include/flang/Tools/CLOptions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "flang/Optimizer/Transforms/Passes.h"
#include "llvm/Passes/OptimizationLevel.h"
#include "llvm/Support/CommandLine.h"
#include <type_traits>

#define DisableOption(DOName, DOOption, DODescription) \
static llvm::cl::opt<bool> disable##DOName("disable-" DOOption, \
Expand Down Expand Up @@ -86,6 +87,29 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",
DisableOption(ExternalNameConversion, "external-name-interop",
"convert names with external convention");

// TODO: remove once these are used for non-codegen passes
#if !defined(FLANG_EXCLUDE_CODEGEN)
using PassConstructor = std::unique_ptr<mlir::Pass>();

template <typename OP>
void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
pm.addNestedPass<OP>(ctor());
}

template <typename OP, typename... OPS,
typename = std::enable_if_t<sizeof...(OPS) != 0>>
void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
addNestedPassToOps<OP>(pm, ctor);
addNestedPassToOps<OPS...>(pm, ctor);
}

void addNestedPassToAllTopLevelOperations(
mlir::PassManager &pm, PassConstructor ctor) {
addNestedPassToOps<mlir::func::FuncOp, mlir::omp::DeclareReductionOp,
fir::GlobalOp>(pm, ctor);
}
#endif

/// Generic for adding a pass to the pass manager if it is not disabled.
template <typename F>
void addPassConditionally(
Expand Down Expand Up @@ -304,9 +328,7 @@ inline void createDebugPasses(
inline void createDefaultFIRCodeGenPassPipeline(
mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig config) {
fir::addBoxedProcedurePass(pm);
pm.addNestedPass<mlir::func::FuncOp>(
fir::createAbstractResultOnFuncOptPass());
pm.addNestedPass<fir::GlobalOp>(fir::createAbstractResultOnGlobalOptPass());
addNestedPassToAllTopLevelOperations(pm, fir::createAbstractResultOpt);
fir::addCodeGenRewritePass(pm);
fir::addTargetRewritePass(pm);
fir::addExternalNameConversionPass(pm, config.Underscoring);
Expand Down
170 changes: 90 additions & 80 deletions flang/lib/Optimizer/Transforms/AbstractResult.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/TypeSwitch.h"

namespace fir {
#define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
#define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
#define GEN_PASS_DEF_ABSTRACTRESULTOPT
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

Expand Down Expand Up @@ -285,59 +284,12 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
bool shouldBoxResult;
};

/// @brief Base CRTP class for AbstractResult pass family.
/// Contains common logic for abstract result conversion in a reusable fashion.
/// @tparam Pass target class that implements operation-specific logic.
/// @tparam PassBase base class template for the pass generated by TableGen.
/// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
/// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
/// This function should implement operation-specific functionality.
template <typename Pass, template <typename> class PassBase>
class AbstractResultOptTemplate : public PassBase<Pass> {
class AbstractResultOpt
: public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
public:
void runOnOperation() override {
auto *context = &this->getContext();
auto op = this->getOperation();

mlir::RewritePatternSet patterns(context);
mlir::ConversionTarget target = *context;
const bool shouldBoxResult = this->passResultAsBox.getValue();

auto &self = static_cast<Pass &>(*this);
self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);

// Convert the calls and, if needed, the ReturnOp in the function body.
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect>();
target.addIllegalOp<fir::SaveResultOp>();
target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
return !hasAbstractResult(call.getFunctionType());
});
target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
return !hasAbstractResult(funTy);
return true;
});
target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
return !hasAbstractResult(dispatch.getFunctionType());
});

patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
patterns.insert<SaveResultOpConversion>(context);
patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
if (mlir::failed(
mlir::applyPartialConversion(op, target, std::move(patterns)))) {
mlir::emitError(op.getLoc(), "error in converting abstract results\n");
this->signalPassFailure();
}
}
};
using fir::impl::AbstractResultOptBase<
AbstractResultOpt>::AbstractResultOptBase;

class AbstractResultOnFuncOpt
: public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
fir::impl::AbstractResultOnFuncOptBase> {
public:
void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
mlir::RewritePatternSet &patterns,
mlir::ConversionTarget &target) {
Expand Down Expand Up @@ -386,40 +338,98 @@ class AbstractResultOnFuncOpt
}
}
}
};

inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
return mlir::TypeSwitch<mlir::Type, bool>(type)
.Case([](fir::BoxProcType boxProc) {
return fir::hasAbstractResult(
boxProc.getEleTy().cast<mlir::FunctionType>());
})
.Case([](fir::PointerType pointer) {
return fir::hasAbstractResult(
pointer.getEleTy().cast<mlir::FunctionType>());
})
.Default([](auto &&) { return false; });
}
inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
return mlir::TypeSwitch<mlir::Type, bool>(type)
.Case([](fir::BoxProcType boxProc) {
return fir::hasAbstractResult(
boxProc.getEleTy().cast<mlir::FunctionType>());
})
.Case([](fir::PointerType pointer) {
return fir::hasAbstractResult(
pointer.getEleTy().cast<mlir::FunctionType>());
})
.Default([](auto &&) { return false; });
}

class AbstractResultOnGlobalOpt
: public AbstractResultOptTemplate<
AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
public:
void runOnSpecificOperation(fir::GlobalOp global, bool,
mlir::RewritePatternSet &,
mlir::ConversionTarget &) {
if (containsFunctionTypeWithAbstractResult(global.getType())) {
TODO(global->getLoc(), "support for procedure pointers");
}
}
};
} // end anonymous namespace
} // namespace fir

std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
return std::make_unique<AbstractResultOnFuncOpt>();
}
/// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
void runOnModule() {
mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());

auto pass = std::make_unique<AbstractResultOpt>();
pass->copyOptionValuesFrom(this);
mlir::OpPassManager pipeline;
pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});

// Run the pass on all operations directly nested inside of the ModuleOp
// we can't just call runOnSpecificOperation here because the pass
// implementation only works when scoped to a particular func.func or
// fir.global
for (mlir::Region &region : mod->getRegions()) {
for (mlir::Block &block : region.getBlocks()) {
for (mlir::Operation &op : block.getOperations()) {
if (mlir::failed(runPipeline(pipeline, &op))) {
mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
signalPassFailure();
return;
}
}
}
}
}

std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
return std::make_unique<AbstractResultOnGlobalOpt>();
}
void runOnOperation() override {
auto *context = &this->getContext();
mlir::Operation *op = this->getOperation();
if (mlir::isa<mlir::ModuleOp>(op)) {
runOnModule();
return;
}

mlir::RewritePatternSet patterns(context);
mlir::ConversionTarget target = *context;
const bool shouldBoxResult = this->passResultAsBox.getValue();

mlir::TypeSwitch<mlir::Operation *, void>(op)
.Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
runOnSpecificOperation(op, shouldBoxResult, patterns, target);
});

// Convert the calls and, if needed, the ReturnOp in the function body.
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect>();
target.addIllegalOp<fir::SaveResultOp>();
target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
return !hasAbstractResult(call.getFunctionType());
});
target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
return !hasAbstractResult(funTy);
return true;
});
target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
return !hasAbstractResult(dispatch.getFunctionType());
});

patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
patterns.insert<SaveResultOpConversion>(context);
patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
if (mlir::failed(
mlir::applyPartialConversion(op, target, std::move(patterns)))) {
mlir::emitError(op->getLoc(), "error in converting abstract results\n");
this->signalPassFailure();
}
}
};

} // end anonymous namespace
} // namespace fir
8 changes: 5 additions & 3 deletions flang/test/Driver/mlir-debug-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
! ALL-NEXT: BoxedProcedurePass

! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
! ALL-NEXT: 'fir.global' Pipeline
! ALL-NEXT: AbstractResultOnGlobalOpt
! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: 'func.func' Pipeline
! ALL-NEXT: AbstractResultOnFuncOpt
! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: 'omp.declare_reduction' Pipeline
! ALL-NEXT: AbstractResultOpt

! ALL-NEXT: CodeGenRewrite
! ALL-NEXT: (S) 0 num-dce'd - Number of operations eliminated
Expand Down
8 changes: 5 additions & 3 deletions flang/test/Driver/mlir-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
! ALL-NEXT: BoxedProcedurePass

! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
! ALL-NEXT: 'fir.global' Pipeline
! ALL-NEXT: AbstractResultOnGlobalOpt
! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: 'func.func' Pipeline
! ALL-NEXT: AbstractResultOnFuncOpt
! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: 'omp.declare_reduction' Pipeline
! ALL-NEXT: AbstractResultOpt

! ALL-NEXT: CodeGenRewrite
! ALL-NEXT: (S) 0 num-dce'd - Number of operations eliminated
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Fir/abstract-result-2.fir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s
// RUN: fir-opt %s --abstract-result | FileCheck %s

// Check that the attributes are shifted along with their corresponding arguments

Expand Down
8 changes: 4 additions & 4 deletions flang/test/Fir/abstract-results.fir
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to
// functions that take an additional argument for the result.

// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s --check-prefix=FUNC-REF
// RUN: fir-opt %s --abstract-result-on-func-opt=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
// RUN: fir-opt %s --abstract-result-on-global-opt | FileCheck %s --check-prefix=GLOBAL-REF
// RUN: fir-opt %s --abstract-result-on-global-opt=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX
// RUN: fir-opt %s --abstract-result | FileCheck %s --check-prefix=FUNC-REF
// RUN: fir-opt %s --abstract-result=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
// RUN: fir-opt %s --abstract-result | FileCheck %s --check-prefix=GLOBAL-REF
// RUN: fir-opt %s --abstract-result=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX

// ----------------------- Test declaration rewrite ----------------------------

Expand Down
8 changes: 5 additions & 3 deletions flang/test/Fir/basic-program.fir
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ func.func @_QQmain() {
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
// PASSES-NEXT: BoxedProcedurePass

// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func']
// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
// PASSES-NEXT: 'fir.global' Pipeline
// PASSES-NEXT: AbstractResultOnGlobalOpt
// PASSES-NEXT: AbstractResultOpt
// PASSES-NEXT: 'func.func' Pipeline
// PASSES-NEXT: AbstractResultOnFuncOpt
// PASSES-NEXT: AbstractResultOpt
// PASSES-NEXT: 'omp.declare_reduction' Pipeline
// PASSES-NEXT: AbstractResultOpt

// PASSES-NEXT: CodeGenRewrite
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations eliminated
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
! RUN: %flang_fc1 -emit-mlir %s -o - | FileCheck %s --check-prefix=BEFORE
! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result-on-global-opt | FileCheck %s --check-prefix=AFTER
! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result | FileCheck %s --check-prefix=AFTER
module a
type f
contains
Expand Down
Loading