Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][async] Avoid crash when not using func.func #72801

Merged
merged 6 commits into from
Nov 21, 2023
Merged

[mlir][async] Avoid crash when not using func.func #72801

merged 6 commits into from
Nov 21, 2023

Conversation

rikhuijzer
Copy link
Member

@rikhuijzer rikhuijzer commented Nov 19, 2023

The createParallelComputeFunction crashed when calling getFunctionTypeAttrName during the creation of a new FuncOp inside the pass. The problem is that getFunctionTypeAttrName looks up the attribute name for the function type which in this case is func.func. However, name.getAttributeNames() was empty when clients used llvm.func instead of func.func.

To fix this, the func dialect is now registered as a dependent dialect. Also, I've added an assertion which could save other people some time.

Fixes #71281, fixes #64326.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:async labels Nov 19, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 19, 2023

@llvm/pr-subscribers-mlir-async
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Rik Huijzer (rikhuijzer)

Changes

The createParallelComputeFunction crashed when calling getFunctionTypeAttrName during the creation of a new FuncOp inside the pass. The problem is that getFunctionTypeAttrName looks up the attribute name for the function type which in this case is func.func. However, name.getAttributeNames() was empty when clients used llvm.func instead of func.func.

To fix this, the func dialect is now registered as a dependent dialect. Also, I've added an assertion which could save other people some time.


Full diff: https://github.com/llvm/llvm-project/pull/72801.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp (+4)
  • (modified) mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir (+19)
  • (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+2)
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 12a28c2e23b221a..639bc7f9ec7f112 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -102,6 +102,10 @@ struct AsyncParallelForPass
     : public impl::AsyncParallelForBase<AsyncParallelForPass> {
   AsyncParallelForPass() = default;
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<async::AsyncDialect, func::FuncDialect>();
+  }
+
   AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
                        int32_t minTaskSize) {
     this->asyncDispatch = asyncDispatch;
diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
index 2115b1881fa6d66..fa3b53dd839c6c6 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
@@ -69,6 +69,25 @@ func.func @sink_constant_step(%arg0: memref<?xf32>, %lb: index, %ub: index) {
 
 // -----
 
+// Smoke test that parallel for doesn't crash when func dialect is not loaded.
+
+// CHECK-LABEL: llvm.func @without_func_dialect()
+llvm.func @without_func_dialect() {
+  %cst = arith.constant 0.0 : f32
+
+  %c0 = arith.constant 0 : index
+  %c22 = arith.constant 22 : index
+  %c1 = arith.constant 1 : index
+  %54 = memref.alloc() : memref<22xf32>
+  %alloc_4 = memref.alloc() : memref<22xf32>
+  scf.parallel (%arg0) = (%c0) to (%c22) step (%c1) {
+    memref.store %cst, %alloc_4[%arg0] : memref<22xf32>
+  }
+  llvm.return
+}
+
+// -----
+
 // Check that for statically known inner loop bound block size is aligned and
 // inner loop uses statically known loop trip counts.
 
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 842964b853d084d..963c52fd4191657 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1143,6 +1143,8 @@ void OpEmitter::genAttrNameGetters() {
       const char *const getAttrName = R"(
   assert(index < {0} && "invalid attribute index");
   assert(name.getStringRef() == getOperationName() && "invalid operation name");
+  assert(!name.getAttributeNames().empty() && "empty attribute names. Is a new "
+         "op created without having initialized its dialect?");
   return name.getAttributeNames()[index];
 )";
       method->body() << formatv(getAttrName, attributes.size());

@@ -1143,6 +1143,8 @@ void OpEmitter::genAttrNameGetters() {
const char *const getAttrName = R"(
assert(index < {0} && "invalid attribute index");
assert(name.getStringRef() == getOperationName() && "invalid operation name");
assert(!name.getAttributeNames().empty() && "empty attribute names. Is a new "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if it is correct for getAttributeNames to return an empty list if the dialect is not loaded. The documentation of getAttributeNames talks about caching and maybe this should have fallen back to OpTy::getAttributeNames? Maybe @joker-eph knows.

Copy link
Member Author

@rikhuijzer rikhuijzer Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering about that too. The attribute names are all known at compile time, so you would guess that getFunctionTypeAttrName (which figures out the name of the attribute function_type) can directly return the result. However, getFunctionTypeAttrName returns not only the name but also some context. I've tried passing the name as string instead of an OperationName, but that will crash since, if I remember correctly, I got an error saying that FuncOp could not be created since it was an unregistered op. So the fallback to (static?) OpTy::getAttributeNames is probably not gonna work.

Returning a nullptr for getAttributeNames sounds like something that may work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since fixing this would require a separate PR anyway, I've opened an issue for this: #72900.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect a better assert here may be assert(name.isRegistered() && "Operation isn't registered, missing a dependent dialect loading?");

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good tip. Thanks! I've implemented it and verified that it is triggered when reverting the fix from this PR (func::FuncDialect in dependentDialects).

@joker-eph
Copy link
Collaborator

FYI the build is failing (maybe re-running would be enough though)

@rikhuijzer rikhuijzer merged commit 88f0e4c into llvm:main Nov 21, 2023
3 checks passed
@rikhuijzer rikhuijzer deleted the rh/llvm-func-async-parallel-for branch November 21, 2023 07:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:async mlir:core MLIR Core Infrastructure mlir
Projects
None yet
4 participants