Skip to content

Commit

Permalink
[flang] use greedy mlir driver for stack arrays pass
Browse files Browse the repository at this point in the history
In upstream mlir, the dialect conversion infrastructure is used for
lowering from one dialect to another: the passes are of the form
XToYPass. Whereas, transformations within the same dialect tend to use
applyPatternsAndFoldGreedily.

In this case, the full complexity of applyPatternsAndFoldGreedily isn't
needed so we can get away with the simpler applyOpPatternsAndFold.

This change was suggested by @jeanPerier

Differential Revision: https://reviews.llvm.org/D150853
  • Loading branch information
tblah committed May 23, 2023
1 parent 111d274 commit 74c2ec5
Showing 1 changed file with 34 additions and 41 deletions.
75 changes: 34 additions & 41 deletions flang/lib/Optimizer/Transforms/StackArrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
Expand Down Expand Up @@ -167,25 +167,22 @@ class StackArraysAnalysisWrapper {

StackArraysAnalysisWrapper(mlir::Operation *op) {}

bool hasErrors() const;

const AllocMemMap &getCandidateOps(mlir::Operation *func);
// returns nullptr if analysis failed
const AllocMemMap *getCandidateOps(mlir::Operation *func);

private:
llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps;
bool gotError = false;

void analyseFunction(mlir::Operation *func);
mlir::LogicalResult analyseFunction(mlir::Operation *func);
};

/// Converts a fir.allocmem to a fir.alloca
class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
public:
using OpRewritePattern::OpRewritePattern;

AllocMemConversion(
explicit AllocMemConversion(
mlir::MLIRContext *ctx,
const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps);
const StackArraysAnalysisWrapper::AllocMemMap &candidateOps)
: OpRewritePattern(ctx), candidateOps{candidateOps} {}

mlir::LogicalResult
matchAndRewrite(fir::AllocMemOp allocmem,
Expand All @@ -196,9 +193,8 @@ class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc);

private:
/// allocmem operations that DFA has determined are safe to move to the stack
/// mapping to where to insert replacement freemem operations
const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps;
/// Handle to the DFA (already run)
const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;

/// If we failed to find an insertion point not inside a loop, see if it would
/// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
Expand Down Expand Up @@ -412,7 +408,8 @@ void AllocationAnalysis::processOperation(mlir::Operation *op) {
visitOperationImpl(op, *before, after);
}

void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
mlir::LogicalResult
StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
assert(mlir::isa<mlir::func::FuncOp>(func));
mlir::DataFlowSolver solver;
// constant propagation is required for dead code analysis, dead code analysis
Expand All @@ -426,8 +423,7 @@ void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
solver.load<AllocationAnalysis>();
if (failed(solver.initializeAndRun(func))) {
llvm::errs() << "DataFlowSolver failed!";
gotError = true;
return;
return mlir::failure();
}

LatticePoint point{func};
Expand Down Expand Up @@ -458,22 +454,17 @@ void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
: candidateOps) {
llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
});
return mlir::success();
}

bool StackArraysAnalysisWrapper::hasErrors() const { return gotError; }

const StackArraysAnalysisWrapper::AllocMemMap &
const StackArraysAnalysisWrapper::AllocMemMap *
StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) {
if (!funcMaps.count(func))
analyseFunction(func);
return funcMaps[func];
if (!funcMaps.contains(func))
if (mlir::failed(analyseFunction(func)))
return nullptr;
return &funcMaps[func];
}

AllocMemConversion::AllocMemConversion(
mlir::MLIRContext *ctx,
const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps)
: OpRewritePattern(ctx), candidateOps(candidateOps) {}

mlir::LogicalResult
AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
mlir::PatternRewriter &rewriter) const {
Expand Down Expand Up @@ -709,29 +700,31 @@ void StackArraysPass::runOnFunc(mlir::Operation *func) {
assert(mlir::isa<mlir::func::FuncOp>(func));

auto &analysis = getAnalysis<StackArraysAnalysisWrapper>();
const auto &candidateOps = analysis.getCandidateOps(func);
if (analysis.hasErrors()) {
const StackArraysAnalysisWrapper::AllocMemMap *candidateOps =
analysis.getCandidateOps(func);
if (!candidateOps) {
signalPassFailure();
return;
}

if (candidateOps.empty())
if (candidateOps->empty())
return;
runCount += candidateOps.size();
runCount += candidateOps->size();

llvm::SmallVector<mlir::Operation *> opsToConvert;
opsToConvert.reserve(candidateOps->size());
for (auto [op, _] : *candidateOps)
opsToConvert.push_back(op);

mlir::MLIRContext &context = getContext();
mlir::RewritePatternSet patterns(&context);
mlir::ConversionTarget target(context);

target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect>();
target.addDynamicallyLegalOp<fir::AllocMemOp>([&](fir::AllocMemOp alloc) {
return !candidateOps.count(alloc.getOperation());
});
mlir::GreedyRewriteConfig config;
// prevent the pattern driver form merging blocks
config.enableRegionSimplification = false;

patterns.insert<AllocMemConversion>(&context, candidateOps);
if (mlir::failed(
mlir::applyPartialConversion(func, target, std::move(patterns)))) {
patterns.insert<AllocMemConversion>(&context, *candidateOps);
if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
std::move(patterns), config))) {
mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
signalPassFailure();
}
Expand Down

0 comments on commit 74c2ec5

Please sign in to comment.