Skip to content

Commit

Permalink
[flang] use greedy mlir driver for stack arrays pass
Browse files Browse the repository at this point in the history
In upstream mlir, the dialect conversion infrastructure is used for
lowering from one dialect to another: the passes are of the form
XToYPass. Whereas, transformations within the same dialect tend to use
applyPatternsAndFoldGreedily.

In this case, the full complexity of applyPatternsAndFoldGreedily isn't
needed so we can get away with the simpler applyOpPatternsAndFold.

This change was suggested by @jeanPerier

The old differential revision for this patch was
https://reviews.llvm.org/D150853

Re-applying here fixing the issue which led to the patch being reverted. The
issue was from erasing uses of the allocation operation while still iterating
over those uses (leading to a use-after-free). I have added a regression
test which catches this bug for -fsanitize=address builds, but it is
hard to reliably cause a crash from the use-after-free in normal builds.

Differential Revision: https://reviews.llvm.org/D151728
  • Loading branch information
tblah committed May 31, 2023
1 parent 5437056 commit 408f419
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 42 deletions.
81 changes: 39 additions & 42 deletions flang/lib/Optimizer/Transforms/StackArrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
Expand Down Expand Up @@ -167,25 +167,22 @@ class StackArraysAnalysisWrapper {

StackArraysAnalysisWrapper(mlir::Operation *op) {}

bool hasErrors() const;

const AllocMemMap &getCandidateOps(mlir::Operation *func);
// returns nullptr if analysis failed
const AllocMemMap *getCandidateOps(mlir::Operation *func);

private:
llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps;
bool gotError = false;

void analyseFunction(mlir::Operation *func);
mlir::LogicalResult analyseFunction(mlir::Operation *func);
};

/// Converts a fir.allocmem to a fir.alloca
class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
public:
using OpRewritePattern::OpRewritePattern;

AllocMemConversion(
explicit AllocMemConversion(
mlir::MLIRContext *ctx,
const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps);
const StackArraysAnalysisWrapper::AllocMemMap &candidateOps)
: OpRewritePattern(ctx), candidateOps{candidateOps} {}

mlir::LogicalResult
matchAndRewrite(fir::AllocMemOp allocmem,
Expand All @@ -196,9 +193,8 @@ class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc);

private:
/// allocmem operations that DFA has determined are safe to move to the stack
/// mapping to where to insert replacement freemem operations
const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps;
/// Handle to the DFA (already run)
const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;

/// If we failed to find an insertion point not inside a loop, see if it would
/// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
Expand Down Expand Up @@ -412,7 +408,8 @@ void AllocationAnalysis::processOperation(mlir::Operation *op) {
visitOperationImpl(op, *before, after);
}

void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
mlir::LogicalResult
StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
assert(mlir::isa<mlir::func::FuncOp>(func));
mlir::DataFlowSolver solver;
// constant propagation is required for dead code analysis, dead code analysis
Expand All @@ -426,8 +423,7 @@ void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
solver.load<AllocationAnalysis>();
if (failed(solver.initializeAndRun(func))) {
llvm::errs() << "DataFlowSolver failed!";
gotError = true;
return;
return mlir::failure();
}

LatticePoint point{func};
Expand Down Expand Up @@ -458,22 +454,17 @@ void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
: candidateOps) {
llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
});
return mlir::success();
}

bool StackArraysAnalysisWrapper::hasErrors() const { return gotError; }

const StackArraysAnalysisWrapper::AllocMemMap &
const StackArraysAnalysisWrapper::AllocMemMap *
StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) {
if (!funcMaps.count(func))
analyseFunction(func);
return funcMaps[func];
if (!funcMaps.contains(func))
if (mlir::failed(analyseFunction(func)))
return nullptr;
return &funcMaps[func];
}

AllocMemConversion::AllocMemConversion(
mlir::MLIRContext *ctx,
const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps)
: OpRewritePattern(ctx), candidateOps(candidateOps) {}

