Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] de-duplicate AbstractResult pass #88867

Merged
merged 4 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ inline mlir::NamedAttribute getAdaptToByRefAttr(Builder &builder) {
builder.getUnitAttr()};
}

/// Returns true if the operation name is for a container operation expected to
/// contain (HL)FIR operations which need to be lowered by FIR passes. The
/// simplest example of this is func.func.
/// This operates on mlir::RegisteredOperationName so that it can be used to
/// implement mlir::Pass::canScheduleOn.
bool isa_toplevel(mlir::RegisteredOperationName opName);

} // namespace fir

#endif // FORTRAN_OPTIMIZER_DIALECT_FIROPSSUPPORT_H
6 changes: 2 additions & 4 deletions flang/include/flang/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ namespace fir {
// Passes defined in Passes.td
//===----------------------------------------------------------------------===//

#define GEN_PASS_DECL_ABSTRACTRESULTONFUNCOPT
#define GEN_PASS_DECL_ABSTRACTRESULTONGLOBALOPT
#define GEN_PASS_DECL_ABSTRACTRESULTOPT
#define GEN_PASS_DECL_AFFINEDIALECTPROMOTION
#define GEN_PASS_DECL_AFFINEDIALECTDEMOTION
#define GEN_PASS_DECL_ANNOTATECONSTANTOPERANDS
Expand All @@ -50,8 +49,7 @@ namespace fir {
#define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION
#include "flang/Optimizer/Transforms/Passes.h.inc"

std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
std::unique_ptr<mlir::Pass> createAbstractResultOnGlobalOptPass();
std::unique_ptr<mlir::Pass> createAbstractResultOptPass();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This declaration can be auto-generated by TableGen by just removing the
let constructor = ... line in the Pass entry. This will also automatically create an overload that can be constructed with the option object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this. I'll update the flang passes as I go along

std::unique_ptr<mlir::Pass> createAffineDemotionPass();
std::unique_ptr<mlir::Pass>
createArrayValueCopyPass(fir::ArrayValueCopyOptions options = {});
Expand Down
13 changes: 3 additions & 10 deletions flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

include "mlir/Pass/PassBase.td"

class AbstractResultOptBase<string optExt, string operation>
: Pass<"abstract-result-on-" # optExt # "-opt", operation> {
def AbstractResultOpt
: Pass<"abstract-result"> {
let summary = "Convert fir.array, fir.box and fir.rec function result to "
"function argument";
let description = [{
Expand All @@ -33,14 +33,7 @@ class AbstractResultOptBase<string optExt, string operation>
"Pass fir.array<T> result as fir.box<fir.array<T>> argument instead"
" of fir.ref<fir.array<T>>.">
];
}

def AbstractResultOnFuncOpt : AbstractResultOptBase<"func", "mlir::func::FuncOp"> {
let constructor = "::fir::createAbstractResultOnFuncOptPass()";
}

def AbstractResultOnGlobalOpt : AbstractResultOptBase<"global", "fir::GlobalOp"> {
let constructor = "::fir::createAbstractResultOnGlobalOptPass()";
let constructor = "::fir::createAbstractResultOptPass()";
}

def AffineDialectPromotion : Pass<"promote-to-affine", "::mlir::func::FuncOp"> {
Expand Down
30 changes: 27 additions & 3 deletions flang/include/flang/Tools/CLOptions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "flang/Optimizer/Transforms/Passes.h"
#include "llvm/Passes/OptimizationLevel.h"
#include "llvm/Support/CommandLine.h"
#include <type_traits>

#define DisableOption(DOName, DOOption, DODescription) \
static llvm::cl::opt<bool> disable##DOName("disable-" DOOption, \
Expand Down Expand Up @@ -86,6 +87,31 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",
DisableOption(ExternalNameConversion, "external-name-interop",
"convert names with external convention");

// TODO: remove once these are used for non-codegen passes
#if !defined(FLANG_EXCLUDE_CODEGEN)
using PassConstructor = std::function<std::unique_ptr<mlir::Pass>()>;
tblah marked this conversation as resolved.
Show resolved Hide resolved

template <typename OP>
void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
pm.addNestedPass<OP>(ctor());
}

template <typename OP, typename... OPS,
typename = std::enable_if_t<sizeof...(OPS) != 0>>
void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
addNestedPassToOps<OP>(pm, ctor);
addNestedPassToOps<OPS...>(pm, ctor);
}

void addNestedPassToAllTopLevelOperations(
mlir::PassManager &pm, PassConstructor ctor) {
// TODO: add more operations that might need full lowering support
// any operations also need to be added to fir::isa_toplevel
addNestedPassToOps<mlir::func::FuncOp, mlir::omp::DeclareReductionOp,
fir::GlobalOp>(pm, ctor);
}
#endif

/// Generic for adding a pass to the pass manager if it is not disabled.
template <typename F>
void addPassConditionally(
Expand Down Expand Up @@ -304,9 +330,7 @@ inline void createDebugPasses(
inline void createDefaultFIRCodeGenPassPipeline(
mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig config) {
fir::addBoxedProcedurePass(pm);
pm.addNestedPass<mlir::func::FuncOp>(
fir::createAbstractResultOnFuncOptPass());
pm.addNestedPass<fir::GlobalOp>(fir::createAbstractResultOnGlobalOptPass());
addNestedPassToAllTopLevelOperations(pm, fir::createAbstractResultOptPass);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't using pm.nestAny(fir::createAbstractResultOptPass()) work?

It seems to me that the using both addNestedPass<OP> and canScheduleOn to restrict what this pass is being run on is a bit redundant (and probably less optimal it it means that the ModuleOp operations need to be walked three times to schedule the passes instead of scheduling it in a single pass, although I do not know the pass scheduling details enough to be sure how the module would be walked in both cases).

Although, maybe the reverse is better from a conceptual point of view: canScheduleOn could be removed from the pass definition. There is no conceptual aspects of the pass that restrict it from running on any operation that may contain FIR calls I think, so it would make sense to me that the pipeline is the only place describing which top level operations needs to be translated here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for review.

I think the idea is that addNested<OP> could work on a subset of the operations supported by the pass (reported by canBeScheduled). Conceptually, addNested<OP> means you want to run the pass on this operation in this particular pipeline. But a different pipeline could be constructed which tries to use the pass on a different operation type (e.g. fir-opt --abstract-result module.mlir). canScheduleOn guards against running pipelines on operation types which the pass is not intended for.

We have to implement canScheduleOn because it is pure virtual in mlir::Pass. Even if there were a way to tell the pass manager to run this pass on every operation on which the pass is supported, this would have to be implemented by calling canScheduleOn with every operation and then only scheduling the pass on supported operations. Unfortunately, canScheduleOn is implemented with string comparisons (it is defined as always taking a RegisteredOperationName argument) so I would prefer the bit of duplication so that we can limit these string comparisons.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

canScheduleOn guards against running pipelines on operation types which the pass is not intended for.

In that case, I think it should be runnable on any op (the ReturnOp handling prevents that currently since you noticed it does not work on ModuleOp).

We have to implement canScheduleOn because it is pure virtual in mlir::Pass.

You do not need to in that case because AbstractResultOptPass actually inherits from mlir::OperationPass<> that defines canScheduleOn here in a way that it would say true to any operation it is being schedule on. The .td Pass<> syntax actually creates a pass inheriting from mlir::OperationPass<>, not directly from mlir::Pass (see here).

So all in all, I am OK with your patch, it is an improvement, and it could be further improved by modifying the ReturnOp handling and removing the canScheduleOn restriction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I had missunderstood the tablegen change I made. I thought it made it inherit mlir::Pass, but you are right it is mlir::OperationPass<>.

Thanks for explaining. I will see if I can create a nested pass pipeline inside the AbstractResultPass so that even when run on a module it behaves the same way as if you ran the old function and global passes. And I agree that the canScheduleOn should be removed if possible.

fir::addCodeGenRewritePass(pm);
fir::addTargetRewritePass(pm);
fir::addExternalNameConversionPass(pm, config.Underscoring);
Expand Down
13 changes: 13 additions & 0 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3846,6 +3846,19 @@ std::optional<std::int64_t> fir::getIntIfConstant(mlir::Value value) {
return {};
}

bool fir::isa_toplevel(mlir::RegisteredOperationName opName) {
const std::initializer_list<llvm::StringLiteral> topLevelOps{
fir::GlobalOp::getOperationName(),
mlir::func::FuncOp::getOperationName(),
mlir::omp::DeclareReductionOp::getOperationName(),
};

llvm::StringRef opStr = opName.getStringRef();
return llvm::any_of(topLevelOps, [&](const llvm::StringRef &topLevelOp) {
return opStr == topLevelOp;
});
}

mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) {
for (auto i = path.begin(), end = path.end(); eleTy && i < end;) {
eleTy = llvm::TypeSwitch<mlir::Type, mlir::Type>(eleTy)
Expand Down
139 changes: 62 additions & 77 deletions flang/lib/Optimizer/Transforms/AbstractResult.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
#include "llvm/ADT/TypeSwitch.h"

namespace fir {
#define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
#define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
#define GEN_PASS_DEF_ABSTRACTRESULTOPT
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

Expand Down Expand Up @@ -285,58 +284,8 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
bool shouldBoxResult;
};

/// @brief Base CRTP class for AbstractResult pass family.
/// Contains common logic for abstract result conversion in a reusable fashion.
/// @tparam Pass target class that implements operation-specific logic.
/// @tparam PassBase base class template for the pass generated by TableGen.
/// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
/// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
/// This function should implement operation-specific functionality.
template <typename Pass, template <typename> class PassBase>
class AbstractResultOptTemplate : public PassBase<Pass> {
public:
void runOnOperation() override {
auto *context = &this->getContext();
auto op = this->getOperation();

mlir::RewritePatternSet patterns(context);
mlir::ConversionTarget target = *context;
const bool shouldBoxResult = this->passResultAsBox.getValue();

auto &self = static_cast<Pass &>(*this);
self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);

// Convert the calls and, if needed, the ReturnOp in the function body.
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect>();
target.addIllegalOp<fir::SaveResultOp>();
target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
return !hasAbstractResult(call.getFunctionType());
});
target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
return !hasAbstractResult(funTy);
return true;
});
target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
return !hasAbstractResult(dispatch.getFunctionType());
});

patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
patterns.insert<SaveResultOpConversion>(context);
patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
if (mlir::failed(
mlir::applyPartialConversion(op, target, std::move(patterns)))) {
mlir::emitError(op.getLoc(), "error in converting abstract results\n");
this->signalPassFailure();
}
}
};

