Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,9 @@ class AnalysisState {
/// regions.
DenseMap<std::pair<Operation *, Operation *>, bool>
insideMutuallyExclusiveRegionsCache;

/// Cache for getAliasingOpOperands results to avoid expensive recomputation.
mutable DenseMap<Value, AliasingOpOperandList> aliasingOpOperandsCache;
};

/// BufferizationState provides information about the state of the IR during the
Expand Down
34 changes: 29 additions & 5 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ bool AnalysisState::insideMutuallyExclusiveRegions(Operation *op0,
void AnalysisState::resetCache() {
enclosingRepetitiveRegionCache.clear();
insideMutuallyExclusiveRegionsCache.clear();
aliasingOpOperandsCache.clear();
}

SymbolTableCollection &BufferizationState::getSymbolTables() {
Expand Down Expand Up @@ -413,12 +414,35 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
/// Determine which OpOperand* will alias with `value` if the op is bufferized
/// in place. Return all tensor OpOperand* if the op is not bufferizable.
AliasingOpOperandList AnalysisState::getAliasingOpOperands(Value value) const {
if (Operation *op = getOwnerOfValue(value))
if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
return bufferizableOp.getAliasingOpOperands(value, *this);
// Lambda to compute aliasing operands
auto computeAliasingOpOperands = [&]() -> AliasingOpOperandList {
AliasingOpOperandList result;
if (Operation *op = getOwnerOfValue(value))
if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
result = bufferizableOp.getAliasingOpOperands(value, *this);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just write return bufferizableOp.getAliasingOpOperands(value, *this); here and return ... at the end of the function.

else
// The op is not bufferizable.
result = detail::unknownGetAliasingOpOperands(value);
else
// The op is not bufferizable.
result = detail::unknownGetAliasingOpOperands(value);
return result;
};

// The op is not bufferizable.
return detail::unknownGetAliasingOpOperands(value);
// Check cache first
auto it = aliasingOpOperandsCache.find(value);
if (it != aliasingOpOperandsCache.end()) {
#ifndef NDEBUG
assert(it->second == computeAliasingOpOperands() &&
"inconsistent cache result");
#endif // NDEBUG
return it->second;
}

// Cache the result
AliasingOpOperandList result = computeAliasingOpOperands();
aliasingOpOperandsCache[value] = result;
return result;
}

/// Determine which Values will alias with `opOperand` if the op is bufferized
Expand Down