Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Handle simple commutative cases in CSE. #75274

Closed
wants to merge 5 commits into from
Closed

Conversation

jpienaar
Copy link
Member

Tried to keep this simple while handling obvious CSE instances. For more complicated cases the expectation is still that the sorting pass would run before. While simple, this case did turn up in a real deployed instance where it had a large e2e impact. This can of course be refined.

Tried to keep this simple while handling obvious CSE instances. For more complicated cases the expectation is still that the sorting pass would run before. While simple, this case did turn up in a real deployed instance where it had a large (>10% e2e) impact. This can of course be refined.
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Dec 13, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 13, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Jacques Pienaar (jpienaar)

Changes

Tried to keep this simple while handling obvious CSE instances. For more complicated cases the expectation is still that the sorting pass would run before. While simple, this case did turn up in a real deployed instance where it had a large e2e impact. This can of course be refined.


Full diff: https://github.com/llvm/llvm-project/pull/75274.diff

2 Files Affected:

  • (modified) mlir/include/mlir/IR/OperationSupport.h (+6-2)
  • (modified) mlir/lib/IR/OperationSupport.cpp (+71-14)
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 6a5ec129ad564..ba66dffeeb8e9 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1271,7 +1271,9 @@ struct OperationEquivalence {
   isEquivalentTo(Operation *lhs, Operation *rhs,
                  function_ref<LogicalResult(Value, Value)> checkEquivalent,
                  function_ref<void(Value, Value)> markEquivalent = nullptr,
-                 Flags flags = Flags::None);
+                 Flags flags = Flags::None,
+                 function_ref<LogicalResult(ValueRange, ValueRange)>
+                     checkCommutativeEquivalent = nullptr);
 
   /// Compare two operations and return if they are equivalent.
   static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
@@ -1282,7 +1284,9 @@ struct OperationEquivalence {
       Region *lhs, Region *rhs,
       function_ref<LogicalResult(Value, Value)> checkEquivalent,
       function_ref<void(Value, Value)> markEquivalent,
-      OperationEquivalence::Flags flags);
+      OperationEquivalence::Flags flags,
+      function_ref<LogicalResult(ValueRange, ValueRange)>
+          checkCommutativeEquivalent = nullptr);
 
   /// Compare two regions and return if they are equivalent.
   static bool isRegionEquivalentTo(Region *lhs, Region *rhs,
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index fc5ccd23b5108..630a3bc5016ff 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -683,8 +683,16 @@ llvm::hash_code OperationEquivalence::computeHash(
     hash = llvm::hash_combine(hash, op->getLoc());
 
   //   - Operands
-  for (Value operand : op->getOperands())
-    hash = llvm::hash_combine(hash, hashOperands(operand));
+  if (op->hasTrait<mlir::OpTrait::IsCommutative>() && op->getNumOperands() > 0) {
+    // If commutative, don't hash the operands as hash is not order independent
+    // and even if it were would not be sufficient for CSE usage.
+    // FIXME: This has the effect of resulting in more hash collisions
+    // for the sake of CSE, this could be improved.
+    hash = llvm::hash_combine(hash, op->getNumOperands());
+  } else {
+    for (Value operand : op->getOperands())
+      hash = llvm::hash_combine(hash, hashOperands(operand));
+  }
 
   //   - Results
   for (Value result : op->getResults())
@@ -696,7 +704,9 @@ llvm::hash_code OperationEquivalence::computeHash(
     Region *lhs, Region *rhs,
     function_ref<LogicalResult(Value, Value)> checkEquivalent,
     function_ref<void(Value, Value)> markEquivalent,
-    OperationEquivalence::Flags flags) {
+    OperationEquivalence::Flags flags,
+    function_ref<LogicalResult(ValueRange, ValueRange)>
+        checkCommutativeEquivalent) {
   DenseMap<Block *, Block *> blocksMap;
   auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
     // Check block arguments.
@@ -751,6 +761,36 @@ struct ValueEquivalenceCache {
     return success(lhsValue == rhsValue ||
                    equivalentValues.lookup(lhsValue) == rhsValue);
   }
+  LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
+                                           ValueRange rhsRange) {
+    // Handle simple case where sizes mismatch.
+    if (lhsRange.size() != rhsRange.size())
+      return failure();
+
+    // Handle where operands in order are equivalent.
+    auto lhsIt = lhsRange.begin();
+    auto rhsIt = rhsRange.begin();
+    for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
+      if (failed(checkEquivalent(*lhsIt, *rhsIt)))
+        break;
+    }
+    if (lhsIt == lhsRange.end())
+      return success();
+
+    // Handle another simple case where operands are just a permutation.
+    // Note: This is not sufficient, this handles simple cases relatively
+    // cheaply.
+    auto sortValues = [](ValueRange values) {
+      SmallVector<Value> sortedValues = llvm::to_vector(values);
+      llvm::sort(sortedValues, [](Value a, Value b) {
+        return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+      });
+      return sortedValues;
+    };
+    auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
+    auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
+    return success(lhsSorted == rhsSorted);
+  }
   void markEquivalent(Value lhsResult, Value rhsResult) {
     auto insertion = equivalentValues.insert({lhsResult, rhsResult});
     // Make sure that the value was not already marked equivalent to some other
@@ -773,13 +813,18 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
       [&](Value lhsResult, Value rhsResult) {
         cache.markEquivalent(lhsResult, rhsResult);
       },
-      flags);
+      flags,
+      [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
+        return cache.checkCommutativeEquivalent(lhs, rhs);
+      });
 }
 
 /*static*/ bool OperationEquivalence::isEquivalentTo(
     Operation *lhs, Operation *rhs,
     function_ref<LogicalResult(Value, Value)> checkEquivalent,
-    function_ref<void(Value, Value)> markEquivalent, Flags flags) {
+    function_ref<void(Value, Value)> markEquivalent, Flags flags,
+    function_ref<LogicalResult(ValueRange, ValueRange)>
+        checkCommutativeEquivalent) {
   if (lhs == rhs)
     return true;
 
@@ -798,15 +843,24 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
     return false;
 
   // 2. Compare operands.
-  for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
-    Value curArg = std::get<0>(operandPair);
-    Value otherArg = std::get<1>(operandPair);
-    if (curArg == otherArg)
-      continue;
-    if (curArg.getType() != otherArg.getType())
-      return false;
-    if (failed(checkEquivalent(curArg, otherArg)))
+  if (checkCommutativeEquivalent &&
+      lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
+    auto lhsRange = lhs->getOperands();
+    auto rhsRange = rhs->getOperands();
+    if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
       return false;
+  } else {
+    // Check pair wise for equivalence.
+    for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
+      Value curArg = std::get<0>(operandPair);
+      Value otherArg = std::get<1>(operandPair);
+      if (curArg == otherArg)
+        continue;
+      if (curArg.getType() != otherArg.getType())
+        return false;
+      if (failed(checkEquivalent(curArg, otherArg)))
+        return false;
+    }
   }
 
   // 3. Compare result types and mark results as equivalent.
