-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[WIP][mlir][Bufferization] Accelerate bufferization pass #160655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP][mlir][Bufferization] Accelerate bufferization pass #160655
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir Author: None (mingzheTerapines) ChangesAccelerate bufferization pass by caching the result of getAliasingOpOperands function. Full diff: https://github.com/llvm/llvm-project/pull/160655.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index f3b34f9fded7f..b79be1a06cef4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -61,8 +61,7 @@ struct AliasingValue {
bool isDefinite;
};
-template <typename T>
-class AliasList {
+template <typename T> class AliasList {
public:
/// Create an empty list of aliases.
AliasList() = default;
@@ -124,8 +123,7 @@ class OpFilter {
/// Allow the given dialects.
///
/// This function adds one or multiple ALLOW entries.
- template <typename... DialectTs>
- void allowDialect() {
+ template <typename... DialectTs> void allowDialect() {
// The following expands a call to allowDialectImpl for each dialect
// in 'DialectTs'.
(allowDialectImpl<DialectTs>(), ...);
@@ -134,8 +132,7 @@ class OpFilter {
/// Deny the given dialects.
///
/// This function adds one or multiple DENY entries.
- template <typename... DialectTs>
- void denyDialect() {
+ template <typename... DialectTs> void denyDialect() {
(denyDialectImpl<DialectTs>(), ...);
}
@@ -162,16 +159,14 @@ class OpFilter {
/// Allow the given ops.
///
/// This function adds one or multiple ALLOW entries.
- template <typename... OpTys>
- void allowOperation() {
+ template <typename... OpTys> void allowOperation() {
(allowOperationImpl<OpTys>(), ...);
}
/// Deny the given ops.
///
/// This function adds one or multiple DENY entries.
- template <typename... OpTys>
- void denyOperation() {
+ template <typename... OpTys> void denyOperation() {
(denyOperationImpl<OpTys>(), ...);
}
@@ -219,26 +214,22 @@ class OpFilter {
}
/// Allow a dialect.
- template <typename DialectT>
- void allowDialectImpl() {
+ template <typename DialectT> void allowDialectImpl() {
allowDialect(DialectT::getDialectNamespace());
}
/// Deny a dialect.
- template <typename DialectT>
- void denyDialectImpl() {
+ template <typename DialectT> void denyDialectImpl() {
denyDialect(DialectT::getDialectNamespace());
}
/// Allow an op.
- template <typename OpTy>
- void allowOperationImpl() {
+ template <typename OpTy> void allowOperationImpl() {
allowOperation(OpTy::getOperationName());
}
/// Deny an op.
- template <typename OpTy>
- void denyOperationImpl() {
+ template <typename OpTy> void denyOperationImpl() {
denyOperation(OpTy::getOperationName());
}
@@ -577,6 +568,9 @@ class AnalysisState {
/// regions.
DenseMap<std::pair<Operation *, Operation *>, bool>
insideMutuallyExclusiveRegionsCache;
+
+ /// Cache for getAliasingOpOperands results to avoid expensive recomputation.
+ mutable DenseMap<Value, AliasingOpOperandList> aliasingOpOperandsCache;
};
/// BufferizationState provides information about the state of the IR during the
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index f7b0b87085f3d..0d8f3c331410d 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -119,6 +119,7 @@ bool AnalysisState::insideMutuallyExclusiveRegions(Operation *op0,
void AnalysisState::resetCache() {
enclosingRepetitiveRegionCache.clear();
insideMutuallyExclusiveRegionsCache.clear();
+ aliasingOpOperandsCache.clear();
}
SymbolTableCollection &BufferizationState::getSymbolTables() {
@@ -413,12 +414,26 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
/// Determine which OpOperand* will alias with `value` if the op is bufferized
/// in place. Return all tensor OpOperand* if the op is not bufferizable.
AliasingOpOperandList AnalysisState::getAliasingOpOperands(Value value) const {
+ // Check cache first
+ auto it = aliasingOpOperandsCache.find(value);
+ if (it != aliasingOpOperandsCache.end()) {
+ return it->second;
+ }
+
+ AliasingOpOperandList result;
if (Operation *op = getOwnerOfValue(value))
if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
- return bufferizableOp.getAliasingOpOperands(value, *this);
-
- // The op is not bufferizable.
- return detail::unknownGetAliasingOpOperands(value);
+ result = bufferizableOp.getAliasingOpOperands(value, *this);
+ else
+ // The op is not bufferizable.
+ result = detail::unknownGetAliasingOpOperands(value);
+ else
+ // The op is not bufferizable.
+ result = detail::unknownGetAliasingOpOperands(value);
+
+ // Cache the result
+ aliasingOpOperandsCache[value] = result;
+ return result;
}
/// Determine which Values will alias with `opOperand` if the op is bufferized
|
@llvm/pr-subscribers-mlir-bufferization Author: None (mingzheTerapines) ChangesAccelerate bufferization pass by caching the result of getAliasingOpOperands function. Full diff: https://github.com/llvm/llvm-project/pull/160655.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index f3b34f9fded7f..b79be1a06cef4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -61,8 +61,7 @@ struct AliasingValue {
bool isDefinite;
};
-template <typename T>
-class AliasList {
+template <typename T> class AliasList {
public:
/// Create an empty list of aliases.
AliasList() = default;
@@ -124,8 +123,7 @@ class OpFilter {
/// Allow the given dialects.
///
/// This function adds one or multiple ALLOW entries.
- template <typename... DialectTs>
- void allowDialect() {
+ template <typename... DialectTs> void allowDialect() {
// The following expands a call to allowDialectImpl for each dialect
// in 'DialectTs'.
(allowDialectImpl<DialectTs>(), ...);
@@ -134,8 +132,7 @@ class OpFilter {
/// Deny the given dialects.
///
/// This function adds one or multiple DENY entries.
- template <typename... DialectTs>
- void denyDialect() {
+ template <typename... DialectTs> void denyDialect() {
(denyDialectImpl<DialectTs>(), ...);
}
@@ -162,16 +159,14 @@ class OpFilter {
/// Allow the given ops.
///
/// This function adds one or multiple ALLOW entries.
- template <typename... OpTys>
- void allowOperation() {
+ template <typename... OpTys> void allowOperation() {
(allowOperationImpl<OpTys>(), ...);
}
/// Deny the given ops.
///
/// This function adds one or multiple DENY entries.
- template <typename... OpTys>
- void denyOperation() {
+ template <typename... OpTys> void denyOperation() {
(denyOperationImpl<OpTys>(), ...);
}
@@ -219,26 +214,22 @@ class OpFilter {
}
/// Allow a dialect.
- template <typename DialectT>
- void allowDialectImpl() {
+ template <typename DialectT> void allowDialectImpl() {
allowDialect(DialectT::getDialectNamespace());
}
/// Deny a dialect.
- template <typename DialectT>
- void denyDialectImpl() {
+ template <typename DialectT> void denyDialectImpl() {
denyDialect(DialectT::getDialectNamespace());
}
/// Allow an op.
- template <typename OpTy>
- void allowOperationImpl() {
+ template <typename OpTy> void allowOperationImpl() {
allowOperation(OpTy::getOperationName());
}
/// Deny an op.
- template <typename OpTy>
- void denyOperationImpl() {
+ template <typename OpTy> void denyOperationImpl() {
denyOperation(OpTy::getOperationName());
}
@@ -577,6 +568,9 @@ class AnalysisState {
/// regions.
DenseMap<std::pair<Operation *, Operation *>, bool>
insideMutuallyExclusiveRegionsCache;
+
+ /// Cache for getAliasingOpOperands results to avoid expensive recomputation.
+ mutable DenseMap<Value, AliasingOpOperandList> aliasingOpOperandsCache;
};
/// BufferizationState provides information about the state of the IR during the
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index f7b0b87085f3d..0d8f3c331410d 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -119,6 +119,7 @@ bool AnalysisState::insideMutuallyExclusiveRegions(Operation *op0,
void AnalysisState::resetCache() {
enclosingRepetitiveRegionCache.clear();
insideMutuallyExclusiveRegionsCache.clear();
+ aliasingOpOperandsCache.clear();
}
SymbolTableCollection &BufferizationState::getSymbolTables() {
@@ -413,12 +414,26 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
/// Determine which OpOperand* will alias with `value` if the op is bufferized
/// in place. Return all tensor OpOperand* if the op is not bufferizable.
AliasingOpOperandList AnalysisState::getAliasingOpOperands(Value value) const {
+ // Check cache first
+ auto it = aliasingOpOperandsCache.find(value);
+ if (it != aliasingOpOperandsCache.end()) {
+ return it->second;
+ }
+
+ AliasingOpOperandList result;
if (Operation *op = getOwnerOfValue(value))
if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
- return bufferizableOp.getAliasingOpOperands(value, *this);
-
- // The op is not bufferizable.
- return detail::unknownGetAliasingOpOperands(value);
+ result = bufferizableOp.getAliasingOpOperands(value, *this);
+ else
+ // The op is not bufferizable.
+ result = detail::unknownGetAliasingOpOperands(value);
+ else
+ // The op is not bufferizable.
+ result = detail::unknownGetAliasingOpOperands(value);
+
+ // Cache the result
+ aliasingOpOperandsCache[value] = result;
+ return result;
}
/// Determine which Values will alias with `opOperand` if the op is bufferized
|
Accelerate bufferization pass by caching the result of getAliasingOpOperands function.
e3c7bdf
to
aae4245
Compare
How much compilation time improvement are you seeing due to this? Can you also measure the memory overhead? You can measure memory usage with |
// Check cache first | ||
auto it = aliasingOpOperandsCache.find(value); | ||
if (it != aliasingOpOperandsCache.end()) { | ||
return it->second; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put something like:
#ifndef NDEBUG
assert(it->second == computeAliasingOpOperands() && "inconsistent cache result");
#endif // NDEBUG
It may help to put the code below into a lambda.
This is to guard against incorrect getAliasingOpOperands
implementations.
Hi @matthias-springer , we tried using this optimization to compile the UNet model (in our downstream custom model compiler) and reduced the compilation time from hours to around 30 minutes. We haven’t measured the memory overhead yet, but maybe @mingzheTerapines and I can take a look at that. Also, I have another optimization that can reduce the YOLO8x model compilation time from 12 minutes to 6 minutes in debug builds. I might submit it to this PR or create a new PR a bit later. |
Thx, later this PR could be closed and merge to that new PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx, later this PR could be closed and merge to that new PR
Do you want to merge this now or later?
AliasingOpOperandList result; | ||
if (Operation *op = getOwnerOfValue(value)) | ||
if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) | ||
result = bufferizableOp.getAliasingOpOperands(value, *this); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just write return bufferizableOp.getAliasingOpOperands(value, *this);
here and return ...
at the end of the function.
Sorry, seems it needs more test in our models. I will change status to WIP. |
Accelerate bufferization pass by caching the result of getAliasingOpOperands function.