class AbstractResultOnFuncOpt
: public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
fir::impl::AbstractResultOnFuncOptBase> {
class AbstractResultOpt
: public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
public:
void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
mlir::RewritePatternSet &patterns,
Expand Down Expand Up @@ -386,40 +335,76 @@ class AbstractResultOnFuncOpt
}
}
}
};

inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
return mlir::TypeSwitch<mlir::Type, bool>(type)
.Case([](fir::BoxProcType boxProc) {
return fir::hasAbstractResult(
boxProc.getEleTy().cast<mlir::FunctionType>());
})
.Case([](fir::PointerType pointer) {
return fir::hasAbstractResult(
pointer.getEleTy().cast<mlir::FunctionType>());
})
.Default([](auto &&) { return false; });
}
inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
return mlir::TypeSwitch<mlir::Type, bool>(type)
.Case([](fir::BoxProcType boxProc) {
return fir::hasAbstractResult(
boxProc.getEleTy().cast<mlir::FunctionType>());
})
.Case([](fir::PointerType pointer) {
return fir::hasAbstractResult(
pointer.getEleTy().cast<mlir::FunctionType>());
})
.Default([](auto &&) { return false; });
}

class AbstractResultOnGlobalOpt
: public AbstractResultOptTemplate<
AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
public:
void runOnSpecificOperation(fir::GlobalOp global, bool,
mlir::RewritePatternSet &,
mlir::ConversionTarget &) {
if (containsFunctionTypeWithAbstractResult(global.getType())) {
TODO(global->getLoc(), "support for procedure pointers");
}
}