@@ -841,7 +895,10 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
       [&](Value lhsResult, Value rhsResult) {
         cache.markEquivalent(lhsResult, rhsResult);
       },
-      flags);
+      flags,
+      [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
+        return cache.checkCommutativeEquivalent(lhs, rhs);
+      });
 }
 
 //===----------------------------------------------------------------------===//

Copy link

github-actions bot commented Dec 13, 2023

:white_check_mark: With the latest revision this PR passed the C/C++ code formatter.

@llvmbot llvmbot added the flang Flang issues not falling into any other category label Dec 13, 2023
mlir/lib/IR/OperationSupport.cpp Outdated Show resolved Hide resolved
mlir/lib/IR/OperationSupport.cpp Show resolved Hide resolved
lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
auto lhsRange = lhs->getOperands();
auto rhsRange = rhs->getOperands();
if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need checkCommutativeEquivalent? Can we do something like this if the op has the IsCommutative trait?

SmallVector<Value> rhsOperands(rhs->getOperands().begin(), rhs->getOperands().end());
for (Value lhsOperand : lhs->getOperands()) {
  auto it = llvm::find_if(rhsOperands, [&](Value rhsOperand) {
      return succeeded(checkEquivalent(lhsOperand, rhsOperand));
  });
  if (it == rhsOperands.end()) {
    // Could not find equivalent operand.
    return false;
  }
  rhsOperands.erase(it);
}
assert(rhsOperands.empty());

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but this is O(n^2) vs O(n log n). Now, this should be replaced by the explicit commutative sort pass in pipeline so that here it should suffice to just do the linear scan. The default param means its easy to drop again at that point without needing to update.

Flags flags = Flags::None);
Flags flags = Flags::None,
function_ref<LogicalResult(ValueRange, ValueRange)>
checkCommutativeEquivalent = nullptr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be documented.

I don't quite get why we need injection here by the way?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is a static function and we need to be able to inject querying the equivalence classes. Else we have to do N^2 as Matthias suggested.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that answers my question, let me rephrase: why isn't this just a Flag?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The equivalence classes are not exposed, so only accessible via the functors passed in. If a flag, then one could only access them via checkEquivalent still. Now, an option is to just not care about equivalence for commutative case except for pairwise part.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, LG, but you're still missing the doc :)

mlir/lib/IR/OperationSupport.cpp Outdated Show resolved Hide resolved
break;
}
if (lhsIt == lhsRange.end())
return success();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (!llvm::all_of(llvm::zip(lhsRange, rhsRange), checkEquivalent) return success();`

Edit: oh you're trying to optimize to only sort a subrange?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, more I'm trying to avoid sorting. So once commutative is handled as presort somewhere, it would just be the above/commutative would not be special here. And the sorted side I'm not considering equivalence classes there, its pure Value comparison rather than querying equivalentValues map too. That's as mentioned in PR description to keep this rather minimal.

@jpienaar jpienaar closed this Dec 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang Flang issues not falling into any other category mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants