Skip to content

Conversation

CoTinker
Copy link
Contributor

@CoTinker CoTinker commented Sep 22, 2024

This PR fixes multiple bugs in DuplicateFunctionElimination.

  • Prevents elimination of function declarations.
  • Updates all symbol uses to reference unique function representatives.

Fixes #93483.

@llvmbot
Copy link
Member

llvmbot commented Sep 22, 2024

@llvm/pr-subscribers-mlir-func

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes multiple bugs in DuplicateFunctionElimination.

  • Prevents elimination of function declarations.
  • Updates constant ops to reference unique function representatives.
  • Simplifies DenseMap by using StringRef as the key instead of StringAttr.

Fixes #93483.


Full diff: https://github.com/llvm/llvm-project/pull/109571.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp (+12-3)
  • (modified) mlir/test/Dialect/Func/duplicate-function-elimination.mlir (+48)
diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
index d41d6c3e8972f9..5e23207eabf9c4 100644
--- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
@@ -54,6 +54,10 @@ struct DuplicateFuncOpEquivalenceInfo
     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
         rhs == getTombstoneKey() || rhs == getEmptyKey())
       return false;
+
+    if (lhs.isDeclaration() || rhs.isDeclaration())
+      return false;
+
     // Check discardable attributes equivalence
     if (lhs->getDiscardableAttrDictionary() !=
         rhs->getDiscardableAttrDictionary())
@@ -87,11 +91,11 @@ struct DuplicateFunctionEliminationPass
 
     // Find unique representant per equivalent func ops.
     DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps;
-    DenseMap<StringAttr, func::FuncOp> getRepresentant;
+    DenseMap<StringRef, func::FuncOp> getRepresentant;
     DenseSet<func::FuncOp> toBeErased;
     module.walk([&](func::FuncOp f) {
       auto [repr, inserted] = uniqueFuncOps.insert(f);
-      getRepresentant[f.getSymNameAttr()] = *repr;
+      getRepresentant[f.getSymName()] = *repr;
       if (!inserted) {
         toBeErased.insert(f);
       }
@@ -99,9 +103,14 @@ struct DuplicateFunctionEliminationPass
 
     // Update call ops to call unique func op representants.
     module.walk([&](func::CallOp callOp) {
-      func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()];
+      func::FuncOp callee = getRepresentant[callOp.getCallee()];
       callOp.setCallee(callee.getSymName());
     });
+    // Update constant ops to reference unique func op representants.
+    module.walk([&](func::ConstantOp constantOp) {
+      func::FuncOp value = getRepresentant[constantOp.getValue()];
+      constantOp.setValue(value.getSymName());
+    });
 
     // Erase redundant func ops.
     for (auto it : toBeErased) {
diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
index 28d059a149bde8..1c6876c1327bc6 100644
--- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
+++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
@@ -366,3 +366,51 @@ func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32)
 // CHECK:     @user
 // CHECK-2:     call @deep_tree
 // CHECK:       call @reverse_deep_tree
+
+// -----
+
+func.func private @func_declaration(i32, i32) -> i32
+func.func private @func_declaration1(i32, i32) -> i32
+
+func.func @user(%arg0: i32, %arg1: i32) -> (i32, i32) {
+  %0 = call @func_declaration(%arg0, %arg1) : (i32, i32) -> i32
+  %1 = call @func_declaration1(%arg0, %arg1) : (i32, i32) -> i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK:  @func_declaration
+// CHECK:  @func_declaration1
+// CHECK:  @user
+// CHECK:    call @func_declaration
+// CHECK:    call @func_declaration1
+
+
+// -----
+
+func.func @identity(%arg0: tensor<f32>) -> tensor<f32> {
+  return %arg0 : tensor<f32>
+}
+
+func.func @also_identity(%arg0: tensor<f32>) -> tensor<f32> {
+  return %arg0 : tensor<f32>
+}
+
+func.func @yet_another_identity(%arg0: tensor<f32>) -> tensor<f32> {
+  return %arg0 : tensor<f32>
+}
+
+func.func @user(%arg0: tensor<f32>) -> tensor<f32> {
+  %f = constant @identity : (tensor<f32>) -> tensor<f32>
+  %0 = call_indirect %f(%arg0) : (tensor<f32>) -> tensor<f32>
+  %f_0 = constant @also_identity : (tensor<f32>) -> tensor<f32>
+  %1 = call_indirect %f_0(%0) : (tensor<f32>) -> tensor<f32>
+  %2 = call @yet_another_identity(%1) : (tensor<f32>) -> tensor<f32>
+  return %2 : tensor<f32>
+}
+
+// CHECK:     @identity
+// CHECK-NOT: @also_identity
+// CHECK-NOT: @yet_another_identity
+// CHECK:     @user
+// CHECK-2:     constant @identity
+// CHECK:       call @identity

@llvmbot
Copy link
Member

llvmbot commented Sep 22, 2024

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes multiple bugs in DuplicateFunctionElimination.

  • Prevents elimination of function declarations.
  • Updates constant ops to reference unique function representatives.
  • Simplifies DenseMap by using StringRef as the key instead of StringAttr.

Fixes #93483.


Full diff: https://github.com/llvm/llvm-project/pull/109571.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp (+12-3)
  • (modified) mlir/test/Dialect/Func/duplicate-function-elimination.mlir (+48)
diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
index d41d6c3e8972f9..5e23207eabf9c4 100644
--- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
@@ -54,6 +54,10 @@ struct DuplicateFuncOpEquivalenceInfo
     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
         rhs == getTombstoneKey() || rhs == getEmptyKey())
       return false;
+
+    if (lhs.isDeclaration() || rhs.isDeclaration())
+      return false;
+
     // Check discardable attributes equivalence
     if (lhs->getDiscardableAttrDictionary() !=
         rhs->getDiscardableAttrDictionary())
@@ -87,11 +91,11 @@ struct DuplicateFunctionEliminationPass
 
     // Find unique representant per equivalent func ops.
     DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps;
-    DenseMap<StringAttr, func::FuncOp> getRepresentant;
+    DenseMap<StringRef, func::FuncOp> getRepresentant;
     DenseSet<func::FuncOp> toBeErased;
     module.walk([&](func::FuncOp f) {
       auto [repr, inserted] = uniqueFuncOps.insert(f);
-      getRepresentant[f.getSymNameAttr()] = *repr;
+      getRepresentant[f.getSymName()] = *repr;
       if (!inserted) {
         toBeErased.insert(f);
       }
@@ -99,9 +103,14 @@ struct DuplicateFunctionEliminationPass
 
     // Update call ops to call unique func op representants.
     module.walk([&](func::CallOp callOp) {
-      func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()];
+      func::FuncOp callee = getRepresentant[callOp.getCallee()];
       callOp.setCallee(callee.getSymName());
     });
+    // Update constant ops to reference unique func op representants.
+    module.walk([&](func::ConstantOp constantOp) {
+      func::FuncOp value = getRepresentant[constantOp.getValue()];
+      constantOp.setValue(value.getSymName());
+    });
 
     // Erase redundant func ops.
     for (auto it : toBeErased) {
diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
index 28d059a149bde8..1c6876c1327bc6 100644
--- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
+++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
@@ -366,3 +366,51 @@ func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32)
 // CHECK:     @user
 // CHECK-2:     call @deep_tree
 // CHECK:       call @reverse_deep_tree
+
+// -----
+
+func.func private @func_declaration(i32, i32) -> i32
+func.func private @func_declaration1(i32, i32) -> i32
+
+func.func @user(%arg0: i32, %arg1: i32) -> (i32, i32) {
+  %0 = call @func_declaration(%arg0, %arg1) : (i32, i32) -> i32
+  %1 = call @func_declaration1(%arg0, %arg1) : (i32, i32) -> i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK:  @func_declaration
+// CHECK:  @func_declaration1
+// CHECK:  @user
+// CHECK:    call @func_declaration
+// CHECK:    call @func_declaration1
+
+
+// -----
+
+func.func @identity(%arg0: tensor<f32>) -> tensor<f32> {
+  return %arg0 : tensor<f32>
+}
+
+func.func @also_identity(%arg0: tensor<f32>) -> tensor<f32> {
+  return %arg0 : tensor<f32>
+}
+
+func.func @yet_another_identity(%arg0: tensor<f32>) -> tensor<f32> {
+  return %arg0 : tensor<f32>
+}
+
+func.func @user(%arg0: tensor<f32>) -> tensor<f32> {
+  %f = constant @identity : (tensor<f32>) -> tensor<f32>
+  %0 = call_indirect %f(%arg0) : (tensor<f32>) -> tensor<f32>
+  %f_0 = constant @also_identity : (tensor<f32>) -> tensor<f32>
+  %1 = call_indirect %f_0(%0) : (tensor<f32>) -> tensor<f32>
+  %2 = call @yet_another_identity(%1) : (tensor<f32>) -> tensor<f32>
+  return %2 : tensor<f32>
+}
+
+// CHECK:     @identity
+// CHECK-NOT: @also_identity
+// CHECK-NOT: @yet_another_identity
+// CHECK:     @user
+// CHECK-2:     constant @identity
+// CHECK:       call @identity

@CoTinker CoTinker force-pushed the duplicate branch 2 times, most recently from 780b3d5 to 69f301a Compare September 22, 2024 08:54
@llvm llvm deleted a comment from github-actions bot Sep 22, 2024
@CoTinker
Copy link
Contributor Author

I've noticed that replaceAllSymbolUsesImpl uses collectSymbolScopes to traverse the region and addReplacement to perform replacements. Perhaps using SymbolTable::replaceAllSymbolUses inside the loop isn't as costly as we think.

/// The implementation of SymbolTable::replaceAllSymbolUses below.
template <typename SymbolT, typename IRUnitT>
static LogicalResult
replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
// Generate a new attribute to replace the given attribute.
FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
SymbolRefAttr oldAttr = scope.symbol;
SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
AttrTypeReplacer replacer;
replacer.addReplacement(
[&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
// Regardless of the match, don't walk nested SymbolRefAttrs, we don't
// want to accidentally replace an inner reference.
if (attr == oldAttr)

I'm not familiar with this, could you please take a look at it. @River707

@CoTinker
Copy link
Contributor Author

In the following scenario, SymbolTable::replaceAllSymbolUses is also inside a loop.

for (auto &op : *combinedModule.getBody()) {
SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;
// Do not support ops with operands or results.
// Global variables, spec constants, and functions won't have
// operands/results, but just for safety here.
if (op.getNumOperands() != 0 || op.getNumResults() != 0)
continue;
// Deduplicating functions are not supported yet.
if (isa<FuncOp>(op))
continue;
auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
if (result.second)
continue;
SymbolOpInterface replacementSymOp = result.first->second;
if (failed(SymbolTable::replaceAllSymbolUses(
symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
symbolOp.emitError("unable to update all symbol uses for ")
<< symbolOp.getName() << " to " << replacementSymOp.getName();
return nullptr;
}
eraseList.push_back(symbolOp);
}

@joker-eph
Copy link
Collaborator

In the following scenario, SymbolTable::replaceAllSymbolUses is also inside a loop.

I suspect It's just as bad

@CoTinker CoTinker requested a review from River707 October 17, 2024 06:52
@christopherbate
Copy link
Contributor

@CoTinker I commented above with a suggestion for how to fix.

@CoTinker
Copy link
Contributor Author

@CoTinker I commented above with a suggestion for how to fix.

Thanks, I'll refer to it.

This PR fixes multiple bugs in `DuplicateFunctionElimination`.
- Prevents elimination of function declarations.
- Updates all symbol uses to reference unique function representatives.
@CoTinker CoTinker merged commit 2ce655c into llvm:main Oct 22, 2024
8 checks passed
@CoTinker CoTinker deleted the duplicate branch October 22, 2024 01:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][func] Duplicate function elimination pass can create invalid IR
4 participants