Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] de-duplicate AbstractResult pass #88867

Merged
merged 4 commits into from
Apr 22, 2024

Conversation

tblah
Copy link
Contributor

@tblah tblah commented Apr 16, 2024

This is the first proof of concept of the modification of FIR codegen to fully support a variety of top level operations (beyond just func.func) proposed in
https://discourse.llvm.org/t/rfc-add-an-interface-for-top-level-container-operations

This is the first proof of concept of the modification of FIR lowering
to fully support a variety of top level operations (beyond just
func.func) proposed in
https://discourse.llvm.org/t/rfc-add-an-interface-for-top-level-container-operations

One unfortunate side-effect of this is that the new AbstractResult pass
cannot be scheduled on a builtin.module operation and so we can't use
  fir-opt --abstract-result < file.fir

I tried adding support for operating on a module to the pass, but this
wasn't straightforward. Operating at module scope means that conversions
added for return operations run on every return operation in the module
rather than just in the current function and this violates assumptions
in the pass: producing incorrect results. This doesn't effect normal
operation because the pass manager will always run the pass on a
specific top level operation not on a whole module. I have worked around
this by specifying the pass pipeline more specifically in the tests.

I expect most other passes will be able to keep their old fir-opt
interface.
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 16, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Tom Eccles (tblah)

Changes

This is the first proof of concept of the modification of FIR lowering to fully support a variety of top level operations (beyond just func.func) proposed in
https://discourse.llvm.org/t/rfc-add-an-interface-for-top-level-container-operations

One unfortunate side-effect of this is that the new AbstractResult pass cannot be scheduled on a builtin.module operation and so we can't use
fir-opt --abstract-result < file.fir

I tried adding support for operating on a module to the pass, but this wasn't straightforward. Operating at module scope means that conversions added for return operations run on every return operation in the module rather than just in the current function and this violates assumptions in the pass: producing incorrect results. This doesn't effect normal operation because the pass manager will always run the pass on a specific top level operation not on a whole module. I have worked around this by specifying the pass pipeline more specifically in the tests.

I expect most other passes will be able to keep their old fir-opt interface.


Full diff: https://github.com/llvm/llvm-project/pull/88867.diff

12 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/FIROpsSupport.h (+7)
  • (modified) flang/include/flang/Optimizer/Transforms/Passes.h (+2-4)
  • (modified) flang/include/flang/Optimizer/Transforms/Passes.td (+3-10)
  • (modified) flang/include/flang/Tools/CLOptions.inc (+27-3)
  • (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+13)
  • (modified) flang/lib/Optimizer/Transforms/AbstractResult.cpp (+62-77)
  • (modified) flang/test/Driver/mlir-debug-pass-pipeline.f90 (+5-3)
  • (modified) flang/test/Driver/mlir-pass-pipeline.f90 (+5-3)
  • (modified) flang/test/Fir/abstract-result-2.fir (+1-1)
  • (modified) flang/test/Fir/abstract-results.fir (+4-4)
  • (modified) flang/test/Fir/basic-program.fir (+5-3)
  • (modified) flang/test/Fir/non-trivial-procedure-binding-description.f90 (+1-1)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
index 3266ea3aa7fdc6..44f2985e573785 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
+++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
@@ -173,6 +173,13 @@ inline mlir::NamedAttribute getAdaptToByRefAttr(Builder &builder) {
           builder.getUnitAttr()};
 }
 
