Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][NFC] use mlir::SymbolTable in lowering #86673

Merged
merged 4 commits into from
Apr 2, 2024

Conversation

jeanPerier
Copy link
Contributor

Whenever lowering is checking if a function or global already exists in the mlir::Module, it was doing module->lookup.

On big programs (~5000 globals and functions), this causes important slowdowns because these lookups are linear. Use mlir::SymbolTable to speed-up these lookups. The SymbolTable has to be created from the ModuleOp and maintained in sync. It is therefore placed in the converter, and FirOPBuilders can take a pointer to it to speed-up the lookups.

This patch does not bring mlir::SymbolTable to FIR/HLFIR passes, but some passes creating a lot of runtime calls could benefit from it too. More analysis will be needed.

As an example of the speed-ups, this patch speeds-up compilation of Whizard compare_amplitude_UFO.F90 from 5 mins to 2 mins on my machine (there is still room for speed-ups).

Whenever lowering is checking if a function or global already
exists in the mlir::Module, it was doing module->lookup.

On big programs (~5000 globals and functions), this causes
important slowdowns because these lookups are linear. Use
mlir::SymbolTable to speed-up these lookups. The SymbolTable
has to be created from the ModuleOp and maintained in sync.
It is therefore placed in the converter, and FirOPBuilders
can take a pointer to it to speed-up the lookups.

This patch does not bring mlir::SymbolTable to FIR/HLFIR
passes, but some passes creating a lot of runtime calls could
benefit from it too. More analysis will be needed.

As an example of the speed-ups, this patch speeds-up compilation
of Whizard compare_amplitude_UFO from 5 mins to 2 mins on
my machine (there is still room for speed-ups).
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir openacc labels Mar 26, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 26, 2024

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-openacc

Author: None (jeanPerier)

Changes

Whenever lowering is checking if a function or global already exists in the mlir::Module, it was doing module->lookup.

On big programs (~5000 globals and functions), this causes important slowdowns because these lookups are linear. Use mlir::SymbolTable to speed-up these lookups. The SymbolTable has to be created from the ModuleOp and maintained in sync. It is therefore placed in the converter, and FirOPBuilders can take a pointer to it to speed-up the lookups.

This patch does not bring mlir::SymbolTable to FIR/HLFIR passes, but some passes creating a lot of runtime calls could benefit from it too. More analysis will be needed.

As an example of the speed-ups, this patch speeds-up compilation of Whizard compare_amplitude_UFO.F90 from 5 mins to 2 mins on my machine (there is still room for speed-ups).


Patch is 40.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86673.diff

12 Files Affected:

  • (modified) flang/include/flang/Lower/AbstractConverter.h (+13)
  • (modified) flang/include/flang/Optimizer/Builder/FIRBuilder.h (+35-35)
  • (modified) flang/include/flang/Optimizer/Dialect/FIROpsSupport.h (+11-8)
  • (modified) flang/lib/Lower/Bridge.cpp (+17-6)
  • (modified) flang/lib/Lower/CallInterface.cpp (+6-3)
  • (modified) flang/lib/Lower/OpenACC.cpp (+4-2)
  • (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+31-14)
  • (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+3-3)
  • (modified) flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp (+30-32)
  • (modified) flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp (+17-17)
  • (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+13-5)
  • (modified) flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp (+2-5)
diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index 32e7a5e2b04061..d5dab9040d22bd 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -23,6 +23,10 @@
 #include "mlir/IR/Operation.h"
 #include "llvm/ADT/ArrayRef.h"
 
