-
Notifications
You must be signed in to change notification settings - Fork 12.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MLIR: add flag to conditionally disable automatic dialect loading #120100
Conversation
|
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: William Moses (wsmoses) ChangesFull diff: https://github.com/llvm/llvm-project/pull/120100.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index d9bab431e2e0cc..235078bbe0b225 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -274,6 +274,9 @@ class PassManager : public OpPassManager {
/// Runs the verifier after each individual pass.
void enableVerifier(bool enabled = true);
+ /// Whether dependent dialects should be automatically loaded.
+ void setAutomaticDialectLoading(bool shouldLoad);
+
//===--------------------------------------------------------------------===//
// Instrumentations
//===--------------------------------------------------------------------===//
@@ -354,6 +357,9 @@ class PassManager : public OpPassManager {
/// Flags to control printing behavior.
OpPrintingFlags opPrintingFlags;
+
+ /// A flag to disable dependent dialect registration.
+ bool loadDialects;
};
/// Add an instrumentation to print the IR before and after pass execution,
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 6fd51c1e3cb538..fd3798652ec31c 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -853,11 +853,13 @@ LogicalResult PassManager::run(Operation *op) {
<< op->getName() << "' op";
// Register all dialects for the current pipeline.
- DialectRegistry dependentDialects;
- getDependentDialects(dependentDialects);
- context->appendDialectRegistry(dependentDialects);
- for (StringRef name : dependentDialects.getDialectNames())
- context->getOrLoadDialect(name);
+ if (loadDialects) {
+ DialectRegistry dependentDialects;
+ getDependentDialects(dependentDialects);
+ context->appendDialectRegistry(dependentDialects);
+ for (StringRef name : dependentDialects.getDialectNames())
+ context->getOrLoadDialect(name);
+ }
// Before running, make sure to finalize the pipeline pass list.
if (failed(getImpl().finalizePassList(context)))
@@ -893,6 +895,11 @@ LogicalResult PassManager::run(Operation *op) {
return result;
}
+
+void PassManager::setAutomaticDialectLoading(bool shouldLoad) {
+ loadDialects = shouldLoad;
+}
+
/// Add the provided instrumentation to the pass manager.
void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
if (!instrumentor)
|
You can test this locally with the following command:git-clang-format --diff d576021853fd64c10fd746389a9b263cf10c5295 5dfbfc6a976b554c62483a201b1263dd4c8d17fe --extensions cpp,h -- mlir/include/mlir/Pass/PassManager.h mlir/lib/Pass/Pass.cppView the diff from clang-format here.diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index fd3798652e..7f6a56e66e 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -895,7 +895,6 @@ LogicalResult PassManager::run(Operation *op) {
return result;
}
-
void PassManager::setAutomaticDialectLoading(bool shouldLoad) {
loadDialects = shouldLoad;
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please include motivation in the description. I don't quite get why we'd want that right now.
|
I was just about to flag this to you per @ftynse's suggestion :) The broader context is I'm making a system for automatically compiling and inserting custom kernels into xla as native MLIR, see EnzymeAD/Enzyme-JAX#191 So instead of an optimization barrier of a regular stablehlo.customcall to a magic unknown address you can have something like the following (which can be nvvm/gpu dialect/etc): Of course when you actually want to run this eventually someone needs to do the jit and make a pointer with the custom kernel. So that is what our LowerPass is trying to do. In particular, we want to run MLIR's GPU codegen, which is accessible via Unfortunately that quickly hits some dialect loading hell. The error I get on my machine (but not @ftynse's) is: Where this is happening in this registration code of the pass pipeline, specifically for this I believe
I also tried doing the module for codegen in an entirely new MLIRContext [that wouldn't hit this same issue], but that hit other errors in dialect loading (specifically not finding a llvmir translation for llvm.mlir.addressof). Since the "use the same context" approach works for Alex's machine and doesn't hit this issue (which I hit), I'm hoping that adding a flag to the passmanager to avoid making the double registration will help, or at least give a more informative error message than crashing. |
This feels like the infra correctly identifying when things can go very wrong. The general way that this is supposed to work is that you create the pipeline you want to run, and then call into it inside of your Pass::getDependentDialects, e.g.
|
|
@River707 yup that resolved, TIL! closing |
No description provided.