mlir::LogicalResult
AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
mlir::PatternRewriter &rewriter) const {
Expand All @@ -485,9 +476,13 @@ AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
return mlir::failure();

// remove freemem operations
llvm::SmallVector<mlir::Operation *> erases;
for (mlir::Operation *user : allocmem.getOperation()->getUsers())
if (mlir::isa<fir::FreeMemOp>(user))
rewriter.eraseOp(user);
erases.push_back(user);
// now we are done iterating the users, it is safe to mutate them
for (mlir::Operation *erase : erases)
rewriter.eraseOp(erase);

// replace references to heap allocation with references to stack allocation
rewriter.replaceAllUsesWith(allocmem.getResult(), alloca->getResult());
Expand Down Expand Up @@ -709,29 +704,31 @@ void StackArraysPass::runOnFunc(mlir::Operation *func) {
assert(mlir::isa<mlir::func::FuncOp>(func));

auto &analysis = getAnalysis<StackArraysAnalysisWrapper>();
const auto &candidateOps = analysis.getCandidateOps(func);
if (analysis.hasErrors()) {
const StackArraysAnalysisWrapper::AllocMemMap *candidateOps =
analysis.getCandidateOps(func);
if (!candidateOps) {
signalPassFailure();
return;
}

if (candidateOps.empty())
if (candidateOps->empty())
return;
runCount += candidateOps.size();
runCount += candidateOps->size();

llvm::SmallVector<mlir::Operation *> opsToConvert;
opsToConvert.reserve(candidateOps->size());
for (auto [op, _] : *candidateOps)
opsToConvert.push_back(op);

mlir::MLIRContext &context = getContext();
mlir::RewritePatternSet patterns(&context);
mlir::ConversionTarget target(context);

target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect>();
target.addDynamicallyLegalOp<fir::AllocMemOp>([&](fir::AllocMemOp alloc) {
return !candidateOps.count(alloc.getOperation());
});
mlir::GreedyRewriteConfig config;
// prevent the pattern driver form merging blocks
config.enableRegionSimplification = false;

patterns.insert<AllocMemConversion>(&context, candidateOps);
if (mlir::failed(
mlir::applyPartialConversion(func, target, std::move(patterns)))) {
patterns.insert<AllocMemConversion>(&context, *candidateOps);
if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
std::move(patterns), config))) {
mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
signalPassFailure();
}
Expand Down
27 changes: 27 additions & 0 deletions flang/test/Transforms/stack-arrays.fir
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,33 @@ func.func @dfa3(%arg0: i1) {
// CHECK-NEXT: return
// CHECK-NEXT: }

func.func private @dfa3a_foo(!fir.ref<!fir.array<1xi8>>) -> ()
func.func private @dfa3a_bar(!fir.ref<!fir.array<1xi8>>) -> ()

// Check freemem in both regions, with other uses
func.func @dfa3a(%arg0: i1) {
%a = fir.allocmem !fir.array<1xi8>
fir.if %arg0 {
%ref = fir.convert %a : (!fir.heap<!fir.array<1xi8>>) -> !fir.ref<!fir.array<1xi8>>
func.call @dfa3a_foo(%ref) : (!fir.ref<!fir.array<1xi8>>) -> ()
fir.freemem %a : !fir.heap<!fir.array<1xi8>>
} else {
%ref = fir.convert %a : (!fir.heap<!fir.array<1xi8>>) -> !fir.ref<!fir.array<1xi8>>
func.call @dfa3a_bar(%ref) : (!fir.ref<!fir.array<1xi8>>) -> ()
fir.freemem %a : !fir.heap<!fir.array<1xi8>>
}
return
}
// CHECK: func.func @dfa3a(%arg0: i1) {
// CHECK-NEXT: %[[MEM:.*]] = fir.alloca !fir.array<1xi8>
// CHECK-NEXT: fir.if %arg0 {
// CHECK-NEXT: func.call @dfa3a_foo(%[[MEM]])
// CHECK-NEXT: } else {
// CHECK-NEXT: func.call @dfa3a_bar(%[[MEM]])
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }

// check the alloca is placed after all operands become available
func.func @placement1() {
// do some stuff with other ssa values
Expand Down

0 comments on commit 408f419

Please sign in to comment.