Skip to content

Commit

Permalink
[flang] de-duplicate AbstractResult pass (#88867)
Browse files Browse the repository at this point in the history
This is the first proof of concept of the modification of FIR codegen to
fully support a variety of top level operations (beyond just func.func)
proposed in
https://discourse.llvm.org/t/rfc-add-an-interface-for-top-level-container-operations
  • Loading branch information
tblah committed Apr 22, 2024
1 parent c2d665b commit bfd1944
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 112 deletions.
5 changes: 1 addition & 4 deletions flang/include/flang/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ namespace fir {
// Passes defined in Passes.td
//===----------------------------------------------------------------------===//

#define GEN_PASS_DECL_ABSTRACTRESULTONFUNCOPT
#define GEN_PASS_DECL_ABSTRACTRESULTONGLOBALOPT
#define GEN_PASS_DECL_ABSTRACTRESULTOPT
#define GEN_PASS_DECL_AFFINEDIALECTPROMOTION
#define GEN_PASS_DECL_AFFINEDIALECTDEMOTION
#define GEN_PASS_DECL_ANNOTATECONSTANTOPERANDS
Expand All @@ -50,8 +49,6 @@ namespace fir {
#define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION
#include "flang/Optimizer/Transforms/Passes.h.inc"

std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
std::unique_ptr<mlir::Pass> createAbstractResultOnGlobalOptPass();
std::unique_ptr<mlir::Pass> createAffineDemotionPass();
std::unique_ptr<mlir::Pass>
createArrayValueCopyPass(fir::ArrayValueCopyOptions options = {});
Expand Down
12 changes: 2 additions & 10 deletions flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

include "mlir/Pass/PassBase.td"

class AbstractResultOptBase<string optExt, string operation>
: Pass<"abstract-result-on-" # optExt # "-opt", operation> {
def AbstractResultOpt
: Pass<"abstract-result"> {
let summary = "Convert fir.array, fir.box and fir.rec function result to "
"function argument";
let description = [{
Expand All @@ -35,14 +35,6 @@ class AbstractResultOptBase<string optExt, string operation>
];
}

def AbstractResultOnFuncOpt : AbstractResultOptBase<"func", "mlir::func::FuncOp"> {
let constructor = "::fir::createAbstractResultOnFuncOptPass()";
}

def AbstractResultOnGlobalOpt : AbstractResultOptBase<"global", "fir::GlobalOp"> {
let constructor = "::fir::createAbstractResultOnGlobalOptPass()";
}

def AffineDialectPromotion : Pass<"promote-to-affine", "::mlir::func::FuncOp"> {
let summary = "Promotes `fir.{do_loop,if}` to `affine.{for,if}`.";
let description = [{
Expand Down
28 changes: 25 additions & 3 deletions flang/include/flang/Tools/CLOptions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "flang/Optimizer/Transforms/Passes.h"
#include "llvm/Passes/OptimizationLevel.h"
#include "llvm/Support/CommandLine.h"
#include <type_traits>

#define DisableOption(DOName, DOOption, DODescription) \
static llvm::cl::opt<bool> disable##DOName("disable-" DOOption, \
Expand Down Expand Up @@ -86,6 +87,29 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",
DisableOption(ExternalNameConversion, "external-name-interop",
"convert names with external convention");

// TODO: remove once these are used for non-codegen passes
#if !defined(FLANG_EXCLUDE_CODEGEN)
using PassConstructor = std::unique_ptr<mlir::Pass>();

template <typename OP>
void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
pm.addNestedPass<OP>(ctor());
}

template <typename OP, typename... OPS,
typename = std::enable_if_t<sizeof...(OPS) != 0>>
void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
addNestedPassToOps<OP>(pm, ctor);
addNestedPassToOps<OPS...>(pm, ctor);
}

void addNestedPassToAllTopLevelOperations(
mlir::PassManager &pm, PassConstructor ctor) {
addNestedPassToOps<mlir::func::FuncOp, mlir::omp::DeclareReductionOp,
fir::GlobalOp>(pm, ctor);
}
#endif

/// Generic for adding a pass to the pass manager if it is not disabled.
template <typename F>
void addPassConditionally(
Expand Down Expand Up @@ -304,9 +328,7 @@ inline void createDebugPasses(
inline void createDefaultFIRCodeGenPassPipeline(
mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig config) {
fir::addBoxedProcedurePass(pm);
pm.addNestedPass<mlir::func::FuncOp>(
fir::createAbstractResultOnFuncOptPass());
pm.addNestedPass<fir::GlobalOp>(fir::createAbstractResultOnGlobalOptPass());
addNestedPassToAllTopLevelOperations(pm, fir::createAbstractResultOpt);
fir::addCodeGenRewritePass(pm);
fir::addTargetRewritePass(pm);
fir::addExternalNameConversionPass(pm, config.Underscoring);
Expand Down
170 changes: 90 additions & 80 deletions flang/lib/Optimizer/Transforms/AbstractResult.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/TypeSwitch.h"

namespace fir {
#define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
#define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
#define GEN_PASS_DEF_ABSTRACTRESULTOPT
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

Expand Down Expand Up @@ -285,59 +284,12 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
bool shouldBoxResult;
};

/// @brief Base CRTP class for AbstractResult pass family.
/// Contains common logic for abstract result conversion in a reusable fashion.
/// @tparam Pass target class that implements operation-specific logic.
/// @tparam PassBase base class template for the pass generated by TableGen.
/// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
/// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
/// This function should implement operation-specific functionality.
template <typename Pass, template <typename> class PassBase>
class AbstractResultOptTemplate : public PassBase<Pass> {
class AbstractResultOpt
: public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
public:
void runOnOperation() override {
auto *context = &this->getContext();
auto op = this->getOperation();

mlir::RewritePatternSet patterns(context);
mlir::ConversionTarget target = *context;
const bool shouldBoxResult = this->passResultAsBox.getValue();

auto &self = static_cast<Pass &>(*this);
self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);

// Convert the calls and, if needed, the ReturnOp in the function body.
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect>();
target.addIllegalOp<fir::SaveResultOp>();
target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
return !hasAbstractResult(call.getFunctionType());
});
target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
return !hasAbstractResult(funTy);
return true;
});
target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
return !hasAbstractResult(dispatch.getFunctionType());
});

patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
patterns.insert<SaveResultOpConversion>(context);
patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
if (mlir::failed(
mlir::applyPartialConversion(op, target, std::move(patterns)))) {
mlir::emitError(op.getLoc(), "error in converting abstract results\n");
this->signalPassFailure();
}
}
};
using fir::impl::AbstractResultOptBase<
AbstractResultOpt>::AbstractResultOptBase;

class AbstractResultOnFuncOpt
: public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
fir::impl::AbstractResultOnFuncOptBase> {
public:
void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
mlir::RewritePatternSet &patterns,
mlir::ConversionTarget &target) {
Expand Down Expand Up @@ -386,40 +338,98 @@ class AbstractResultOnFuncOpt
}
}
}
};

inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
return mlir::TypeSwitch<mlir::Type, bool>(type)
.Case([](fir::BoxProcType boxProc) {
return fir::hasAbstractResult(
boxProc.getEleTy().cast<mlir::FunctionType>());
})
.Case([](fir::PointerType pointer) {
return fir::hasAbstractResult(
pointer.getEleTy().cast<mlir::FunctionType>());
})
.Default([](auto &&) { return false; });
}
inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
return mlir::TypeSwitch<mlir::Type, bool>(type)
.Case([](fir::BoxProcType boxProc) {
return fir::hasAbstractResult(
boxProc.getEleTy().cast<mlir::FunctionType>());
})
.Case([](fir::PointerType pointer) {
return fir::hasAbstractResult(
pointer.getEleTy().cast<mlir::FunctionType>());
})
.Default([](auto &&) { return false; });
}

class AbstractResultOnGlobalOpt
: public AbstractResultOptTemplate<
AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
public:
void runOnSpecificOperation(fir::GlobalOp global, bool,
mlir::RewritePatternSet &,
mlir::ConversionTarget &) {
if (containsFunctionTypeWithAbstractResult(global.getType())) {
TODO(global->getLoc(), "support for procedure pointers");
}
}
};
} // end anonymous namespace
} // namespace fir

std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
return std::make_unique<AbstractResultOnFuncOpt>();
}
/// Run the pass on a ModuleOp. This makes fir-opt --abstract-result work.
void runOnModule() {
mlir::ModuleOp mod = mlir::cast<mlir::ModuleOp>(getOperation());

auto pass = std::make_unique<AbstractResultOpt>();
pass->copyOptionValuesFrom(this);
mlir::OpPassManager pipeline;
pipeline.addPass(std::unique_ptr<mlir::Pass>{pass.release()});

// Run the pass on all operations directly nested inside of the ModuleOp
// we can't just call runOnSpecificOperation here because the pass
// implementation only works when scoped to a particular func.func or
// fir.global
for (mlir::Region &region : mod->getRegions()) {
for (mlir::Block &block : region.getBlocks()) {
for (mlir::Operation &op : block.getOperations()) {
if (mlir::failed(runPipeline(pipeline, &op))) {
mlir::emitError(op.getLoc(), "Failed to run abstract result pass");
signalPassFailure();
return;
}
}
}
}
}

std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
return std::make_unique<AbstractResultOnGlobalOpt>();
}
void runOnOperation() override {
auto *context = &this->getContext();
mlir::Operation *op = this->getOperation();
if (mlir::isa<mlir::ModuleOp>(op)) {
runOnModule();
return;
}

mlir::RewritePatternSet patterns(context);
mlir::ConversionTarget target = *context;
const bool shouldBoxResult = this->passResultAsBox.getValue();

mlir::TypeSwitch<mlir::Operation *, void>(op)
.Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
runOnSpecificOperation(op, shouldBoxResult, patterns, target);
});

// Convert the calls and, if needed, the ReturnOp in the function body.
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect>();
target.addIllegalOp<fir::SaveResultOp>();
target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
return !hasAbstractResult(call.getFunctionType());
});
target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
return !hasAbstractResult(funTy);
return true;
});
target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
return !hasAbstractResult(dispatch.getFunctionType());
});

patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
patterns.insert<SaveResultOpConversion>(context);
patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
if (mlir::failed(
mlir::applyPartialConversion(op, target, std::move(patterns)))) {
mlir::emitError(op->getLoc(), "error in converting abstract results\n");
this->signalPassFailure();
}
}
};

} // end anonymous namespace
} // namespace fir
8 changes: 5 additions & 3 deletions flang/test/Driver/mlir-debug-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
! ALL-NEXT: BoxedProcedurePass

! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
! ALL-NEXT: 'fir.global' Pipeline
! ALL-NEXT: AbstractResultOnGlobalOpt
! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: 'func.func' Pipeline
! ALL-NEXT: AbstractResultOnFuncOpt
! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: 'omp.declare_reduction' Pipeline
! ALL-NEXT: AbstractResultOpt

! ALL-NEXT: CodeGenRewrite
! ALL-NEXT: (S) 0 num-dce'd - Number of operations eliminated
Expand Down
8 changes: 5 additions & 3 deletions flang/test/Driver/mlir-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
! ALL-NEXT: BoxedProcedurePass

! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
! ALL-NEXT: 'fir.global' Pipeline
! ALL-NEXT: AbstractResultOnGlobalOpt
! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: 'func.func' Pipeline
! ALL-NEXT: AbstractResultOnFuncOpt
! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: 'omp.declare_reduction' Pipeline
! ALL-NEXT: AbstractResultOpt

! ALL-NEXT: CodeGenRewrite
! ALL-NEXT: (S) 0 num-dce'd - Number of operations eliminated
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Fir/abstract-result-2.fir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s
// RUN: fir-opt %s --abstract-result | FileCheck %s

// Check that the attributes are shifted along with their corresponding arguments

Expand Down
8 changes: 4 additions & 4 deletions flang/test/Fir/abstract-results.fir
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to
// functions that take an additional argument for the result.

// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s --check-prefix=FUNC-REF
// RUN: fir-opt %s --abstract-result-on-func-opt=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
// RUN: fir-opt %s --abstract-result-on-global-opt | FileCheck %s --check-prefix=GLOBAL-REF
// RUN: fir-opt %s --abstract-result-on-global-opt=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX
// RUN: fir-opt %s --abstract-result | FileCheck %s --check-prefix=FUNC-REF
// RUN: fir-opt %s --abstract-result=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
// RUN: fir-opt %s --abstract-result | FileCheck %s --check-prefix=GLOBAL-REF
// RUN: fir-opt %s --abstract-result=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX

// ----------------------- Test declaration rewrite ----------------------------

Expand Down
8 changes: 5 additions & 3 deletions flang/test/Fir/basic-program.fir
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ func.func @_QQmain() {
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
// PASSES-NEXT: BoxedProcedurePass

// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func']
// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
// PASSES-NEXT: 'fir.global' Pipeline
// PASSES-NEXT: AbstractResultOnGlobalOpt
// PASSES-NEXT: AbstractResultOpt
// PASSES-NEXT: 'func.func' Pipeline
// PASSES-NEXT: AbstractResultOnFuncOpt
// PASSES-NEXT: AbstractResultOpt
// PASSES-NEXT: 'omp.declare_reduction' Pipeline
// PASSES-NEXT: AbstractResultOpt

// PASSES-NEXT: CodeGenRewrite
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations eliminated
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
! RUN: %flang_fc1 -emit-mlir %s -o - | FileCheck %s --check-prefix=BEFORE
! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result-on-global-opt | FileCheck %s --check-prefix=AFTER
! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result | FileCheck %s --check-prefix=AFTER
module a
type f
contains
Expand Down

0 comments on commit bfd1944

Please sign in to comment.