+namespace mlir {
+class SymbolTable;
+}
+
 namespace fir {
 class KindMapping;
 class FirOpBuilder;
@@ -305,6 +309,15 @@ class AbstractConverter {
   virtual Fortran::lower::SymbolBox
   lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) = 0;
 
+  /// Return the mlir::SymbolTable associated to the ModuleOp.
+  /// Look-ups are faster using it than using module.lookup<>,
+  /// but the module op should be queried in case of failure
+  /// because this symbol table is not guaranteed to contain
+  /// all the symbols from the ModuleOp (the symbol table should
+  /// always be provided to the builder helper creating globals and
+  /// functions in order to be in sync).
+  virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;
+
 private:
   /// Options controlling lowering behavior.
   const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index d61bf681be6194..8537f29b2e549c 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -28,6 +28,10 @@
 #include <optional>
 #include <utility>
 
+namespace mlir {
+class SymbolTable;
+}
+
 namespace fir {
 class AbstractArrayBox;
 class ExtendedValue;
@@ -42,8 +46,10 @@ class BoxValue;
 /// patterns.
 class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
 public:
-  explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap)
-      : OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)} {}
+  explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap,
+                        mlir::SymbolTable *symbolTable = nullptr)
+      : OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)},
+        symbolTable{symbolTable} {}
   explicit FirOpBuilder(mlir::OpBuilder &builder, fir::KindMapping kindMap)
       : OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move(kindMap)} {
     setListener(this);
@@ -69,13 +75,14 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   // The listener self-reference has to be updated in case of copy-construction.
   FirOpBuilder(const FirOpBuilder &other)
       : OpBuilder(other), OpBuilder::Listener(), kindMap{other.kindMap},
-        fastMathFlags{other.fastMathFlags} {
+        fastMathFlags{other.fastMathFlags}, symbolTable{other.symbolTable} {
     setListener(this);
   }
 
   FirOpBuilder(FirOpBuilder &&other)
-      : OpBuilder(other), OpBuilder::Listener(),
-        kindMap{std::move(other.kindMap)}, fastMathFlags{other.fastMathFlags} {
+      : OpBuilder(other), OpBuilder::Listener(), kindMap{std::move(
+                                                     other.kindMap)},
+        fastMathFlags{other.fastMathFlags}, symbolTable{other.symbolTable} {
     setListener(this);
   }
 
@@ -95,6 +102,9 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   /// Get a reference to the kind map.
   const fir::KindMapping &getKindMap() { return kindMap; }
 
+  /// Get func.func/fir.global symbol table attached to this builder if any.
+  mlir::SymbolTable *getMLIRSymbolTable() { return symbolTable; }
+
   /// Get the default integer type
   [[maybe_unused]] mlir::IntegerType getDefaultIntegerType() {
     return getIntegerType(
@@ -280,25 +290,28 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   /// Get a function by name. If the function exists in the current module, it
   /// is returned. Otherwise, a null FuncOp is returned.
   mlir::func::FuncOp getNamedFunction(llvm::StringRef name) {
-    return getNamedFunction(getModule(), name);
+    return getNamedFunction(getModule(), name, getMLIRSymbolTable());
   }
-  static mlir::func::FuncOp getNamedFunction(mlir::ModuleOp module,
-                                             llvm::StringRef name);
+  static mlir::func::FuncOp
+  getNamedFunction(mlir::ModuleOp module, llvm::StringRef name,
+                   const mlir::SymbolTable *symbolTable);
 
   /// Get a function by symbol name. The result will be null if there is no
   /// function with the given symbol in the module.
   mlir::func::FuncOp getNamedFunction(mlir::SymbolRefAttr symbol) {
-    return getNamedFunction(getModule(), symbol);
+    return getNamedFunction(getModule(), symbol, getMLIRSymbolTable());
   }
-  static mlir::func::FuncOp getNamedFunction(mlir::ModuleOp module,
-                                             mlir::SymbolRefAttr symbol);
+  static mlir::func::FuncOp
+  getNamedFunction(mlir::ModuleOp module, mlir::SymbolRefAttr symbol,
+                   const mlir::SymbolTable *symbolTable);
 
   fir::GlobalOp getNamedGlobal(llvm::StringRef name) {
-    return getNamedGlobal(getModule(), name);
+    return getNamedGlobal(getModule(), name, getMLIRSymbolTable());
   }
 
   static fir::GlobalOp getNamedGlobal(mlir::ModuleOp module,
-                                      llvm::StringRef name);
+                                      llvm::StringRef name,
+                                      const mlir::SymbolTable *symbolTable);
 
   /// Lazy creation of fir.convert op.
   mlir::Value createConvert(mlir::Location loc, mlir::Type toTy,
@@ -313,35 +326,18 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   /// result of the load if it was created, otherwise return \p val
   mlir::Value loadIfRef(mlir::Location loc, mlir::Value val);
 
-  /// Create a new FuncOp. If the function may have already been created, use
-  /// `addNamedFunction` instead.
+  /// Determine if the named function is already in the module. Return the
+  /// instance if found, otherwise add a new named function to the module.
   mlir::func::FuncOp createFunction(mlir::Location loc, llvm::StringRef name,
                                     mlir::FunctionType ty) {
-    return createFunction(loc, getModule(), name, ty);
+    return createFunction(loc, getModule(), name, ty, getMLIRSymbolTable());
   }
 
   static mlir::func::FuncOp createFunction(mlir::Location loc,
                                            mlir::ModuleOp module,
                                            llvm::StringRef name,
-                                           mlir::FunctionType ty);
-
-  /// Determine if the named function is already in the module. Return the
-  /// instance if found, otherwise add a new named function to the module.
-  mlir::func::FuncOp addNamedFunction(mlir::Location loc, llvm::StringRef name,
-                                      mlir::FunctionType ty) {
-    if (auto func = getNamedFunction(name))
-      return func;
-    return createFunction(loc, name, ty);
-  }
-
-  static mlir::func::FuncOp addNamedFunction(mlir::Location loc,
-                                             mlir::ModuleOp module,
-                                             llvm::StringRef name,
-                                             mlir::FunctionType ty) {
-    if (auto func = getNamedFunction(module, name))
-      return func;
-    return createFunction(loc, module, name, ty);
-  }
+                                           mlir::FunctionType ty,
+                                           mlir::SymbolTable *);
 
   /// Cast the input value to IndexType.
   mlir::Value convertToIndexType(mlir::Location loc, mlir::Value val) {
@@ -515,6 +511,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   /// FastMathFlags that need to be set for operations that support
   /// mlir::arith::FastMathAttr.
   mlir::arith::FastMathFlags fastMathFlags{};
+
+  /// fir::GlobalOp and func::FuncOp symbol table to speed-up
+  /// lookups.
+  mlir::SymbolTable *symbolTable = nullptr;
 };
 
 } // namespace fir
diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
index e8226b6df58ca2..f29e44504acb63 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
+++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
@@ -52,16 +52,19 @@ inline bool pureCall(mlir::Operation *op) {
 /// Get or create a FuncOp in a module.
 ///
 /// If `module` already contains FuncOp `name`, it is returned. Otherwise, a new
-/// FuncOp is created, and that new FuncOp is returned.
-mlir::func::FuncOp
-createFuncOp(mlir::Location loc, mlir::ModuleOp module, llvm::StringRef name,
-             mlir::FunctionType type,
-             llvm::ArrayRef<mlir::NamedAttribute> attrs = {});
-
-/// Get or create a GlobalOp in a module.
+/// FuncOp is created, and that new FuncOp is returned. A symbol table can
+/// be provided to speed-up the lookups.
+mlir::func::FuncOp createFuncOp(mlir::Location loc, mlir::ModuleOp module,
+                                llvm::StringRef name, mlir::FunctionType type,
+                                llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
+                                const mlir::SymbolTable *symbolTable = nullptr);
+
+/// Get or create a GlobalOp in a module. A symbol table can be provided to
+/// speed-up the lookups.
 fir::GlobalOp createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
                              llvm::StringRef name, mlir::Type type,
-                             llvm::ArrayRef<mlir::NamedAttribute> attrs = {});
+                             llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
+                             const mlir::SymbolTable *symbolTable = nullptr);
 
 /// Attribute to mark Fortran entities with the CONTIGUOUS attribute.
 constexpr llvm::StringRef getContiguousAttrName() { return "fir.contiguous"; }
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 48830dc55578c2..46a259d9ae86c9 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -273,7 +273,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 public:
   explicit FirConverter(Fortran::lower::LoweringBridge &bridge)
       : Fortran::lower::AbstractConverter(bridge.getLoweringOptions()),
-        bridge{bridge}, foldingContext{bridge.createFoldingContext()} {}
+        bridge{bridge}, foldingContext{bridge.createFoldingContext()},
+        mlirSymbolTable{bridge.getModule()} {}
   virtual ~FirConverter() = default;
 
   /// Convert the PFT to FIR.
@@ -329,8 +330,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
               [&](Fortran::lower::pft::BlockDataUnit &b) {},
               [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
               [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
-                builder = new fir::FirOpBuilder(bridge.getModule(),
-                                                bridge.getKindMap());
+                builder = new fir::FirOpBuilder(
+                    bridge.getModule(), bridge.getKindMap(), &mlirSymbolTable);
                 Fortran::lower::genOpenACCRoutineConstruct(
                     *this, bridge.getSemanticsContext(), bridge.getModule(),
                     d.routine, accRoutineInfos);
@@ -1036,6 +1037,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     return {};
   }
 
+  mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }
+
   /// Add the symbol to the local map and return `true`. If the symbol is
   /// already in the map and \p forced is `false`, the map is not updated.
   /// Instead the value `false` is returned.
@@ -4570,7 +4573,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                llvm::dbgs() << "\n");
     Fortran::lower::CalleeInterface callee(funit, *this);
     mlir::func::FuncOp func = callee.addEntryBlockAndMapArguments();
-    builder = new fir::FirOpBuilder(func, bridge.getKindMap());
+    builder =
+        new fir::FirOpBuilder(func, bridge.getKindMap(), &mlirSymbolTable);
     assert(builder && "FirOpBuilder did not instantiate");
     builder->setFastMathFlags(bridge.getLoweringOptions().getMathOptions());
     builder->setInsertionPointToStart(&func.front());
@@ -4838,12 +4842,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     // FIXME: get rid of the bogus function context and instantiate the
     // globals directly into the module.
     mlir::MLIRContext *context = &getMLIRContext();
+    mlir::SymbolTable *symbolTable = getMLIRSymbolTable();
     mlir::func::FuncOp func = fir::FirOpBuilder::createFunction(
         mlir::UnknownLoc::get(context), getModuleOp(),
         fir::NameUniquer::doGenerated("Sham"),
-        mlir::FunctionType::get(context, std::nullopt, std::nullopt));
+        mlir::FunctionType::get(context, std::nullopt, std::nullopt),
+        symbolTable);
     func.addEntryBlock();
