Skip to content

Commit

Permalink
[flang][NFC] speed-up external name conversion pass (#86814)
Browse files Browse the repository at this point in the history
The ExternalNameConversion pass can be surprisingly slow on big
programs. On an example with a 50kloc Fortran file with about 10000
calls to external procedures, the pass alone took 25s on my machine.
This patch reduces this to 0.16s.

The root cause is that using `replaceAllSymbolUses` on each modified
FuncOp is very expensive: it is walking all operations and attribute
every time.

An alternative would be to use mlir::SymbolUserMap to avoid walking the
module again and again, but this is still much more expensive than what
is needed because it is essentially caching all symbol uses of the
module, and there is no need to such caching here.

Instead:
- Do a shallow walk of the module (only top level operation) to detect
FuncOp/GlobalOp that needs to be updated. Update them and place the name
remapping in a DenseMap.
- If any remapping were done, do a single deep walk of the module
operation, and update any SymbolRefAttr that matches a name that was
remapped.
  • Loading branch information
jeanPerier committed Apr 2, 2024
1 parent fa8dc36 commit 2d14ea6
Showing 1 changed file with 34 additions and 121 deletions.
155 changes: 34 additions & 121 deletions flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Support/InternalNames.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace fir {
#define GEN_PASS_DEF_EXTERNALNAMECONVERSION
Expand All @@ -44,102 +40,8 @@ mangleExternalName(const std::pair<fir::NameUniquer::NameKind,
appendUnderscore);
}

//===----------------------------------------------------------------------===//
// Rewrite patterns
//===----------------------------------------------------------------------===//

namespace {

struct MangleNameOnFuncOp : public mlir::OpRewritePattern<mlir::func::FuncOp> {
public:
using OpRewritePattern::OpRewritePattern;

MangleNameOnFuncOp(mlir::MLIRContext *ctx, bool appendUnderscore)
: mlir::OpRewritePattern<mlir::func::FuncOp>(ctx),
appendUnderscore(appendUnderscore) {}

mlir::LogicalResult
matchAndRewrite(mlir::func::FuncOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::LogicalResult ret = success();
rewriter.startOpModification(op);
llvm::StringRef oldName = op.getSymName();
auto result = fir::NameUniquer::deconstruct(oldName);
if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
auto newSymbol =
rewriter.getStringAttr(mangleExternalName(result, appendUnderscore));

// Try to update all SymbolRef's in the module that match the current op
if (mlir::ModuleOp mod = op->getParentOfType<mlir::ModuleOp>())
ret = op.replaceAllSymbolUses(newSymbol, mod);

op.setSymNameAttr(newSymbol);
mlir::SymbolTable::setSymbolName(op, newSymbol);

op->setAttr(fir::getInternalFuncNameAttrName(),
mlir::StringAttr::get(op->getContext(), oldName));
}
rewriter.finalizeOpModification(op);
return ret;
}

private:
bool appendUnderscore;
};

struct MangleNameForCommonBlock : public mlir::OpRewritePattern<fir::GlobalOp> {
public:
using OpRewritePattern::OpRewritePattern;

MangleNameForCommonBlock(mlir::MLIRContext *ctx, bool appendUnderscore)
: mlir::OpRewritePattern<fir::GlobalOp>(ctx),
appendUnderscore(appendUnderscore) {}

mlir::LogicalResult
matchAndRewrite(fir::GlobalOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.startOpModification(op);
auto result = fir::NameUniquer::deconstruct(
op.getSymref().getRootReference().getValue());
if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
auto newName = mangleExternalName(result, appendUnderscore);
op.setSymrefAttr(mlir::SymbolRefAttr::get(op.getContext(), newName));
SymbolTable::setSymbolName(op, newName);
}
rewriter.finalizeOpModification(op);
return success();
}

private:
bool appendUnderscore;
};

struct MangleNameOnAddrOfOp : public mlir::OpRewritePattern<fir::AddrOfOp> {
public:
using OpRewritePattern::OpRewritePattern;

MangleNameOnAddrOfOp(mlir::MLIRContext *ctx, bool appendUnderscore)
: mlir::OpRewritePattern<fir::AddrOfOp>(ctx),
appendUnderscore(appendUnderscore) {}

mlir::LogicalResult
matchAndRewrite(fir::AddrOfOp op,
mlir::PatternRewriter &rewriter) const override {
auto result = fir::NameUniquer::deconstruct(
op.getSymbol().getRootReference().getValue());
if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
auto newName = SymbolRefAttr::get(
op.getContext(), mangleExternalName(result, appendUnderscore));
rewriter.replaceOpWithNewOp<fir::AddrOfOp>(op, op.getResTy().getType(),
newName);
}
return success();
}

private:
bool appendUnderscore;
};

class ExternalNameConversionPass
: public fir::impl::ExternalNameConversionBase<ExternalNameConversionPass> {
public:
Expand All @@ -162,31 +64,42 @@ void ExternalNameConversionPass::runOnOperation() {
auto *context = &getContext();

appendUnderscores = (usePassOpt) ? appendUnderscoreOpt : appendUnderscores;
llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr> remappings;
// Update names of external Fortran functions and names of Common Block
// globals.
for (auto &funcOrGlobal : op->getRegion(0).front()) {
if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal) ||
llvm::isa<fir::GlobalOp>(funcOrGlobal)) {
auto symName = funcOrGlobal.getAttrOfType<mlir::StringAttr>(
mlir::SymbolTable::getSymbolAttrName());
auto deconstructedName = fir::NameUniquer::deconstruct(symName);
if (fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) {
auto newName = mangleExternalName(deconstructedName, appendUnderscores);
auto newAttr = mlir::StringAttr::get(context, newName);
mlir::SymbolTable::setSymbolName(&funcOrGlobal, newAttr);
auto newSymRef = mlir::FlatSymbolRefAttr::get(newAttr);
remappings.try_emplace(symName, newSymRef);
if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal))
funcOrGlobal.setAttr(fir::getInternalFuncNameAttrName(), symName);
}
}
}

mlir::RewritePatternSet patterns(context);
patterns.insert<MangleNameOnFuncOp, MangleNameForCommonBlock,
MangleNameOnAddrOfOp>(context, appendUnderscores);

ConversionTarget target(*context);
target.addLegalDialect<fir::FIROpsDialect, LLVM::LLVMDialect,
acc::OpenACCDialect, omp::OpenMPDialect>();

target.addDynamicallyLegalOp<mlir::func::FuncOp>([](mlir::func::FuncOp op) {
return !fir::NameUniquer::needExternalNameMangling(op.getSymName());
});

target.addDynamicallyLegalOp<fir::GlobalOp>([](fir::GlobalOp op) {
return !fir::NameUniquer::needExternalNameMangling(
op.getSymref().getRootReference().getValue());
});

target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp op) {
return !fir::NameUniquer::needExternalNameMangling(
op.getSymbol().getRootReference().getValue());
if (remappings.empty())
return;

// Update all uses of the functions and globals that have been renamed.
op.walk([&remappings](mlir::Operation *nestedOp) {
llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>> updates;
for (const mlir::NamedAttribute &attr : nestedOp->getAttrDictionary())
if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue()))
if (auto remap = remappings.find(symRef.getRootReference());
remap != remappings.end())
updates.emplace_back(std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{
attr.getName(), mlir::SymbolRefAttr(remap->second)});
for (auto update : updates)
nestedOp->setAttr(update.first, update.second);
});

if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}

std::unique_ptr<mlir::Pass> fir::createExternalNameConversionPass() {
Expand Down

0 comments on commit 2d14ea6

Please sign in to comment.