-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][func] Fix multiple bugs in DuplicateFunctionElimination
#109571
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-func Author: Longsheng Mou (CoTinker) ChangesThis PR fixes multiple bugs in
Fixes #93483. Full diff: https://github.com/llvm/llvm-project/pull/109571.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
index d41d6c3e8972f9..5e23207eabf9c4 100644
--- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
@@ -54,6 +54,10 @@ struct DuplicateFuncOpEquivalenceInfo
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
rhs == getTombstoneKey() || rhs == getEmptyKey())
return false;
+
+ if (lhs.isDeclaration() || rhs.isDeclaration())
+ return false;
+
// Check discardable attributes equivalence
if (lhs->getDiscardableAttrDictionary() !=
rhs->getDiscardableAttrDictionary())
@@ -87,11 +91,11 @@ struct DuplicateFunctionEliminationPass
// Find unique representant per equivalent func ops.
DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps;
- DenseMap<StringAttr, func::FuncOp> getRepresentant;
+ DenseMap<StringRef, func::FuncOp> getRepresentant;
DenseSet<func::FuncOp> toBeErased;
module.walk([&](func::FuncOp f) {
auto [repr, inserted] = uniqueFuncOps.insert(f);
- getRepresentant[f.getSymNameAttr()] = *repr;
+ getRepresentant[f.getSymName()] = *repr;
if (!inserted) {
toBeErased.insert(f);
}
@@ -99,9 +103,14 @@ struct DuplicateFunctionEliminationPass
// Update call ops to call unique func op representants.
module.walk([&](func::CallOp callOp) {
- func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()];
+ func::FuncOp callee = getRepresentant[callOp.getCallee()];
callOp.setCallee(callee.getSymName());
});
+ // Update constant ops to reference unique func op representants.
+ module.walk([&](func::ConstantOp constantOp) {
+ func::FuncOp value = getRepresentant[constantOp.getValue()];
+ constantOp.setValue(value.getSymName());
+ });
// Erase redundant func ops.
for (auto it : toBeErased) {
diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
index 28d059a149bde8..1c6876c1327bc6 100644
--- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
+++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
@@ -366,3 +366,51 @@ func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32)
// CHECK: @user
// CHECK-2: call @deep_tree
// CHECK: call @reverse_deep_tree
+
+// -----
+
+func.func private @func_declaration(i32, i32) -> i32
+func.func private @func_declaration1(i32, i32) -> i32
+
+func.func @user(%arg0: i32, %arg1: i32) -> (i32, i32) {
+ %0 = call @func_declaration(%arg0, %arg1) : (i32, i32) -> i32
+ %1 = call @func_declaration1(%arg0, %arg1) : (i32, i32) -> i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK: @func_declaration
+// CHECK: @func_declaration1
+// CHECK: @user
+// CHECK: call @func_declaration
+// CHECK: call @func_declaration1
+
+
+// -----
+
+func.func @identity(%arg0: tensor<f32>) -> tensor<f32> {
+ return %arg0 : tensor<f32>
+}
+
+func.func @also_identity(%arg0: tensor<f32>) -> tensor<f32> {
+ return %arg0 : tensor<f32>
+}
+
+func.func @yet_another_identity(%arg0: tensor<f32>) -> tensor<f32> {
+ return %arg0 : tensor<f32>
+}
+
+func.func @user(%arg0: tensor<f32>) -> tensor<f32> {
+ %f = constant @identity : (tensor<f32>) -> tensor<f32>
+ %0 = call_indirect %f(%arg0) : (tensor<f32>) -> tensor<f32>
+ %f_0 = constant @also_identity : (tensor<f32>) -> tensor<f32>
+ %1 = call_indirect %f_0(%0) : (tensor<f32>) -> tensor<f32>
+ %2 = call @yet_another_identity(%1) : (tensor<f32>) -> tensor<f32>
+ return %2 : tensor<f32>
+}
+
+// CHECK: @identity
+// CHECK-NOT: @also_identity
+// CHECK-NOT: @yet_another_identity
+// CHECK: @user
+// CHECK-2: constant @identity
+// CHECK: call @identity
|
@llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesThis PR fixes multiple bugs in
Fixes #93483. Full diff: https://github.com/llvm/llvm-project/pull/109571.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
index d41d6c3e8972f9..5e23207eabf9c4 100644
--- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
@@ -54,6 +54,10 @@ struct DuplicateFuncOpEquivalenceInfo
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
rhs == getTombstoneKey() || rhs == getEmptyKey())
return false;
+
+ if (lhs.isDeclaration() || rhs.isDeclaration())
+ return false;
+
// Check discardable attributes equivalence
if (lhs->getDiscardableAttrDictionary() !=
rhs->getDiscardableAttrDictionary())
@@ -87,11 +91,11 @@ struct DuplicateFunctionEliminationPass
// Find unique representant per equivalent func ops.
DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps;
- DenseMap<StringAttr, func::FuncOp> getRepresentant;
+ DenseMap<StringRef, func::FuncOp> getRepresentant;
DenseSet<func::FuncOp> toBeErased;
module.walk([&](func::FuncOp f) {
auto [repr, inserted] = uniqueFuncOps.insert(f);
- getRepresentant[f.getSymNameAttr()] = *repr;
+ getRepresentant[f.getSymName()] = *repr;
if (!inserted) {
toBeErased.insert(f);
}
@@ -99,9 +103,14 @@ struct DuplicateFunctionEliminationPass
// Update call ops to call unique func op representants.
module.walk([&](func::CallOp callOp) {
- func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()];
+ func::FuncOp callee = getRepresentant[callOp.getCallee()];
callOp.setCallee(callee.getSymName());
});
+ // Update constant ops to reference unique func op representants.
+ module.walk([&](func::ConstantOp constantOp) {
+ func::FuncOp value = getRepresentant[constantOp.getValue()];
+ constantOp.setValue(value.getSymName());
+ });
// Erase redundant func ops.
for (auto it : toBeErased) {
diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
index 28d059a149bde8..1c6876c1327bc6 100644
--- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
+++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
@@ -366,3 +366,51 @@ func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32)
// CHECK: @user
// CHECK-2: call @deep_tree
// CHECK: call @reverse_deep_tree
+
+// -----
+
+func.func private @func_declaration(i32, i32) -> i32
+func.func private @func_declaration1(i32, i32) -> i32
+
+func.func @user(%arg0: i32, %arg1: i32) -> (i32, i32) {
+ %0 = call @func_declaration(%arg0, %arg1) : (i32, i32) -> i32
+ %1 = call @func_declaration1(%arg0, %arg1) : (i32, i32) -> i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK: @func_declaration
+// CHECK: @func_declaration1
+// CHECK: @user
+// CHECK: call @func_declaration
+// CHECK: call @func_declaration1
+
+
+// -----
+
+func.func @identity(%arg0: tensor<f32>) -> tensor<f32> {
+ return %arg0 : tensor<f32>
+}
+
+func.func @also_identity(%arg0: tensor<f32>) -> tensor<f32> {
+ return %arg0 : tensor<f32>
+}
+
+func.func @yet_another_identity(%arg0: tensor<f32>) -> tensor<f32> {
+ return %arg0 : tensor<f32>
+}
+
+func.func @user(%arg0: tensor<f32>) -> tensor<f32> {
+ %f = constant @identity : (tensor<f32>) -> tensor<f32>
+ %0 = call_indirect %f(%arg0) : (tensor<f32>) -> tensor<f32>
+ %f_0 = constant @also_identity : (tensor<f32>) -> tensor<f32>
+ %1 = call_indirect %f_0(%0) : (tensor<f32>) -> tensor<f32>
+ %2 = call @yet_another_identity(%1) : (tensor<f32>) -> tensor<f32>
+ return %2 : tensor<f32>
+}
+
+// CHECK: @identity
+// CHECK-NOT: @also_identity
+// CHECK-NOT: @yet_another_identity
+// CHECK: @user
+// CHECK-2: constant @identity
+// CHECK: call @identity
|
780b3d5
to
69f301a
Compare
mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
Outdated
Show resolved
Hide resolved
I've noticed that llvm-project/mlir/lib/IR/SymbolTable.cpp Lines 877 to 891 in 8dd817b
I'm not familiar with this, could you please take a look at it. @River707 |
In the following scenario, llvm-project/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp Lines 214 to 243 in 0df8880
|
I suspect It's just as bad |
@CoTinker I commented above with a suggestion for how to fix. |
Thanks, I'll refer to it. |
This PR fixes multiple bugs in `DuplicateFunctionElimination`. - Prevents elimination of function declarations. - Updates all symbol uses to reference unique function representatives.
This PR fixes multiple bugs in
DuplicateFunctionElimination
.Fixes #93483.