-    builder = new fir::FirOpBuilder(func, bridge.getKindMap());
+    builder = new fir::FirOpBuilder(func, bridge.getKindMap(), symbolTable);
     assert(builder && "FirOpBuilder did not instantiate");
     builder->setFastMathFlags(bridge.getLoweringOptions().getMathOptions());
     createGlobals();
@@ -5335,6 +5341,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// utilities to deal with procedure pointer components whose arguments have
   /// the type of the containing derived type.
   Fortran::lower::TypeConstructionStack typeConstructionStack;
+  /// MLIR symbol table of the fir.global/func.func operations. Note that it is
+  /// not guaranteed to contain all operations of the ModuleOp with Symbol
+  /// attribute since mlirSymbolTable must pro-actively be maintained when
+  /// new Symbol operations are created.
+  mlir::SymbolTable mlirSymbolTable;
 };
 
 } // namespace
diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index c65becc497459c..fef38da0133060 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -667,11 +667,13 @@ void Fortran::lower::CallInterface<T>::declare() {
   if (!side().isIndirectCall()) {
     std::string name = side().getMangledName();
     mlir::ModuleOp module = converter.getModuleOp();
-    func = fir::FirOpBuilder::getNamedFunction(module, name);
+    mlir::SymbolTable *symbolTable = converter.getMLIRSymbolTable();
+    func = fir::FirOpBuilder::getNamedFunction(module, name, symbolTable);
     if (!func) {
       mlir::Location loc = side().getCalleeLocation();
       mlir::FunctionType ty = genFunctionType();
-      func = fir::FirOpBuilder::createFunction(loc, module, name, ty);
+      func =
+          fir::FirOpBuilder::createFunction(loc, module, name, ty, symbolTable);
       if (const Fortran::semantics::Symbol *sym = side().getProcedureSymbol()) {
         if (side().isMainProgram()) {
           func->setAttr(fir::getSymbolAttrName(),
@@ -1644,7 +1646,8 @@ mlir::func::FuncOp Fortran::lower::getOrDeclareFunction(
     Fortran::lower::AbstractConverter &converter) {
   mlir::ModuleOp module = converter.getModuleOp();
   std::string name = getProcMangledName(proc, converter);
-  mlir::func::FuncOp func = fir::FirOpBuilder::getNamedFunction(module, name);
+  mlir::func::FuncOp func = fir::FirOpBuilder::getNamedFunction(
+      module, name, converter.getMLIRSymbolTable());
   if (func)
     return func;
 
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 7b7e4a875cd8e8..0ef3baa19c0199 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3809,7 +3809,8 @@ void Fortran::lower::genOpenACCRoutineConstruct(
   std::string funcName;
   if (name) {
     funcName = converter.mangleName(*name->symbol);
-    funcOp = builder.getNamedFunction(mod, funcName);
+    funcOp =
+        builder.getNamedFunction(mod, funcName, builder.getMLIRSymbolTable());
   } else {
     Fortran::semantics::Scope &scope =
         semanticsContext.FindScope(routineConstruct.source);
@@ -3821,7 +3822,8 @@ void Fortran::lower::genOpenACCRoutineConstruct(
             : nullptr};
     if (subpDetails && subpDetails->isInterface()) {
       funcName = converter.mangleName(*progUnit.symbol());
-      funcOp = builder.getNamedFunction(mod, funcName);
+      funcOp =
+          builder.getNamedFunction(mod, funcName, builder.getMLIRSymbolTable());
     } else {
       funcOp = builder.getFunction();
       funcName = funcOp.getName();
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 2bcd5e5914027d..a8606a79af1671 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -36,26 +36,39 @@ static llvm::cl::opt<std::size_t>
                                       "name"),
                        llvm::cl::init(32));
 
-mlir::func::FuncOp fir::FirOpBuilder::createFunction(mlir::Location loc,
-                                                     mlir::ModuleOp module,
-                                                     llvm::StringRef name,
-                                                     mlir::FunctionType ty) {
-  return fir::createFuncOp(loc, module, name, ty);
+mlir::func::FuncOp
+fir::FirOpBuilder::createFunction(mlir::Location loc, mlir::ModuleOp module,
+                                  llvm::StringRef name, mlir::FunctionType ty,
+                                  mlir::SymbolTable *symbolTable) {
+  return fir::createFuncOp(loc, module, name, ty, /*attrs*/ {}, symbolTable);
 }
 
-mlir::func::FuncOp fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp,
-                                                       llvm::StringRef name) {
+mlir::func::FuncOp
+fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp, llvm::StringRef name,
+                                    const mlir::SymbolTable *symbolTable) {
+  if (symbolTable)
+    if (auto func = symbolTable->lookup<mlir::func::FuncOp>(name))
+      return func;
   return modOp.lookupSymbol<mlir::func::FuncOp>(name);
 }
 
 mlir::func::FuncOp
 fir::FirOpBuilder::getNamedFunction(mlir::ModuleOp modOp,
-                                    mlir::SymbolRefAttr symbol) {
+                                    mlir::SymbolRefAttr symbol,
+                                    const mlir::SymbolTable *symbolTable) {
+  if (symbolTable)
+    if (auto func =
+            symbolTable->lookup<mlir::func::FuncOp>(symbol.getLeafReference()))
+      return func;
   return modOp.lookupSymbol<mlir::func::FuncOp>(symbol);
 }
 
-fir::GlobalOp fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp,
-                                                llvm::StringRef name) {
+fir::GlobalOp
+fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp, llvm::StringRef name,
+                                  const mlir::SymbolTable *symbolTable) {
+  if (symbolTable)
+    if (auto global = symbolTable->lookup<fir::GlobalOp>(name))
+      return global;
   return modOp.lookupSymbol<fir::GlobalOp>(name);
 }
 
@@ -279,10 +292,10 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
     mlir::Location loc, mlir::Type type, llvm::StringRef name,
     mlir::StringAttr linkage, mlir::Attribute value, bool isConst,
     bool isTarget, fir::CUDADataAttributeAttr cudaAttr) {
+  if (auto global = getNamedGlobal(name))
+    return global;
   auto module = getModule();
   auto insertPt = saveInsertionPoint();
-  if (auto glob = module.lookupSymbol<fir::GlobalOp>(name))
-    return glob;
   setInsertionPoint(module.getBody(), module.getBody()->end());
   llvm::SmallVector<mlir::NamedAttribute> attrs;
   if (cudaAttr) {
@@ -294,6 +307,8 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
   auto glob = create<fir::GlobalOp>(loc, name, isConst, isTarget, type, value,
                                     linkage, attrs);
   restoreInsertionPoint(insertPt);
+  if (symbolTable)
+    symbolTable->insert(glob);
   return glob;
 }
 
@@ -301,10 +316,10 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
     mlir::Location loc, mlir::Type type, llvm::StringRef name, bool isConst,
     bool isTarget, std::function<void(FirOpBuilder &)> bodyBuilder,
     mlir::StringAttr linkage, fir::CUDADataAttributeAttr cudaAttr) {
+  if (auto global = getNamedGlobal(name))
+    return global;
   auto module = getModule();
...
[truncated]

Copy link

github-actions bot commented Mar 26, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@vdonaldson
Copy link
Contributor

Looks ok to me, but someone more familiar with C++ should have a look.

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix, Jean!

I have a question about in an inline comment:

... it is
  /// not guaranteed to contain all operations of the ModuleOp with Symbol
  /// attribute since mlirSymbolTable must pro-actively be maintained when
  /// new Symbol operations are created.

const mlir::SymbolTable *symbolTable) {
if (symbolTable)
if (auto f = symbolTable->lookup<mlir::func::FuncOp>(name))
return f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to make sure that the symbolTable is kept in sync with the module? E.g. can we put assertions here and other functions where we lookup symbols to verify that symbolTable->lookup(name) == module.lookupSymbol(name). Whenever the assertion fails this will indicate that there is code that adds/removes the symbol to/form the module and does not update the symbolTable - then we can track it down and fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is OK from a correctness point of view if symbols are not added to the symbolTable (when not found there, the code fallbacks to looking at the module), but you are right that if the symbols are deleted/replaced, that would be an issue. I added the suggested assert in the lookups (under EXPENSIVE_CHECKS, since doing the check kills any compilation time improvement of using the map).

@clementval
Copy link
Contributor

Looks good to me.

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Jean,

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. Thanks.

I guess we could theoretically always build a symbol table for each builder (as it only needs to know the ModuleOp). Is the concern here that it would be too hard to keep in sync or that it would make instantiating a builder too slow?

@jeanPerier
Copy link
Contributor Author

I guess we could theoretically always build a symbol table for each builder (as it only needs to know the ModuleOp). Is the concern here that it would be too hard to keep in sync or that it would make instantiating a builder too slow?

Thanks for the review! It would be too slow, especially in the passes where FIROpBuilder are created on the fly. Creating a new SymbolTable from a ModuleOp requires doing a shallow walk of the module (linear with the number of function + global + deriver types).

@jeanPerier jeanPerier merged commit a4798bb into llvm:main Apr 2, 2024
4 checks passed
@jeanPerier jeanPerier deleted the jpr-mlir-symbol-table-2 branch April 2, 2024 12:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants