-
Notifications
You must be signed in to change notification settings - Fork 52
[Enhancement][TileLang] Introduce Pass LegalizeSafeMemoryAccess to auto protect memory access by Injecting IfThenElse Node
#267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Contributor
Author
|
ops, Test Looks like we transformed ir: A_shared[0, i * 4 + v // 32, v % 32 // 4 * 32 + (v % 32 // 16 + v % 4 // 2) % 2 * 16 + (v % 16 // 8 + v % 2) % 2 * 8 + vec_1] = T.if_then_else(by * 64 + i * 32 + v // 4 < m and k_1 * 32 + v % 4 * 8 < k, A[by * 64 + i * 32 + v // 4, k_1 * 32 + v % 4 * 8 + vec_1], T.float16(0))into if k_1 * 32 + v % 4 * 8 + vec_1 < k:
if by * 64 + i * 32 + v // 4 < m:
A_shared[0, i * 4 + v // 32, v % 32 // 4 * 32 + (v % 32 // 16 + v % 4 // 2) % 2 * 16 + (v % 16 // 8 + v % 2) % 2 * 8 + vec_1] = T.if_then_else(k_1 * 32 + v % 4 * 8 < k, A[by * 64 + i * 32 + v // 4, k_1 * 32 + v % 4 * 8 + vec_1], T.float16(0))Which must be incorrect behavior, we should have an else node to set A_shared to zero instead of skip the BufferStoreNode. |
Contributor
Author
|
Summarize current solution:
Detailed information please checkout Pass // Class to legalize safe memory access by transforming them appropriately
class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
public:
// Static method to substitute and transform the given PrimFunc
static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer;
// Create an instance of the legalizer with the analyzer
SafeMemoryLegalizer substituter(&analyzer);
// Get a mutable copy of the function node
PrimFuncNode* fptr = f.CopyOnWrite();
// Apply the legalizer to the function body
fptr->body = substituter.VisitStmt(f->body);
return f;
}
private:
// Constructor initializing the base class with the analyzer
SafeMemoryLegalizer(arith::Analyzer* analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {}
// Override the VisitStmt_ method to handle ForNode (loop statements)
Stmt VisitStmt_(const ForNode* op) final {
// Visit and potentially modify the loop node
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto has_inner_loop = HasInnerLoop(for_node->body);
if (!has_inner_loop) {
SafeMemorysRewriter rewriter(analyzer_);
for_node.CopyOnWrite()->body = rewriter(for_node->body);
// // Detect Buffer Load Node in the loop body, collect the indices and buffer size
// // Run the checker on the loop body
// GlobalMemChecker checker(analyzer_);
// checker(for_node->body);
// Array<PrimExpr> conditions = checker.GetConditions();
// auto body = for_node->body;
// // Note that we might have duplicate conditions
// // Which will be optimzied by simplify pass
// // Replace the loop body with the new body
// for (auto cond : conditions) {
// body = IfThenElse(cond, body);
// }
// for_node.CopyOnWrite()->body = body;
return std::move(for_node);
}
// Visit a For Node
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
static bool HasInnerLoop(const Stmt& stmt) {
LeafForFinder finder;
finder(stmt);
return finder.leaf_for_nodes.size() > 0;
}
};
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
ref to issue #265 , This pull request introduces Pass
LegalizeSafeMemoryAccessTo address the need for auto-injection of if-then-else checks for out-of-bounds memory access in TileLang, we’ve implemented a new pass:
tl.LegalizeSafeMemoryAccess. This pass is designed to analyze memory accesses within loops, identify potential out-of-bounds conditions, and automatically insert the appropriate if-then-else statements to guard such accesses.Key Features of the Pass:
Here’s an example of how the pass transforms a loop:
Input:
Output (after the pass):
The injected condition ensures safe memory access by checking that m < M and n < N before accessing C[m, n].
How to Use the Pass
The pass can be applied as part of the TVM compilation pipeline by invoking tl.LegalizeSafeMemoryAccess. The registration of the pass allows it to be seamlessly integrated into existing workflows. For example:
This should address scenarios where loops access memory blocks with dimensions larger than the actual data, such as in the case of padding or tiling.