-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flang] Extracting internal constants from scalar literals #73829
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-driver Author: Mats Petersson (Leporacanthicus) ChangesConstants actual arguments in function/subroutine calls are currently lowered as allocas + store. This can sometimes inhibit LTO and the constant will not be propagated to the called function. Particularly in cases where the function/subroutine call happens inside a condition. This patch changes the lowering of these constant actual arguments to a global constant + fir.address_of_op. This lowering makes it easier for LTO to propagate the constant. Full diff: https://github.com/llvm/llvm-project/pull/73829.diff 14 Files Affected:
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 92bc7246eca7005..f1c38a026660243 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -60,6 +60,7 @@ createExternalNameConversionPass(bool appendUnderscore);
std::unique_ptr<mlir::Pass> createMemDataFlowOptPass();
std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
std::unique_ptr<mlir::Pass> createMemoryAllocationPass();
+std::unique_ptr<mlir::Pass> createConstExtruderPass();
std::unique_ptr<mlir::Pass> createStackArraysPass();
std::unique_ptr<mlir::Pass> createAliasTagsPass();
std::unique_ptr<mlir::Pass> createSimplifyIntrinsicsPass();
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index c3768fd2d689c1a..179833876a7b333 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -242,6 +242,16 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
let constructor = "::fir::createMemoryAllocationPass()";
}
+// This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
+def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
+ let summary = "Convert scalar literals of function arguments to global constants.";
+ let description = [{
+ Convert scalar literals of function arguments to global constants.
+ }];
+ let dependentDialects = [ "fir::FIROpsDialect" ];
+ let constructor = "::fir::createConstExtruderPass()";
+}
+
def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
let summary = "Move local array allocations from heap memory into stack memory";
let description = [{
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index d3e4dc6cd4a243e..b902621dfe42177 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -216,6 +216,8 @@ inline void createDefaultFIROptimizerPassPipeline(
else
fir::addMemoryAllocationOpt(pm);
+ pm.addPass(fir::createConstExtruderPass());
+
// The default inliner pass adds the canonicalizer pass with the default
// configuration. Create the inliner pass with tco config.
llvm::StringMap<mlir::OpPassManager> pipelines;
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 03b67104a93b575..bada67729ede95c 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_flang_library(FIRTransforms
AffineDemotion.cpp
AnnotateConstant.cpp
CharacterConversion.cpp
+ ConstExtruder.cpp
ControlFlowConverter.cpp
ArrayValueCopy.cpp
ExternalNameConversion.cpp
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
new file mode 100644
index 000000000000000..1bb1cd226987110
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -0,0 +1,216 @@
+//===- ConstExtruder.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <atomic>
+
+namespace fir {
+#define GEN_PASS_DEF_CONSTEXTRUDEROPT
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+#define DEBUG_TYPE "flang-const-extruder-opt"
+
+namespace {
+std::atomic<int> uniqueLitId = 1;
+
+static bool needsExtrusion(const mlir::Value *a) {
+ if (!a || !a->getDefiningOp())
+ return false;
+
+ // is alloca
+ if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
+ // alloca has annotation
+ if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
+ for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
+ if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
+ auto constant_def = store->getOperand(0).getDefiningOp();
+ // Expect constant definition operation
+ if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+ return true;
+ }
+ }
+ }
+ }
+ }
+ return false;
+}
+
+class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
+protected:
+ mlir::DominanceInfo &di;
+
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
+ : OpRewritePattern(ctx), di(_di) {}
+
+ mlir::LogicalResult
+ matchAndRewrite(fir::CallOp callOp,
+ mlir::PatternRewriter &rewriter) const override {
+ LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
+ auto module = callOp->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, module);
+ llvm::SmallVector<mlir::Value> newOperands;
+ llvm::SmallVector<mlir::Operation *> toErase;
+ for (const auto &a : callOp.getArgs()) {
+ if (auto alloca =
+ mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
+ if (needsExtrusion(&a)) {
+
+ mlir::Type varTy = alloca.getInType();
+ assert(!fir::hasDynamicSize(varTy) &&
+ "only expect statically sized scalars to be by value");
+
+ // find immediate store with const argument
+ llvm::SmallVector<mlir::Operation *> stores;
+ for (mlir::Operation *s : alloca.getOperation()->getUsers())
+ if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
+ stores.push_back(s);
+ assert(stores.size() == 1 && "expected exactly one store");
+ LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
+
+ auto constant_def = stores[0]->getOperand(0).getDefiningOp();
+ // Expect constant definition operation or force legalisation of the
+ // callOp and continue with its next argument
+ if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+ // unable to remove alloca arg
+ newOperands.push_back(a);
+ continue;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
+
+ auto loc = callOp.getLoc();
+ llvm::StringRef globalPrefix = "_extruded_";
+
+ std::string globalName;
+ while (!globalName.length() || builder.getNamedGlobal(globalName))
+ globalName =
+ globalPrefix.str() + "." + std::to_string(uniqueLitId++);
+
+ if (alloca->hasOneUse()) {
+ toErase.push_back(alloca);
+ toErase.push_back(stores[0]);
+ } else {
+ int count = -2;
+ for (mlir::Operation *s : alloca.getOperation()->getUsers())
+ if (di.dominates(stores[0], s))
+ ++count;
+
+ // delete if dominates itself and one more operation (which should
+ // be callOp)
+ if (!count)
+ toErase.push_back(stores[0]);
+ }
+ auto global = builder.createGlobalConstant(
+ loc, varTy, globalName,
+ [&](fir::FirOpBuilder &builder) {
+ mlir::Operation *cln = constant_def->clone();
+ builder.insert(cln);
+ fir::ExtendedValue exv{cln->getResult(0)};
+ mlir::Value valBase = fir::getBase(exv);
+ mlir::Value val = builder.createConvert(loc, varTy, valBase);
+ builder.create<fir::HasValueOp>(loc, val);
+ },
+ builder.createInternalLinkage());
+ mlir::Value ope = {builder.create<fir::AddrOfOp>(
+ loc, global.resultType(), global.getSymbol())};
+ newOperands.push_back(ope);
+ } else {
+ // alloca but without attr, add it
+ newOperands.push_back(a);
+ }
+ } else {
+ // non-alloca operand, add it
+ newOperands.push_back(a);
+ }
+ }
+
+ auto loc = callOp.getLoc();
+ llvm::SmallVector<mlir::Type> newResultTypes;
+ newResultTypes.append(callOp.getResultTypes().begin(),
+ callOp.getResultTypes().end());
+ fir::CallOp newOp = builder.create<fir::CallOp>(
+ loc, newResultTypes,
+ callOp.getCallee().has_value() ? callOp.getCallee().value()
+ : mlir::SymbolRefAttr{},
+ newOperands, callOp.getFastmathAttr());
+ rewriter.replaceOp(callOp, newOp);
+
+ for (auto e : toErase)
+ rewriter.eraseOp(e);
+
+ LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
+ << newOp << '\n');
+ return mlir::success();
+ }
+};
+
+// This pass attempts to convert immediate scalar literals in function calls
+// to global constants to allow transformations as Dead Argument Elimination
+class ConstExtruderOpt
+ : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
+protected:
+ mlir::DominanceInfo *di;
+
+public:
+ ConstExtruderOpt() {}
+
+ void runOnOperation() override {
+ mlir::ModuleOp mod = getOperation();
+ di = &getAnalysis<mlir::DominanceInfo>();
+ mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
+ }
+
+ void runOnFunc(mlir::func::FuncOp &func) {
+ auto *context = &getContext();
+ mlir::RewritePatternSet patterns(context);
+ mlir::ConversionTarget target(*context);
+
+ // If func is a declaration, skip it.
+ if (func.empty())
+ return;
+
+ target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+ mlir::func::FuncDialect>();
+ target.addDynamicallyLegalOp<fir::CallOp>([&](fir::CallOp op) {
+ for (auto a : op.getArgs()) {
+ if (needsExtrusion(&a))
+ return false;
+ }
+ return true;
+ });
+
+ patterns.insert<CallOpRewriter>(context, *di);
+ if (mlir::failed(
+ mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+ mlir::emitError(func.getLoc(),
+ "error in constant extrusion optimization\n");
+ signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::createConstExtruderPass() {
+ return std::make_unique<ConstExtruderOpt>();
+}
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index 243a620a9fd003c..c43149e07bf55f3 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -31,6 +31,7 @@
! CHECK-NEXT: 'func.func' Pipeline
! CHECK-NEXT: MemoryAllocationOpt
+! CHECK-NEXT: ConstExtruderOpt
! CHECK-NEXT: Inliner
! CHECK-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index a3ff416f4d77951..5eb2354b67012e0 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -51,6 +51,7 @@
! ALL-NEXT: 'func.func' Pipeline
! ALL-NEXT: MemoryAllocationOpt
+! ALL-NEXT: ConstExtruderOpt
! ALL-NEXT: Inliner
! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 3d8c42f123e2eb0..10bc9b90cfd769a 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -42,6 +42,7 @@
! ALL-NEXT: 'func.func' Pipeline
! ALL-NEXT: MemoryAllocationOpt
+! ALL-NEXT: ConstExtruderOpt
! ALL-NEXT: Inliner
! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index d8a9e74c318ce18..b9aabd322399ca2 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -48,6 +48,7 @@ func.func @_QQmain() {
// PASSES-NEXT: 'func.func' Pipeline
// PASSES-NEXT: MemoryAllocationOpt
+// PASSES-NEXT: ConstExtruderOpt
// PASSES-NEXT: Inliner
// PASSES-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/boxproc.fir b/flang/test/Fir/boxproc.fir
index 1fed16a808af042..2ddc0ef525ac481 100644
--- a/flang/test/Fir/boxproc.fir
+++ b/flang/test/Fir/boxproc.fir
@@ -16,9 +16,7 @@
// CHECK-LABEL: define void @_QPtest_proc_dummy_other(ptr
// CHECK-SAME: %[[VAL_0:.*]])
-// CHECK: %[[VAL_1:.*]] = alloca i32, i64 1, align 4
-// CHECK: store i32 4, ptr %[[VAL_1]], align 4
-// CHECK: call void %[[VAL_0]](ptr %[[VAL_1]])
+// CHECK: call void %[[VAL_0]](ptr @{{.*}})
func.func @_QPtest_proc_dummy() {
%c0_i32 = arith.constant 0 : i32
diff --git a/flang/test/Lower/character-local-variables.f90 b/flang/test/Lower/character-local-variables.f90
index 0cf61a2623c4e73..b1cfc540f438966 100644
--- a/flang/test/Lower/character-local-variables.f90
+++ b/flang/test/Lower/character-local-variables.f90
@@ -116,8 +116,7 @@ subroutine dyn_array_dyn_len_lb(l, n)
subroutine assumed_length_param(n)
character(*), parameter :: c(1)=(/"abcd"/)
integer :: n
- ! CHECK: %[[c4:.*]] = arith.constant 4 : i64
- ! CHECK: fir.store %[[c4]] to %[[tmp:.*]] : !fir.ref<i64>
+ ! CHECK: %[[tmp:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i64>
! CHECK: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
call take_int(len(c(n), kind=8))
end
diff --git a/flang/test/Lower/dummy-arguments.f90 b/flang/test/Lower/dummy-arguments.f90
index 43d8e3c1e5d4485..46e4323e8862049 100644
--- a/flang/test/Lower/dummy-arguments.f90
+++ b/flang/test/Lower/dummy-arguments.f90
@@ -2,9 +2,7 @@
! CHECK-LABEL: _QQmain
program test1
- ! CHECK-DAG: %[[TMP:.*]] = fir.alloca
- ! CHECK-DAG: %[[TEN:.*]] = arith.constant
- ! CHECK: fir.store %[[TEN]] to %[[TMP]]
+ ! CHECK-DAG: %[[TEN:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
! CHECK-NEXT: fir.call @_QFPfoo
call foo(10)
contains
diff --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90
index e2db6deb8803d08..26598ef1f16ea31 100644
--- a/flang/test/Lower/host-associated.f90
+++ b/flang/test/Lower/host-associated.f90
@@ -448,11 +448,10 @@ subroutine bar()
! CHECK-LABEL: func @_QPtest_proc_dummy_other(
! CHECK-SAME: %[[VAL_0:.*]]: !fir.boxproc<() -> ()>) {
-! CHECK: %[[VAL_1:.*]] = arith.constant 4 : i32
-! CHECK: %[[VAL_2:.*]] = fir.alloca i32 {adapt.valuebyref}
-! CHECK: fir.store %[[VAL_1]] to %[[VAL_2]] : !fir.ref<i32>
! CHECK: %[[VAL_3:.*]] = fir.box_addr %[[VAL_0]] : (!fir.boxproc<() -> ()>) -> ((!fir.ref<i32>) -> ())
-! CHECK: fir.call %[[VAL_3]](%[[VAL_2]]) {{.*}}: (!fir.ref<i32>) -> ()
+! CHECK: %[[VAL_1:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
+! CHECK: fir.call %[[VAL_3]](%[[VAL_1]]) {{.*}}: (!fir.ref<i32>) -> ()
+
! CHECK: return
! CHECK: }
diff --git a/flang/test/Transforms/const-extrude.f90 b/flang/test/Transforms/const-extrude.f90
new file mode 100644
index 000000000000000..70cdaf496f34acc
--- /dev/null
+++ b/flang/test/Transforms/const-extrude.f90
@@ -0,0 +1,32 @@
+! RUN: %flang_fc1 -emit-fir %s -o - | fir-opt --const-extruder-opt | FileCheck %s
+
+subroutine sub1(x,y)
+ implicit none
+ integer x, y
+
+ call sub2(0.0d0, 1.0d0, x, y, 1)
+end subroutine sub1
+
+!CHECK-LABEL: func.func @_QPsub1
+!CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
+!CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
+!CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
+!CHECK: [[Y:%.*]] = fir.declare [[ARG1]] {{.*}}
+!CHECK: [[CONST_R0:%.*]] = fir.address_of([[EXTR_0:@.*]]) : !fir.ref<f64>
+!CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
+!CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
+!CHECK: fir.call @_QPsub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
+!CHECK: return
+
+!CHECK: fir.global internal [[EXTR_0]] constant : f64 {
+!CHECK: %{{.*}} = arith.constant 0.000000e+00 : f64
+!CHECK: fir.has_value %{{.*}} : f64
+!CHECK: }
+!CHECK: fir.global internal [[EXTR_1]] constant : f64 {
+!CHECK: %{{.*}} = arith.constant 1.000000e+00 : f64
+!CHECK: fir.has_value %{{.*}} : f64
+!CHECK: }
+!CHECK: fir.global internal [[EXTR_2]] constant : i32 {
+!CHECK: %{{.*}} = arith.constant 1 : i32
+!CHECK: fir.has_value %{{.*}} : i32
+!CHECK: }
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Leporacanthicus. If I remember correctly, this pass requires the alias tags pass. Since that pass is reverted, this will have to wait.
And performance results have to be mentioned or present along with the patch.
Benchmark for Spec-2017 503.bwaves is about 11% better on AArch64, and about 13% better on a fairly old x86-64 machine. As far as I can tell, there's no slow-down, nor any improvement on other benchmarks. (I will add this to the commit message too). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor comments.
|
Constants actual arguments in function/subroutine calls are currently lowered as allocas + store. This can sometimes inhibit LTO and the constant will not be propagated to the called function. Particularly in cases where the function/subroutine call happens inside a condition. This patch changes the lowering of these constant actual arguments to a global constant + fir.address_of_op. This lowering makes it easier for LTO to propagate the constant. Co-authored-by: Dmitriy Smirnov <dmitriy.smirnov@arm.com>
f936dde
to
bbd355c
Compare
Thanks for the changes. LGTM once CI passes |
bbd355c
to
a3309fd
Compare
Constants actual arguments in function/subroutine calls are currently lowered as allocas + store. This can sometimes inhibit LTO and the constant will not be propagated to the called function. Particularly in cases where the function/subroutine call happens inside a condition.
This patch changes the lowering of these constant actual arguments to a global constant + fir.address_of_op. This lowering makes it easier for LTO to propagate the constant.