+/// Returns true if the operation name is for a container operation expected to
+/// contain (HL)FIR operations which need to be lowered by FIR passes. The
+/// simplest example of this is func.func.
+/// This operates on mlir::RegisteredOperationName so that it can be used to
+/// implement mlir::Pass::canScheduleOn.
+bool isa_toplevel(mlir::RegisteredOperationName opName);
+
 } // namespace fir
 
 #endif // FORTRAN_OPTIMIZER_DIALECT_FIROPSSUPPORT_H
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index d8840d9e967b48..8520324e5491e1 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -31,8 +31,7 @@ namespace fir {
 // Passes defined in Passes.td
 //===----------------------------------------------------------------------===//
 
-#define GEN_PASS_DECL_ABSTRACTRESULTONFUNCOPT
-#define GEN_PASS_DECL_ABSTRACTRESULTONGLOBALOPT
+#define GEN_PASS_DECL_ABSTRACTRESULTOPT
 #define GEN_PASS_DECL_AFFINEDIALECTPROMOTION
 #define GEN_PASS_DECL_AFFINEDIALECTDEMOTION
 #define GEN_PASS_DECL_ANNOTATECONSTANTOPERANDS
@@ -50,8 +49,7 @@ namespace fir {
 #define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 
-std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
-std::unique_ptr<mlir::Pass> createAbstractResultOnGlobalOptPass();
+std::unique_ptr<mlir::Pass> createAbstractResultOptPass();
 std::unique_ptr<mlir::Pass> createAffineDemotionPass();
 std::unique_ptr<mlir::Pass>
 createArrayValueCopyPass(fir::ArrayValueCopyOptions options = {});
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 187796d77cf5c1..06887091a1d3ac 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -16,8 +16,8 @@
 
 include "mlir/Pass/PassBase.td"
 
-class AbstractResultOptBase<string optExt, string operation> 
-  : Pass<"abstract-result-on-" # optExt # "-opt", operation> {
+def AbstractResultOpt
+  : Pass<"abstract-result"> {
   let summary = "Convert fir.array, fir.box and fir.rec function result to "
                 "function argument";
   let description = [{
@@ -33,14 +33,7 @@ class AbstractResultOptBase<string optExt, string operation>
            "Pass fir.array<T> result as fir.box<fir.array<T>> argument instead"
            " of fir.ref<fir.array<T>>.">
   ];
-}
-
-def AbstractResultOnFuncOpt : AbstractResultOptBase<"func", "mlir::func::FuncOp"> {
-  let constructor = "::fir::createAbstractResultOnFuncOptPass()";
-}
-
-def AbstractResultOnGlobalOpt : AbstractResultOptBase<"global", "fir::GlobalOp"> {
-  let constructor = "::fir::createAbstractResultOnGlobalOptPass()";
+  let constructor = "::fir::createAbstractResultOptPass()";
 }
 
 def AffineDialectPromotion : Pass<"promote-to-affine", "::mlir::func::FuncOp"> {
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 268d00b5a60535..2735a0944e8e9e 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -19,6 +19,7 @@
 #include "flang/Optimizer/Transforms/Passes.h"
 #include "llvm/Passes/OptimizationLevel.h"
 #include "llvm/Support/CommandLine.h"
+#include <type_traits>
 
 #define DisableOption(DOName, DOOption, DODescription) \
   static llvm::cl::opt<bool> disable##DOName("disable-" DOOption, \
@@ -86,6 +87,31 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",
 DisableOption(ExternalNameConversion, "external-name-interop",
     "convert names with external convention");
 
+// TODO: remove once these are used for non-codegen passes
+#if !defined(FLANG_EXCLUDE_CODEGEN)
+using PassConstructor = std::function<std::unique_ptr<mlir::Pass>()>;
+
+template <typename OP>
+void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
+  pm.addNestedPass<OP>(ctor());
+}
+
+template <typename OP, typename... OPS,
+    typename = std::enable_if_t<sizeof...(OPS) != 0>>
+void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
+  addNestedPassToOps<OP>(pm, ctor);
+  addNestedPassToOps<OPS...>(pm, ctor);
+}
+
+void addNestedPassToAllTopLevelOperations(
+    mlir::PassManager &pm, PassConstructor ctor) {
+  // TODO: add more operations that might need full lowering support
+  // any operations also need to be added to fir::isa_toplevel
+  addNestedPassToOps<mlir::func::FuncOp, mlir::omp::DeclareReductionOp,
+      fir::GlobalOp>(pm, ctor);
+}
+#endif
+
 /// Generic for adding a pass to the pass manager if it is not disabled.
 template <typename F>
 void addPassConditionally(
@@ -304,9 +330,7 @@ inline void createDebugPasses(
 inline void createDefaultFIRCodeGenPassPipeline(
     mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig config) {
   fir::addBoxedProcedurePass(pm);
-  pm.addNestedPass<mlir::func::FuncOp>(
-      fir::createAbstractResultOnFuncOptPass());
-  pm.addNestedPass<fir::GlobalOp>(fir::createAbstractResultOnGlobalOptPass());
+  addNestedPassToAllTopLevelOperations(pm, fir::createAbstractResultOptPass);
   fir::addCodeGenRewritePass(pm);
   fir::addTargetRewritePass(pm);
   fir::addExternalNameConversionPass(pm, config.Underscoring);
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 88710880174d21..0bbbf59dbb352a 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3846,6 +3846,19 @@ std::optional<std::int64_t> fir::getIntIfConstant(mlir::Value value) {
   return {};
 }
 
+bool fir::isa_toplevel(mlir::RegisteredOperationName opName) {
+  const std::initializer_list<llvm::StringLiteral> topLevelOps{
+      fir::GlobalOp::getOperationName(),
+      mlir::func::FuncOp::getOperationName(),
+      mlir::omp::DeclareReductionOp::getOperationName(),
+  };
+
+  llvm::StringRef opStr = opName.getStringRef();
+  return llvm::any_of(topLevelOps, [&](const llvm::StringRef &topLevelOp) {
+    return opStr == topLevelOp;
+  });
+}
+
 mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) {
   for (auto i = path.begin(), end = path.end(); eleTy && i < end;) {
     eleTy = llvm::TypeSwitch<mlir::Type, mlir::Type>(eleTy)
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index dd1ddd16f2ded5..e295694f84d3fc 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -21,8 +21,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 
 namespace fir {
-#define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
-#define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
+#define GEN_PASS_DEF_ABSTRACTRESULTOPT
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 } // namespace fir
 
@@ -285,58 +284,8 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
   bool shouldBoxResult;
 };
 
-/// @brief Base CRTP class for AbstractResult pass family.
-/// Contains common logic for abstract result conversion in a reusable fashion.
-/// @tparam Pass target class that implements operation-specific logic.
-/// @tparam PassBase base class template for the pass generated by TableGen.
-/// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
-/// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
-/// This function should implement operation-specific functionality.
-template <typename Pass, template <typename> class PassBase>
-class AbstractResultOptTemplate : public PassBase<Pass> {
-public:
-  void runOnOperation() override {
-    auto *context = &this->getContext();
-    auto op = this->getOperation();
-
-    mlir::RewritePatternSet patterns(context);
-    mlir::ConversionTarget target = *context;
-    const bool shouldBoxResult = this->passResultAsBox.getValue();
-
-    auto &self = static_cast<Pass &>(*this);
-    self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);
-
-    // Convert the calls and, if needed,  the ReturnOp in the function body.
-    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
-                           mlir::func::FuncDialect>();
-    target.addIllegalOp<fir::SaveResultOp>();
-    target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
-      return !hasAbstractResult(call.getFunctionType());
-    });
-    target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
-      if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
-        return !hasAbstractResult(funTy);
-      return true;
-    });
-    target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
-      return !hasAbstractResult(dispatch.getFunctionType());
-    });
-
-    patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
-    patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
-    patterns.insert<SaveResultOpConversion>(context);
-    patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
-    if (mlir::failed(
-            mlir::applyPartialConversion(op, target, std::move(patterns)))) {
-      mlir::emitError(op.getLoc(), "error in converting abstract results\n");
-      this->signalPassFailure();
-    }
-  }
-};
-
-class AbstractResultOnFuncOpt
-    : public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
-                                       fir::impl::AbstractResultOnFuncOptBase> {
+class AbstractResultOpt
+    : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
 public:
   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
                               mlir::RewritePatternSet &patterns,
@@ -386,25 +335,20 @@ class AbstractResultOnFuncOpt
       }
     }
   }
-};
 
-inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
-  return mlir::TypeSwitch<mlir::Type, bool>(type)
-      .Case([](fir::BoxProcType boxProc) {
-        return fir::hasAbstractResult(
-            boxProc.getEleTy().cast<mlir::FunctionType>());
-      })
-      .Case([](fir::PointerType pointer) {
-        return fir::hasAbstractResult(
-            pointer.getEleTy().cast<mlir::FunctionType>());
-      })
-      .Default([](auto &&) { return false; });
-}
+  inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
+    return mlir::TypeSwitch<mlir::Type, bool>(type)
+        .Case([](fir::BoxProcType boxProc) {
+          return fir::hasAbstractResult(
+              boxProc.getEleTy().cast<mlir::FunctionType>());
+        })
+        .Case([](fir::PointerType pointer) {
+          return fir::hasAbstractResult(
+              pointer.getEleTy().cast<mlir::FunctionType>());
+        })
+        .Default([](auto &&) { return false; });
+  }
 
-class AbstractResultOnGlobalOpt
-    : public AbstractResultOptTemplate<
-          AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
-public:
   void runOnSpecificOperation(fir::GlobalOp global, bool,
                               mlir::RewritePatternSet &,
                               mlir::ConversionTarget &) {
@@ -412,14 +356,55 @@ class AbstractResultOnGlobalOpt
       TODO(global->getLoc(), "support for procedure pointers");
     }
   }
+
+  virtual bool canScheduleOn(RegisteredOperationName opName) const override {
+    return fir::isa_toplevel(opName);
+  }
+
+  void runOnOperation() override {
+    auto *context = &this->getContext();
+    mlir::Operation *op = this->getOperation();
+
+    mlir::RewritePatternSet patterns(context);
+    mlir::ConversionTarget target = *context;
+    const bool shouldBoxResult = this->passResultAsBox.getValue();
+
+    mlir::TypeSwitch<mlir::Operation *, void>(op)
+        .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
+          runOnSpecificOperation(op, shouldBoxResult, patterns, target);
+        });
+
+    // Convert the calls and, if needed,  the ReturnOp in the function body.
+    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+                           mlir::func::FuncDialect>();
+    target.addIllegalOp<fir::SaveResultOp>();
+    target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
+      return !hasAbstractResult(call.getFunctionType());
+    });
+    target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
+      if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
+        return !hasAbstractResult(funTy);
+      return true;
+    });
+    target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
+      return !hasAbstractResult(dispatch.getFunctionType());
+    });
+
+    patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
+    patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
+    patterns.insert<SaveResultOpConversion>(context);
+    patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
+    if (mlir::failed(
+            mlir::applyPartialConversion(op, target, std::move(patterns)))) {
+      mlir::emitError(op->getLoc(), "error in converting abstract results\n");
+      this->signalPassFailure();
+    }
+  }
 };
+
 } // end anonymous namespace
 } // namespace fir
 
-std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
-  return std::make_unique<AbstractResultOnFuncOpt>();
-}
-
-std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
-  return std::make_unique<AbstractResultOnGlobalOpt>();
+std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() {
+  return std::make_unique<AbstractResultOpt>();
 }
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 04d432f854ca35..ef84cb80ecf1db 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -72,11 +72,13 @@
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 ! ALL-NEXT: BoxedProcedurePass
 
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
 ! ALL-NEXT:   'fir.global' Pipeline
-! ALL-NEXT:   AbstractResultOnGlobalOpt
+! ALL-NEXT:     AbstractResultOpt
 ! ALL-NEXT:   'func.func' Pipeline
-! ALL-NEXT:   AbstractResultOnFuncOpt
+! ALL-NEXT:     AbstractResultOpt
+! ALL-NEXT:   'omp.declare_reduction' Pipeline
+! ALL-NEXT:     AbstractResultOpt
 
 ! ALL-NEXT: CodeGenRewrite
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations eliminated
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index cfa0de63cde5e8..d1ff2869b0a6a9 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -67,11 +67,13 @@
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 ! ALL-NEXT: BoxedProcedurePass
 
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
 ! ALL-NEXT:   'fir.global' Pipeline
-! ALL-NEXT:    AbstractResultOnGlobalOpt
+! ALL-NEXT:    AbstractResultOpt
 ! ALL-NEXT:  'func.func' Pipeline
-! ALL-NEXT:    AbstractResultOnFuncOpt
+! ALL-NEXT:    AbstractResultOpt
+! ALL-NEXT:  'omp.declare_reduction' Pipeline
+! ALL-NEXT:    AbstractResultOpt
 
 ! ALL-NEXT: CodeGenRewrite
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations eliminated
diff --git a/flang/test/Fir/abstract-result-2.fir b/flang/test/Fir/abstract-result-2.fir
index 08b723b8305936..d0cba7a9a63431 100644
--- a/flang/test/Fir/abstract-result-2.fir
+++ b/flang/test/Fir/abstract-result-2.fir
@@ -1,4 +1,4 @@
-// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s 
+// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result))" | FileCheck %s
 
 // Check that the attributes are shifted along with their corresponding arguments
 
diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir
index 42ff2a5c8eb2a8..4aac7f70d21039 100644
--- a/flang/test/Fir/abstract-results.fir
+++ b/flang/test/Fir/abstract-results.fir
@@ -1,10 +1,10 @@
 // Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to
 // functions that take an additional argument for the result.
 
-// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s --check-prefix=FUNC-REF
-// RUN: fir-opt %s --abstract-result-on-func-opt=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
-// RUN: fir-opt %s --abstract-result-on-global-opt | FileCheck %s --check-prefix=GLOBAL-REF
-// RUN: fir-opt %s --abstract-result-on-global-opt=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX
+// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result))" | FileCheck %s --check-prefix=FUNC-REF
+// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result{abstract-result-as-box}))" | FileCheck %s --check-prefix=FUNC-BOX
+// RUN: fir-opt %s --pass-pipeline="builtin.module(fir.global(abstract-result))" | FileCheck %s --check-prefix=GLOBAL-REF
+// RUN: fir-opt %s --pass-pipeline="builtin.module(fir.global(abstract-result{abstract-result-as-box}))" | FileCheck %s --check-prefix=GLOBAL-BOX
 
 // ----------------------- Test declaration rewrite ----------------------------
 
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 80d3520bc7f7d4..28c597fc918cd7 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -74,11 +74,13 @@ func.func @_QQmain() {
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 // PASSES-NEXT: BoxedProcedurePass
 
-// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
 // PASSES-NEXT:   'fir.global' Pipeline
-// PASSES-NEXT:    AbstractResultOnGlobalOpt
+// PASSES-NEXT:    AbstractResultOpt
 // PASSES-NEXT:  'func.func' Pipeline
-// PASSES-NEXT:    AbstractResultOnFuncOpt
+// PASSES-NEXT:    AbstractResultOpt
+// PASSES-NEXT:  'omp.declare_reduction' Pipeline
+// PASSES-NEXT:    AbstractResultOpt
 
 // PASSES-NEXT: CodeGenRewrite
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations eliminated
diff --git a/flang/test/Fir/non-trivial-procedure-binding-description.f90 b/flang/test/Fir/non-trivial-procedure-binding-description.f90
index 695d7fdfe232d3..f59248961d2ea1 100644
--- a/flang/test/Fir/non-trivial-procedure-binding-description.f90
+++ b/flang/test/Fir/non-trivial-procedure-binding-description.f90
@@ -1,5 +1,5 @@
 ! RUN: %flang_fc1 -emit-mlir %s -o - | FileCheck %s --check-prefix=BEFORE
-! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result-on-global-opt | FileCheck %s --check-prefix=AFTER
+! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --pass-pipeline="builtin.module(fir.global(abstract-result))" | FileCheck %s --check-prefix=AFTER
 module a
   type f
   contains

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 16, 2024

@llvm/pr-subscribers-flang-driver

Author: Tom Eccles (tblah)

Changes

This is the first proof of concept of the modification of FIR lowering to fully support a variety of top level operations (beyond just func.func) proposed in
https://discourse.llvm.org/t/rfc-add-an-interface-for-top-level-container-operations

One unfortunate side-effect of this is that the new AbstractResult pass cannot be scheduled on a builtin.module operation and so we can't use
fir-opt --abstract-result < file.fir

I tried adding support for operating on a module to the pass, but this wasn't straightforward. Operating at module scope means that conversions added for return operations run on every return operation in the module rather than just in the current function and this violates assumptions in the pass: producing incorrect results. This doesn't effect normal operation because the pass manager will always run the pass on a specific top level operation not on a whole module. I have worked around this by specifying the pass pipeline more specifically in the tests.

I expect most other passes will be able to keep their old fir-opt interface.


Full diff: https://github.com/llvm/llvm-project/pull/88867.diff

12 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/FIROpsSupport.h (+7)
  • (modified) flang/include/flang/Optimizer/Transforms/Passes.h (+2-4)
  • (modified) flang/include/flang/Optimizer/Transforms/Passes.td (+3-10)
  • (modified) flang/include/flang/Tools/CLOptions.inc (+27-3)
  • (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+13)
  • (modified) flang/lib/Optimizer/Transforms/AbstractResult.cpp (+62-77)
  • (modified) flang/test/Driver/mlir-debug-pass-pipeline.f90 (+5-3)
  • (modified) flang/test/Driver/mlir-pass-pipeline.f90 (+5-3)
  • (modified) flang/test/Fir/abstract-result-2.fir (+1-1)
  • (modified) flang/test/Fir/abstract-results.fir (+4-4)
  • (modified) flang/test/Fir/basic-program.fir (+5-3)
  • (modified) flang/test/Fir/non-trivial-procedure-binding-description.f90 (+1-1)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
index 3266ea3aa7fdc6..44f2985e573785 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
+++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
@@ -173,6 +173,13 @@ inline mlir::NamedAttribute getAdaptToByRefAttr(Builder &builder) {
           builder.getUnitAttr()};
 }
 
+/// Returns true if the operation name is for a container operation expected to
+/// contain (HL)FIR operations which need to be lowered by FIR passes. The
+/// simplest example of this is func.func.
+/// This operates on mlir::RegisteredOperationName so that it can be used to
+/// implement mlir::Pass::canScheduleOn.
+bool isa_toplevel(mlir::RegisteredOperationName opName);
+
 } // namespace fir
 
 #endif // FORTRAN_OPTIMIZER_DIALECT_FIROPSSUPPORT_H
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index d8840d9e967b48..8520324e5491e1 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -31,8 +31,7 @@ namespace fir {
 // Passes defined in Passes.td
 //===----------------------------------------------------------------------===//
 
-#define GEN_PASS_DECL_ABSTRACTRESULTONFUNCOPT
-#define GEN_PASS_DECL_ABSTRACTRESULTONGLOBALOPT
+#define GEN_PASS_DECL_ABSTRACTRESULTOPT
 #define GEN_PASS_DECL_AFFINEDIALECTPROMOTION
 #define GEN_PASS_DECL_AFFINEDIALECTDEMOTION
 #define GEN_PASS_DECL_ANNOTATECONSTANTOPERANDS
@@ -50,8 +49,7 @@ namespace fir {
 #define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 
-std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
-std::unique_ptr<mlir::Pass> createAbstractResultOnGlobalOptPass();
+std::unique_ptr<mlir::Pass> createAbstractResultOptPass();
 std::unique_ptr<mlir::Pass> createAffineDemotionPass();
 std::unique_ptr<mlir::Pass>
 createArrayValueCopyPass(fir::ArrayValueCopyOptions options = {});
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 187796d77cf5c1..06887091a1d3ac 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -16,8 +16,8 @@
 
 include "mlir/Pass/PassBase.td"
 
-class AbstractResultOptBase<string optExt, string operation> 
-  : Pass<"abstract-result-on-" # optExt # "-opt", operation> {
+def AbstractResultOpt
+  : Pass<"abstract-result"> {
   let summary = "Convert fir.array, fir.box and fir.rec function result to "
                 "function argument";
   let description = [{
@@ -33,14 +33,7 @@ class AbstractResultOptBase<string optExt, string operation>
            "Pass fir.array<T> result as fir.box<fir.array<T>> argument instead"
            " of fir.ref<fir.array<T>>.">
   ];
-}
-
-def AbstractResultOnFuncOpt : AbstractResultOptBase<"func", "mlir::func::FuncOp"> {
-  let constructor = "::fir::createAbstractResultOnFuncOptPass()";
-}
-
-def AbstractResultOnGlobalOpt : AbstractResultOptBase<"global", "fir::GlobalOp"> {
-  let constructor = "::fir::createAbstractResultOnGlobalOptPass()";
+  let constructor = "::fir::createAbstractResultOptPass()";
 }
 
 def AffineDialectPromotion : Pass<"promote-to-affine", "::mlir::func::FuncOp"> {
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 268d00b5a60535..2735a0944e8e9e 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -19,6 +19,7 @@
 #include "flang/Optimizer/Transforms/Passes.h"
 #include "llvm/Passes/OptimizationLevel.h"
 #include "llvm/Support/CommandLine.h"
+#include <type_traits>
 
 #define DisableOption(DOName, DOOption, DODescription) \
   static llvm::cl::opt<bool> disable##DOName("disable-" DOOption, \
@@ -86,6 +87,31 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",
 DisableOption(ExternalNameConversion, "external-name-interop",
     "convert names with external convention");
 
+// TODO: remove once these are used for non-codegen passes
+#if !defined(FLANG_EXCLUDE_CODEGEN)
+using PassConstructor = std::function<std::unique_ptr<mlir::Pass>()>;
+
+template <typename OP>
+void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
+  pm.addNestedPass<OP>(ctor());
+}
+
+template <typename OP, typename... OPS,
+    typename = std::enable_if_t<sizeof...(OPS) != 0>>
+void addNestedPassToOps(mlir::PassManager &pm, PassConstructor ctor) {
+  addNestedPassToOps<OP>(pm, ctor);
+  addNestedPassToOps<OPS...>(pm, ctor);
+}
+
+void addNestedPassToAllTopLevelOperations(
+    mlir::PassManager &pm, PassConstructor ctor) {
+  // TODO: add more operations that might need full lowering support
+  // any operations also need to be added to fir::isa_toplevel
+  addNestedPassToOps<mlir::func::FuncOp, mlir::omp::DeclareReductionOp,
+      fir::GlobalOp>(pm, ctor);
+}
+#endif
+
 /// Generic for adding a pass to the pass manager if it is not disabled.
 template <typename F>
 void addPassConditionally(
@@ -304,9 +330,7 @@ inline void createDebugPasses(
 inline void createDefaultFIRCodeGenPassPipeline(
     mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig config) {
   fir::addBoxedProcedurePass(pm);
-  pm.addNestedPass<mlir::func::FuncOp>(
-      fir::createAbstractResultOnFuncOptPass());
-  pm.addNestedPass<fir::GlobalOp>(fir::createAbstractResultOnGlobalOptPass());
+  addNestedPassToAllTopLevelOperations(pm, fir::createAbstractResultOptPass);
   fir::addCodeGenRewritePass(pm);
   fir::addTargetRewritePass(pm);
   fir::addExternalNameConversionPass(pm, config.Underscoring);
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 88710880174d21..0bbbf59dbb352a 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3846,6 +3846,19 @@ std::optional<std::int64_t> fir::getIntIfConstant(mlir::Value value) {
   return {};
 }
 
+bool fir::isa_toplevel(mlir::RegisteredOperationName opName) {
+  const std::initializer_list<llvm::StringLiteral> topLevelOps{
+      fir::GlobalOp::getOperationName(),
+      mlir::func::FuncOp::getOperationName(),
+      mlir::omp::DeclareReductionOp::getOperationName(),
+  };
+
+  llvm::StringRef opStr = opName.getStringRef();
+  return llvm::any_of(topLevelOps, [&](const llvm::StringRef &topLevelOp) {
+    return opStr == topLevelOp;
+  });
+}
+
 mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) {
   for (auto i = path.begin(), end = path.end(); eleTy && i < end;) {
     eleTy = llvm::TypeSwitch<mlir::Type, mlir::Type>(eleTy)
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index dd1ddd16f2ded5..e295694f84d3fc 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -21,8 +21,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 
 namespace fir {
-#define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT
-#define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT
+#define GEN_PASS_DEF_ABSTRACTRESULTOPT
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 } // namespace fir
 
@@ -285,58 +284,8 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
   bool shouldBoxResult;
 };
 
-/// @brief Base CRTP class for AbstractResult pass family.
-/// Contains common logic for abstract result conversion in a reusable fashion.
-/// @tparam Pass target class that implements operation-specific logic.
-/// @tparam PassBase base class template for the pass generated by TableGen.
-/// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
-/// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
-/// This function should implement operation-specific functionality.
-template <typename Pass, template <typename> class PassBase>
-class AbstractResultOptTemplate : public PassBase<Pass> {
-public:
-  void runOnOperation() override {
-    auto *context = &this->getContext();
-    auto op = this->getOperation();
-
-    mlir::RewritePatternSet patterns(context);
-    mlir::ConversionTarget target = *context;
-    const bool shouldBoxResult = this->passResultAsBox.getValue();
-
-    auto &self = static_cast<Pass &>(*this);
-    self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);
-
-    // Convert the calls and, if needed,  the ReturnOp in the function body.
-    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
-                           mlir::func::FuncDialect>();
-    target.addIllegalOp<fir::SaveResultOp>();
-    target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
-      return !hasAbstractResult(call.getFunctionType());
-    });
-    target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
-      if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
-        return !hasAbstractResult(funTy);
-      return true;
-    });
-    target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
-      return !hasAbstractResult(dispatch.getFunctionType());
-    });
-
-    patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
-    patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
-    patterns.insert<SaveResultOpConversion>(context);
-    patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
-    if (mlir::failed(
-            mlir::applyPartialConversion(op, target, std::move(patterns)))) {
-      mlir::emitError(op.getLoc(), "error in converting abstract results\n");
-      this->signalPassFailure();
-    }
-  }
-};
-
-class AbstractResultOnFuncOpt
-    : public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
-                                       fir::impl::AbstractResultOnFuncOptBase> {
+class AbstractResultOpt
+    : public fir::impl::AbstractResultOptBase<AbstractResultOpt> {
 public:
   void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
                               mlir::RewritePatternSet &patterns,
@@ -386,25 +335,20 @@ class AbstractResultOnFuncOpt
       }
     }
   }