virtual bool canScheduleOn(RegisteredOperationName opName) const override {
return fir::isa_toplevel(opName);
}

void runOnOperation() override {
auto *context = &this->getContext();
mlir::Operation *op = this->getOperation();

mlir::RewritePatternSet patterns(context);
mlir::ConversionTarget target = *context;
const bool shouldBoxResult = this->passResultAsBox.getValue();

mlir::TypeSwitch<mlir::Operation *, void>(op)
.Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
runOnSpecificOperation(op, shouldBoxResult, patterns, target);
});

// Convert the calls and, if needed, the ReturnOp in the function body.
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect>();
target.addIllegalOp<fir::SaveResultOp>();
target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
return !hasAbstractResult(call.getFunctionType());
});
target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
return !hasAbstractResult(funTy);
return true;
});
target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
return !hasAbstractResult(dispatch.getFunctionType());
});

patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
patterns.insert<SaveResultOpConversion>(context);
patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
if (mlir::failed(
mlir::applyPartialConversion(op, target, std::move(patterns)))) {
mlir::emitError(op->getLoc(), "error in converting abstract results\n");
this->signalPassFailure();
}
}
};

} // end anonymous namespace
} // namespace fir

std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
return std::make_unique<AbstractResultOnFuncOpt>();
}

std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
return std::make_unique<AbstractResultOnGlobalOpt>();
std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() {
return std::make_unique<AbstractResultOpt>();
}
8 changes: 5 additions & 3 deletions flang/test/Driver/mlir-debug-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
! ALL-NEXT: BoxedProcedurePass

! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
! ALL-NEXT: 'fir.global' Pipeline
! ALL-NEXT: AbstractResultOnGlobalOpt
! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: 'func.func' Pipeline
! ALL-NEXT: AbstractResultOnFuncOpt
! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: 'omp.declare_reduction' Pipeline
! ALL-NEXT: AbstractResultOpt

! ALL-NEXT: CodeGenRewrite
! ALL-NEXT: (S) 0 num-dce'd - Number of operations eliminated
Expand Down
8 changes: 5 additions & 3 deletions flang/test/Driver/mlir-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
! ALL-NEXT: BoxedProcedurePass

! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
! ALL-NEXT: 'fir.global' Pipeline
! ALL-NEXT: AbstractResultOnGlobalOpt
! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: 'func.func' Pipeline
! ALL-NEXT: AbstractResultOnFuncOpt
! ALL-NEXT: AbstractResultOpt
! ALL-NEXT: 'omp.declare_reduction' Pipeline
! ALL-NEXT: AbstractResultOpt

! ALL-NEXT: CodeGenRewrite
! ALL-NEXT: (S) 0 num-dce'd - Number of operations eliminated
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Fir/abstract-result-2.fir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s
// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result))" | FileCheck %s

// Check that the attributes are shifted along with their corresponding arguments

Expand Down
8 changes: 4 additions & 4 deletions flang/test/Fir/abstract-results.fir
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to
// functions that take an additional argument for the result.

// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s --check-prefix=FUNC-REF
// RUN: fir-opt %s --abstract-result-on-func-opt=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
// RUN: fir-opt %s --abstract-result-on-global-opt | FileCheck %s --check-prefix=GLOBAL-REF
// RUN: fir-opt %s --abstract-result-on-global-opt=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX
// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result))" | FileCheck %s --check-prefix=FUNC-REF
// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result{abstract-result-as-box}))" | FileCheck %s --check-prefix=FUNC-BOX
// RUN: fir-opt %s --pass-pipeline="builtin.module(fir.global(abstract-result))" | FileCheck %s --check-prefix=GLOBAL-REF
// RUN: fir-opt %s --pass-pipeline="builtin.module(fir.global(abstract-result{abstract-result-as-box}))" | FileCheck %s --check-prefix=GLOBAL-BOX

// ----------------------- Test declaration rewrite ----------------------------

Expand Down
8 changes: 5 additions & 3 deletions flang/test/Fir/basic-program.fir
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ func.func @_QQmain() {
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
// PASSES-NEXT: BoxedProcedurePass

// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func']
// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
// PASSES-NEXT: 'fir.global' Pipeline
// PASSES-NEXT: AbstractResultOnGlobalOpt
// PASSES-NEXT: AbstractResultOpt
// PASSES-NEXT: 'func.func' Pipeline
// PASSES-NEXT: AbstractResultOnFuncOpt
// PASSES-NEXT: AbstractResultOpt
// PASSES-NEXT: 'omp.declare_reduction' Pipeline
// PASSES-NEXT: AbstractResultOpt

// PASSES-NEXT: CodeGenRewrite
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations eliminated
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
! RUN: %flang_fc1 -emit-mlir %s -o - | FileCheck %s --check-prefix=BEFORE
! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result-on-global-opt | FileCheck %s --check-prefix=AFTER
! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --pass-pipeline="builtin.module(fir.global(abstract-result))" | FileCheck %s --check-prefix=AFTER
module a
type f
contains
Expand Down
Loading