-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Handle simple commutative cases in CSE. #75274
Conversation
Tried to keep this simple while handling obvious CSE instances. For more complicated cases the expectation is still that the sorting pass would run before. While simple, this case did turn up in a real deployed instance where it had a large (>10% e2e) impact. This can of course be refined.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Jacques Pienaar (jpienaar) ChangesTried to keep this simple while handling obvious CSE instances. For more complicated cases the expectation is still that the sorting pass would run before. While simple, this case did turn up in a real deployed instance where it had a large e2e impact. This can of course be refined. Full diff: https://github.com/llvm/llvm-project/pull/75274.diff 2 Files Affected:
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 6a5ec129ad564..ba66dffeeb8e9 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1271,7 +1271,9 @@ struct OperationEquivalence {
isEquivalentTo(Operation *lhs, Operation *rhs,
function_ref<LogicalResult(Value, Value)> checkEquivalent,
function_ref<void(Value, Value)> markEquivalent = nullptr,
- Flags flags = Flags::None);
+ Flags flags = Flags::None,
+ function_ref<LogicalResult(ValueRange, ValueRange)>
+ checkCommutativeEquivalent = nullptr);
/// Compare two operations and return if they are equivalent.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
@@ -1282,7 +1284,9 @@ struct OperationEquivalence {
Region *lhs, Region *rhs,
function_ref<LogicalResult(Value, Value)> checkEquivalent,
function_ref<void(Value, Value)> markEquivalent,
- OperationEquivalence::Flags flags);
+ OperationEquivalence::Flags flags,
+ function_ref<LogicalResult(ValueRange, ValueRange)>
+ checkCommutativeEquivalent = nullptr);
/// Compare two regions and return if they are equivalent.
static bool isRegionEquivalentTo(Region *lhs, Region *rhs,
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index fc5ccd23b5108..630a3bc5016ff 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -683,8 +683,16 @@ llvm::hash_code OperationEquivalence::computeHash(
hash = llvm::hash_combine(hash, op->getLoc());
// - Operands
- for (Value operand : op->getOperands())
- hash = llvm::hash_combine(hash, hashOperands(operand));
+ if (op->hasTrait<mlir::OpTrait::IsCommutative>() && op->getNumOperands() > 0) {
+ // If commutative, don't hash the operands as hash is not order independent
+ // and even if it were would not be sufficient for CSE usage.
+ // FIXME: This has the effect of resulting in more hash collisions
+ // for the sake of CSE, this could be improved.
+ hash = llvm::hash_combine(hash, op->getNumOperands());
+ } else {
+ for (Value operand : op->getOperands())
+ hash = llvm::hash_combine(hash, hashOperands(operand));
+ }
// - Results
for (Value result : op->getResults())
@@ -696,7 +704,9 @@ llvm::hash_code OperationEquivalence::computeHash(
Region *lhs, Region *rhs,
function_ref<LogicalResult(Value, Value)> checkEquivalent,
function_ref<void(Value, Value)> markEquivalent,
- OperationEquivalence::Flags flags) {
+ OperationEquivalence::Flags flags,
+ function_ref<LogicalResult(ValueRange, ValueRange)>
+ checkCommutativeEquivalent) {
DenseMap<Block *, Block *> blocksMap;
auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
// Check block arguments.
@@ -751,6 +761,36 @@ struct ValueEquivalenceCache {
return success(lhsValue == rhsValue ||
equivalentValues.lookup(lhsValue) == rhsValue);
}
+ LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
+ ValueRange rhsRange) {
+ // Handle simple case where sizes mismatch.
+ if (lhsRange.size() != rhsRange.size())
+ return failure();
+
+ // Handle where operands in order are equivalent.
+ auto lhsIt = lhsRange.begin();
+ auto rhsIt = rhsRange.begin();
+ for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
+ if (failed(checkEquivalent(*lhsIt, *rhsIt)))
+ break;
+ }
+ if (lhsIt == lhsRange.end())
+ return success();
+
+ // Handle another simple case where operands are just a permutation.
+ // Note: This is not sufficient, this handles simple cases relatively
+ // cheaply.
+ auto sortValues = [](ValueRange values) {
+ SmallVector<Value> sortedValues = llvm::to_vector(values);
+ llvm::sort(sortedValues, [](Value a, Value b) {
+ return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+ });
+ return sortedValues;
+ };
+ auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
+ auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
+ return success(lhsSorted == rhsSorted);
+ }
void markEquivalent(Value lhsResult, Value rhsResult) {
auto insertion = equivalentValues.insert({lhsResult, rhsResult});
// Make sure that the value was not already marked equivalent to some other
@@ -773,13 +813,18 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
[&](Value lhsResult, Value rhsResult) {
cache.markEquivalent(lhsResult, rhsResult);
},
- flags);
+ flags,
+ [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
+ return cache.checkCommutativeEquivalent(lhs, rhs);
+ });
}
/*static*/ bool OperationEquivalence::isEquivalentTo(
Operation *lhs, Operation *rhs,
function_ref<LogicalResult(Value, Value)> checkEquivalent,
- function_ref<void(Value, Value)> markEquivalent, Flags flags) {
+ function_ref<void(Value, Value)> markEquivalent, Flags flags,
+ function_ref<LogicalResult(ValueRange, ValueRange)>
+ checkCommutativeEquivalent) {
if (lhs == rhs)
return true;
@@ -798,15 +843,24 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
return false;
// 2. Compare operands.
- for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
- Value curArg = std::get<0>(operandPair);
- Value otherArg = std::get<1>(operandPair);
- if (curArg == otherArg)
- continue;
- if (curArg.getType() != otherArg.getType())
- return false;
- if (failed(checkEquivalent(curArg, otherArg)))
+ if (checkCommutativeEquivalent &&
+ lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
+ auto lhsRange = lhs->getOperands();
+ auto rhsRange = rhs->getOperands();
+ if (failed(checkCommutativeEquivalent(lhsRange, rhsRange)))
return false;
+ } else {
+ // Check pair wise for equivalence.
+ for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
+ Value curArg = std::get<0>(operandPair);
+ Value otherArg = std::get<1>(operandPair);
+ if (curArg == otherArg)
+ continue;
+ if (curArg.getType() != otherArg.getType())
+ return false;
+ if (failed(checkEquivalent(curArg, otherArg)))
+ return false;
+ }
}
// 3. Compare result types and mark results as equivalent.
@@ -841,7 +895,10 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
[&](Value lhsResult, Value rhsResult) {
cache.markEquivalent(lhsResult, rhsResult);
},
- flags);
+ flags,
+ [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
+ return cache.checkCommutativeEquivalent(lhs, rhs);
+ });
}
//===----------------------------------------------------------------------===//
|
|
lhs->hasTrait<mlir::OpTrait::IsCommutative>()) { | ||
auto lhsRange = lhs->getOperands(); | ||
auto rhsRange = rhs->getOperands(); | ||
if (failed(checkCommutativeEquivalent(lhsRange, rhsRange))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually need checkCommutativeEquivalent
? Can we do something like this if the op has the IsCommutative
trait?
SmallVector<Value> rhsOperands(rhs->getOperands().begin(), rhs->getOperands().end());
for (Value lhsOperand : lhs->getOperands()) {
auto it = llvm::find_if(rhsOperands, [&](Value rhsOperand) {
return succeeded(checkEquivalent(lhsOperand, rhsOperand));
});
if (it == rhsOperands.end()) {
// Could not find equivalent operand.
return false;
}
rhsOperands.erase(it);
}
assert(rhsOperands.empty());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could, but this is O(n^2) vs O(n log n). Now, this should be replaced by the explicit commutative sort pass in pipeline so that here it should suffice to just do the linear scan. The default param means its easy to drop again at that point without needing to update.
Flags flags = Flags::None); | ||
Flags flags = Flags::None, | ||
function_ref<LogicalResult(ValueRange, ValueRange)> | ||
checkCommutativeEquivalent = nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be documented.
I don't quite get why we need injection here by the way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As this is a static function and we need to be able to inject querying the equivalence classes. Else we have to do N^2 as Matthias suggested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure that answers my question, let me rephrase: why isn't this just a Flag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The equivalence classes are not exposed, so only accessible via the functors passed in. If a flag, then one could only access them via checkEquivalent still. Now, an option is to just not care about equivalence for commutative case except for pairwise part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, LG, but you're still missing the doc :)
break; | ||
} | ||
if (lhsIt == lhsRange.end()) | ||
return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (!llvm::all_of(llvm::zip(lhsRange, rhsRange), checkEquivalent) return success();`
Edit: oh you're trying to optimize to only sort a subrange?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, more I'm trying to avoid sorting. So once commutative is handled as presort somewhere, it would just be the above/commutative would not be special here. And the sorted side I'm not considering equivalence classes there, its pure Value comparison rather than querying equivalentValues
map too. That's as mentioned in PR description to keep this rather minimal.
…e & cheap combine for the operands.
Tried to keep this simple while handling obvious CSE instances. For more complicated cases the expectation is still that the sorting pass would run before. While simple, this case did turn up in a real deployed instance where it had a large e2e impact. This can of course be refined.