-};
 
-inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
-  return mlir::TypeSwitch<mlir::Type, bool>(type)
-      .Case([](fir::BoxProcType boxProc) {
-        return fir::hasAbstractResult(
-            boxProc.getEleTy().cast<mlir::FunctionType>());
-      })
-      .Case([](fir::PointerType pointer) {
-        return fir::hasAbstractResult(
-            pointer.getEleTy().cast<mlir::FunctionType>());
-      })
-      .Default([](auto &&) { return false; });
-}
+  inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) {
+    return mlir::TypeSwitch<mlir::Type, bool>(type)
+        .Case([](fir::BoxProcType boxProc) {
+          return fir::hasAbstractResult(
+              boxProc.getEleTy().cast<mlir::FunctionType>());
+        })
+        .Case([](fir::PointerType pointer) {
+          return fir::hasAbstractResult(
+              pointer.getEleTy().cast<mlir::FunctionType>());
+        })
+        .Default([](auto &&) { return false; });
+  }
 
-class AbstractResultOnGlobalOpt
-    : public AbstractResultOptTemplate<
-          AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> {
-public:
   void runOnSpecificOperation(fir::GlobalOp global, bool,
                               mlir::RewritePatternSet &,
                               mlir::ConversionTarget &) {
@@ -412,14 +356,55 @@ class AbstractResultOnGlobalOpt
       TODO(global->getLoc(), "support for procedure pointers");
     }
   }
+
+  virtual bool canScheduleOn(RegisteredOperationName opName) const override {
+    return fir::isa_toplevel(opName);
+  }
+
+  void runOnOperation() override {
+    auto *context = &this->getContext();
+    mlir::Operation *op = this->getOperation();
+
+    mlir::RewritePatternSet patterns(context);
+    mlir::ConversionTarget target = *context;
+    const bool shouldBoxResult = this->passResultAsBox.getValue();
+
+    mlir::TypeSwitch<mlir::Operation *, void>(op)
+        .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
+          runOnSpecificOperation(op, shouldBoxResult, patterns, target);
+        });
+
+    // Convert the calls and, if needed,  the ReturnOp in the function body.
+    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+                           mlir::func::FuncDialect>();
+    target.addIllegalOp<fir::SaveResultOp>();
+    target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
+      return !hasAbstractResult(call.getFunctionType());
+    });
+    target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
+      if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
+        return !hasAbstractResult(funTy);
+      return true;
+    });
+    target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
+      return !hasAbstractResult(dispatch.getFunctionType());
+    });
+
+    patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
+    patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
+    patterns.insert<SaveResultOpConversion>(context);
+    patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
+    if (mlir::failed(
+            mlir::applyPartialConversion(op, target, std::move(patterns)))) {
+      mlir::emitError(op->getLoc(), "error in converting abstract results\n");
+      this->signalPassFailure();
+    }
+  }
 };
+
 } // end anonymous namespace
 } // namespace fir
 
