-
Notifications
You must be signed in to change notification settings - Fork 15.5k
Enable pass instrumentation to signal failures. #163126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Jacques Pienaar (jpienaar) ChangesEnables adding instrumentation to pass manager that can track/flag invariants. This would be useful for cases where one some tighter requirements than the general dialects or for a phase of conversion that elsewhere. It would enable making verify also just a regular instrumentation I believe, but also a non-goal as that is a first class concept and baseline for the ops and passes. Would have enabled some of the requirements of https://discourse.llvm.org/t/pre-verification-logic-before-running-conversion-pass-in-mlir/88318/10 . Full diff: https://github.com/llvm/llvm-project/pull/163126.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 16893c6db87b1..f0b0979a81ee3 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -17,6 +17,7 @@
#include <optional>
namespace mlir {
+class PassInstrumentation;
namespace detail {
class OpToOpPassAdaptor;
struct OpPassManagerImpl;
@@ -334,6 +335,9 @@ class Pass {
/// Allow access to 'passOptions'.
friend class PassInfo;
+
+ /// Allow access to 'signalPassFailure'.
+ friend class PassInstrumentation;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h
index 917bac4b22288..25a8e77be75ee 100644
--- a/mlir/include/mlir/Pass/PassInstrumentation.h
+++ b/mlir/include/mlir/Pass/PassInstrumentation.h
@@ -80,6 +80,8 @@ class PassInstrumentation {
/// name of the analysis that was computed, its TypeID, as well as the
/// current operation being analyzed.
virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {}
+
+ static void signalPassFailure(Pass *pass);
};
/// This class holds a collection of PassInstrumentation objects, and invokes
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 521c7c6be17b6..17ac475b42f4b 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -599,17 +599,20 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
if (pi)
pi->runBeforePass(pass, op);
- bool passFailed = false;
- op->getContext()->executeAction<PassExecutionAction>(
- [&]() {
- // Invoke the virtual runOnOperation method.
- if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
- adaptor->runOnOperation(verifyPasses);
- else
- pass->runOnOperation();
- passFailed = pass->passState->irAndPassFailed.getInt();
- },
- {op}, *pass);
+ bool passFailed = pass->passState->irAndPassFailed.getInt();
+ if (!passFailed) {
+ op->getContext()->executeAction<PassExecutionAction>(
+ [&]() {
+ // Invoke the virtual runOnOperation method.
+ if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
+ adaptor->runOnOperation(verifyPasses);
+ else
+ pass->runOnOperation();
+ passFailed = pass->passState->irAndPassFailed.getInt();
+ },
+ {op}, *pass);
+ }
+
// Invalidate any non preserved analyses.
am.invalidate(pass->passState->preservedAnalyses);
@@ -640,10 +643,12 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
// Instrument after the pass has run.
if (pi) {
- if (passFailed)
+ if (passFailed) {
pi->runAfterPassFailed(pass, op);
- else
+ } else {
pi->runAfterPass(pass, op);
+ passFailed = passFailed || pass->passState->irAndPassFailed.getInt();
+ }
}
// Return if the pass signaled a failure.
@@ -1198,6 +1203,8 @@ void PassInstrumentation::runBeforePipeline(
void PassInstrumentation::runAfterPipeline(
std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
+void PassInstrumentation::signalPassFailure(Pass *pass) { pass->signalPassFailure(); }
+
//===----------------------------------------------------------------------===//
// PassInstrumentor
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 7e618811eabf4..86c793384db11 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassInstrumentation.h"
#include "gtest/gtest.h"
#include <memory>
@@ -117,6 +118,103 @@ struct AddSecondAttrFunctionPass
}
};
+/// PassInstrumentation to count pass callbacks and signal pass failures.
+struct TestPassInstrumentation : public PassInstrumentation {
+ int beforePassCallbackCount = 0;
+ int afterPassCallbackCount = 0;
+ int afterPassFailedCallbackCount = 0;
+
+ bool failBeforePass = false;
+ bool failAfterPass = false;
+
+ void runBeforePass(Pass *pass, Operation *op) override {
+ if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+ ++beforePassCallbackCount;
+ if (failBeforePass)
+ signalPassFailure(pass);
+ }
+ void runAfterPass(Pass *pass, Operation *op) override {
+ if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+ ++afterPassCallbackCount;
+ if (failAfterPass)
+ signalPassFailure(pass);
+ }
+ void runAfterPassFailed(Pass *pass, Operation *op) override {
+ if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+ ++afterPassFailedCallbackCount;
+ }
+};
+
+TEST(PassManagerTest, PassInstrumentation) {
+ MLIRContext context;
+ context.loadDialect<func::FuncDialect>();
+ Builder b(&context);
+
+ // Create a module with 1 function.
+ OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
+ auto func = func::FuncOp::create(b.getUnknownLoc(), "test_func",
+ b.getFunctionType({}, {}));
+ func.setPrivate();
+ module->push_back(func);
+
+ struct InstrumentationCounts {
+ int beforePass;
+ int afterPass;
+ int afterPassFailed;
+ };
+
+ auto runInstrumentation =
+ [&](bool failBefore,
+ bool failAfter) -> std::pair<LogicalResult, InstrumentationCounts> {
+ // Instantiate and run our pass.
+ auto pm = PassManager::on<ModuleOp>(&context);
+ auto instrumentation = std::make_unique<TestPassInstrumentation>();
+ auto *instrumentationPtr = instrumentation.get();
+ instrumentation->failBeforePass = failBefore;
+ instrumentation->failAfterPass = failAfter;
+ pm.addInstrumentation(std::move(instrumentation));
+ pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
+ LogicalResult result = pm.run(module.get());
+
+ InstrumentationCounts counts = {
+ .beforePass = instrumentationPtr->beforePassCallbackCount,
+ .afterPass = instrumentationPtr->afterPassCallbackCount,
+ .afterPassFailed = instrumentationPtr->afterPassFailedCallbackCount};
+ return {result, counts};
+ };
+
+ for (bool failBefore : {false, true}) {
+ for (bool failAfter : {false, true}) {
+ auto [result, counts] = runInstrumentation(failBefore, failAfter);
+
+ InstrumentationCounts expected;
+ if (failBefore) {
+ EXPECT_TRUE(failed(result))
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ expected = {.beforePass = 1, .afterPass = 0, .afterPassFailed = 1};
+ } else if (failAfter) {
+ EXPECT_TRUE(failed(result))
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0};
+ } else {
+ EXPECT_TRUE(succeeded(result))
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0};
+ }
+
+ EXPECT_EQ(counts.beforePass, expected.beforePass)
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ EXPECT_EQ(counts.afterPass, expected.afterPass)
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ EXPECT_EQ(counts.afterPassFailed, expected.afterPassFailed)
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ }
+ }
+}
+
TEST(PassManagerTest, ExecutionAction) {
MLIRContext context;
context.loadDialect<func::FuncDialect>();
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
| /// current operation being analyzed. | ||
| virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {} | ||
|
|
||
| static void signalPassFailure(Pass *pass); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add documentation for this method, also should this be a static method or an instance method? I know it doesn't have to be an instance method but that would help keep the scope of API exposure slimmer (otherwise, should we just make signalPassFailure public?)
68d3faa to
b754842
Compare
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
Enables adding instrumentation to pass manager that can track/flag invariants. This would be useful for cases where one some tighter requirements than the general dialects or for a phase of conversion that elsewhere.
It would enable making verify also just a regular instrumentation I believe, but also a non-goal as that is a first class concept and baseline for the ops and passes.
Would have enabled some of the requirements of https://discourse.llvm.org/t/pre-verification-logic-before-running-conversion-pass-in-mlir/88318/10 .