-std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
-  return std::make_unique<AbstractResultOnFuncOpt>();
-}
-
-std::unique_ptr<mlir::Pass> fir::createAbstractResultOnGlobalOptPass() {
-  return std::make_unique<AbstractResultOnGlobalOpt>();
+std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() {
+  return std::make_unique<AbstractResultOpt>();
 }
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 04d432f854ca35..ef84cb80ecf1db 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -72,11 +72,13 @@
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 ! ALL-NEXT: BoxedProcedurePass
 
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
 ! ALL-NEXT:   'fir.global' Pipeline
-! ALL-NEXT:   AbstractResultOnGlobalOpt
+! ALL-NEXT:     AbstractResultOpt
 ! ALL-NEXT:   'func.func' Pipeline
-! ALL-NEXT:   AbstractResultOnFuncOpt
+! ALL-NEXT:     AbstractResultOpt
+! ALL-NEXT:   'omp.declare_reduction' Pipeline
+! ALL-NEXT:     AbstractResultOpt
 
 ! ALL-NEXT: CodeGenRewrite
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations eliminated
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index cfa0de63cde5e8..d1ff2869b0a6a9 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -67,11 +67,13 @@
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 ! ALL-NEXT: BoxedProcedurePass
 
-! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
 ! ALL-NEXT:   'fir.global' Pipeline
-! ALL-NEXT:    AbstractResultOnGlobalOpt
+! ALL-NEXT:    AbstractResultOpt
 ! ALL-NEXT:  'func.func' Pipeline
-! ALL-NEXT:    AbstractResultOnFuncOpt
+! ALL-NEXT:    AbstractResultOpt
+! ALL-NEXT:  'omp.declare_reduction' Pipeline
+! ALL-NEXT:    AbstractResultOpt
 
 ! ALL-NEXT: CodeGenRewrite
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations eliminated
diff --git a/flang/test/Fir/abstract-result-2.fir b/flang/test/Fir/abstract-result-2.fir
index 08b723b8305936..d0cba7a9a63431 100644
--- a/flang/test/Fir/abstract-result-2.fir
+++ b/flang/test/Fir/abstract-result-2.fir
@@ -1,4 +1,4 @@
-// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s 
+// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result))" | FileCheck %s
 
 // Check that the attributes are shifted along with their corresponding arguments
 
diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir
index 42ff2a5c8eb2a8..4aac7f70d21039 100644
--- a/flang/test/Fir/abstract-results.fir
+++ b/flang/test/Fir/abstract-results.fir
@@ -1,10 +1,10 @@
 // Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to
 // functions that take an additional argument for the result.
 
-// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s --check-prefix=FUNC-REF
-// RUN: fir-opt %s --abstract-result-on-func-opt=abstract-result-as-box | FileCheck %s --check-prefix=FUNC-BOX
-// RUN: fir-opt %s --abstract-result-on-global-opt | FileCheck %s --check-prefix=GLOBAL-REF
-// RUN: fir-opt %s --abstract-result-on-global-opt=abstract-result-as-box | FileCheck %s --check-prefix=GLOBAL-BOX
+// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result))" | FileCheck %s --check-prefix=FUNC-REF
+// RUN: fir-opt %s --pass-pipeline="builtin.module(func.func(abstract-result{abstract-result-as-box}))" | FileCheck %s --check-prefix=FUNC-BOX
+// RUN: fir-opt %s --pass-pipeline="builtin.module(fir.global(abstract-result))" | FileCheck %s --check-prefix=GLOBAL-REF
+// RUN: fir-opt %s --pass-pipeline="builtin.module(fir.global(abstract-result{abstract-result-as-box}))" | FileCheck %s --check-prefix=GLOBAL-BOX
 
 // ----------------------- Test declaration rewrite ----------------------------
 
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 80d3520bc7f7d4..28c597fc918cd7 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -74,11 +74,13 @@ func.func @_QQmain() {
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 // PASSES-NEXT: BoxedProcedurePass
 
-// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func']
+// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction']
 // PASSES-NEXT:   'fir.global' Pipeline
-// PASSES-NEXT:    AbstractResultOnGlobalOpt
+// PASSES-NEXT:    AbstractResultOpt
 // PASSES-NEXT:  'func.func' Pipeline
-// PASSES-NEXT:    AbstractResultOnFuncOpt
+// PASSES-NEXT:    AbstractResultOpt
+// PASSES-NEXT:  'omp.declare_reduction' Pipeline
+// PASSES-NEXT:    AbstractResultOpt
 
 // PASSES-NEXT: CodeGenRewrite
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations eliminated
diff --git a/flang/test/Fir/non-trivial-procedure-binding-description.f90 b/flang/test/Fir/non-trivial-procedure-binding-description.f90
index 695d7fdfe232d3..f59248961d2ea1 100644
--- a/flang/test/Fir/non-trivial-procedure-binding-description.f90
+++ b/flang/test/Fir/non-trivial-procedure-binding-description.f90
@@ -1,5 +1,5 @@
 ! RUN: %flang_fc1 -emit-mlir %s -o - | FileCheck %s --check-prefix=BEFORE
-! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --abstract-result-on-global-opt | FileCheck %s --check-prefix=AFTER
+! RUN: %flang_fc1 -emit-mlir %s -o - | fir-opt --pass-pipeline="builtin.module(fir.global(abstract-result))" | FileCheck %s --check-prefix=AFTER
 module a
   type f
   contains

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! This looks like I expected it would. I am bit surprised that the --abstract-result syntax doesn't work as I was under the impression that I'd automatically try and find nested ops it could run on, but maybe I am wrong. Doesn't really matter for this patch.

One question I have looking at the code, it seems the pass can also run on omp.declare_reduction now. Is this intentional and if yes, should there be a test exercising this?

flang/include/flang/Tools/CLOptions.inc Outdated Show resolved Hide resolved
Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
@tblah
Copy link
Contributor Author

tblah commented Apr 16, 2024

Thanks for taking a look

One question I have looking at the code, it seems the pass can also run on omp.declare_reduction now. Is this intentional and if yes, should there be a test exercising this?

We don't currently need this pass to run on omp.declare_reduction. I included that operation because my intention is for as many passes as possible to share this same infrastructure. Some other passes do need to run on omp.declare_reduction, for example, CFGConversion. I don't want to set a precedent that every combination of pass and top level operation has to be tested because there are a very large number of passes so that would create a lot of work to add a new top level operation (say omp.private). Most of these top level operations will only ever have very specific code constructs inside of them (e.g. fir.global, omp.*) so I think it is easier to test only specifically those cases.

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but a Flang person should probably approve as well

@@ -50,8 +49,7 @@ namespace fir {
#define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION
#include "flang/Optimizer/Transforms/Passes.h.inc"

std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
std::unique_ptr<mlir::Pass> createAbstractResultOnGlobalOptPass();
std::unique_ptr<mlir::Pass> createAbstractResultOptPass();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This declaration can be auto-generated by TableGen by just removing the
let constructor = ... line in the Pass entry. This will also automatically create an overload that can be constructed with the option object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this. I'll update the flang passes as I go along

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Tom for cleaning this up! Making this more modular makes sense to me.
Question inlined because I do not get why both addNested<Op> and canBeScheduled are needed here (but I am not an expert in pass pipeline at all).

Regarding the ReturnOp issue you saw, is it related to the weird way the ReturnOp op pattern is inserted (only inserted when the top level op is a function)? Maybe the pattern could just always be registered (just unused if there is not ReturnOp).

pm.addNestedPass<mlir::func::FuncOp>(
fir::createAbstractResultOnFuncOptPass());
pm.addNestedPass<fir::GlobalOp>(fir::createAbstractResultOnGlobalOptPass());
addNestedPassToAllTopLevelOperations(pm, fir::createAbstractResultOptPass);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't using pm.nestAny(fir::createAbstractResultOptPass()) work?

It seems to me that the using both addNestedPass<OP> and canScheduleOn to restrict what this pass is being run on is a bit redundant (and probably less optimal it it means that the ModuleOp operations need to be walked three times to schedule the passes instead of scheduling it in a single pass, although I do not know the pass scheduling details enough to be sure how the module would be walked in both cases).

Although, maybe the reverse is better from a conceptual point of view: canScheduleOn could be removed from the pass definition. There is no conceptual aspects of the pass that restrict it from running on any operation that may contain FIR calls I think, so it would make sense to me that the pipeline is the only place describing which top level operations needs to be translated here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for review.

I think the idea is that addNested<OP> could work on a subset of the operations supported by the pass (reported by canBeScheduled). Conceptually, addNested<OP> means you want to run the pass on this operation in this particular pipeline. But a different pipeline could be constructed which tries to use the pass on a different operation type (e.g. fir-opt --abstract-result module.mlir). canScheduleOn guards against running pipelines on operation types which the pass is not intended for.

We have to implement canScheduleOn because it is pure virtual in mlir::Pass. Even if there were a way to tell the pass manager to run this pass on every operation on which the pass is supported, this would have to be implemented by calling canScheduleOn with every operation and then only scheduling the pass on supported operations. Unfortunately, canScheduleOn is implemented with string comparisons (it is defined as always taking a RegisteredOperationName argument) so I would prefer the bit of duplication so that we can limit these string comparisons.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

canScheduleOn guards against running pipelines on operation types which the pass is not intended for.

In that case, I think it should be runnable on any op (the ReturnOp handling prevents that currently since you noticed it does not work on ModuleOp).

We have to implement canScheduleOn because it is pure virtual in mlir::Pass.

You do not need to in that case because AbstractResultOptPass actually inherits from mlir::OperationPass<> that defines canScheduleOn here in a way that it would say true to any operation it is being schedule on. The .td Pass<> syntax actually creates a pass inheriting from mlir::OperationPass<>, not directly from mlir::Pass (see here).

So all in all, I am OK with your patch, it is an improvement, and it could be further improved by modifying the ReturnOp handling and removing the canScheduleOn restriction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I had missunderstood the tablegen change I made. I thought it made it inherit mlir::Pass, but you are right it is mlir::OperationPass<>.

Thanks for explaining. I will see if I can create a nested pass pipeline inside the AbstractResultPass so that even when run on a module it behaves the same way as if you ran the old function and global passes. And I agree that the canScheduleOn should be removed if possible.

@tblah
Copy link
Contributor Author

tblah commented Apr 18, 2024

Regarding the ReturnOp issue you saw, is it related to the weird way the ReturnOp op pattern is inserted (only inserted when the top level op is a function)? Maybe the pattern could just always be registered (just unused if there is not ReturnOp).

I think it is the difference in scope. For a pass scheduled on a function, creating a pattern for every ReturnOp will only match return operations inside of that function. If the pass is invoked on a module, the pattern will match every ReturnOp in that module. The way the pass is currently written, it uses different legality checks for the return operation depending on the prototype of the function being processed by the pass, so running it at the module scope adds several conflicting operation legality checks for return operations and each tries to process every return operation in the module.

The pass could be fixed to work correctly at module scope, or to create its own internal nested pipeline running on each function. I preferred not to modify the pass too much in this commit so the changes didn't get confusing to review. Let me know if you would prefer I make sure fir-opt --abstract-result works before merging this.

Personally, I don't think the exact fir-opt invocation is too important, especially in this case where it has to change anyway (previously it would have been fir-opt --abstract-result-on-func).

@tblah
Copy link
Contributor Author

tblah commented Apr 19, 2024

Thanks for the review. The pass should work on module operations now

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@tblah tblah merged commit bfd1944 into llvm:main Apr 22, 2024
4 checks passed
tblah added a commit that referenced this pull request Apr 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:driver flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants