Skip to content

Conversation

@LeiWang1999
Copy link
Contributor

ref to issue #265 , This pull request introduces Pass LegalizeSafeMemoryAccess

To address the need for auto-injection of if-then-else checks for out-of-bounds memory access in TileLang, we’ve implemented a new pass: tl.LegalizeSafeMemoryAccess. This pass is designed to analyze memory accesses within loops, identify potential out-of-bounds conditions, and automatically insert the appropriate if-then-else statements to guard such accesses.

Key Features of the Pass:

  1. Global Memory Access Detection: The pass scans for BufferLoad and BufferStore operations targeting global memory and analyzes their access patterns.
  2. Bounds Checking: It compares indices against the buffer’s shape dimensions and automatically generates conditions to ensure access safety.
  3. Condition Injection: If the pass detects that a memory access could potentially be out of bounds, it wraps the loop body in the necessary if-then-else conditions to enforce valid access.
  4. Leaf Loop Analysis: The pass only applies the transformations to loops that do not have inner loops, ensuring that it does not over-constrain the program unnecessarily.
  5. Simplification-Ready Conditions: The generated conditions are compatible with downstream simplification passes, ensuring they are optimized efficiently.

Here’s an example of how the pass transforms a loop:

Input:

for i, j in T.Parallel(block_M, block_N):
    m, n = by * block_M + i, bx * block_N + j
    C[m, n] = C_shared[
        i // micro_size_x,
        j // micro_size_y,
        i % micro_size_x,
        j % micro_size_y,
    ]

Output (after the pass):

for i, j in T.Parallel(block_M, block_N):
    m, n = by * block_M + i, bx * block_N + j
    if m < M and n < N:
        C[m, n] = C_shared[
            i // micro_size_x,
            j // micro_size_y,
            i % micro_size_x,
            j % micro_size_y,
        ]

The injected condition ensures safe memory access by checking that m < M and n < N before accessing C[m, n].

How to Use the Pass

The pass can be applied as part of the TVM compilation pipeline by invoking tl.LegalizeSafeMemoryAccess. The registration of the pass allows it to be seamlessly integrated into existing workflows. For example:

    mod = tl.transform.LegalizeSafeMemoryAccess()(mod)
    # Inject Simplify to remove the duplicated conditions
    # which may be introduced by the LegalizeSafeMemoryAccess
    mod = tir.transform.Simplify()(mod)

This should address scenarios where loops access memory blocks with dimensions larger than the actual data, such as in the case of padding or tiling.

@LeiWang1999
Copy link
Contributor Author

ops, Test Loop All Dynamic is failed.

Looks like we transformed ir:

A_shared[0, i * 4 + v // 32, v % 32 // 4 * 32 + (v % 32 // 16 + v % 4 // 2) % 2 * 16 + (v % 16 // 8 + v % 2) % 2 * 8 + vec_1] = T.if_then_else(by * 64 + i * 32 + v // 4 < m and k_1 * 32 + v % 4 * 8 < k, A[by * 64 + i * 32 + v // 4, k_1 * 32 + v % 4 * 8 + vec_1], T.float16(0))

into

if k_1 * 32 + v % 4 * 8 + vec_1 < k:
    if by * 64 + i * 32 + v // 4 < m:
        A_shared[0, i * 4 + v // 32, v % 32 // 4 * 32 + (v % 32 // 16 + v % 4 // 2) % 2 * 16 + (v % 16 // 8 + v % 2) % 2 * 8 + vec_1] = T.if_then_else(k_1 * 32 + v % 4 * 8 < k, A[by * 64 + i * 32 + v // 4, k_1 * 32 + v % 4 * 8 + vec_1], T.float16(0))

Which must be incorrect behavior, we should have an else node to set A_shared to zero instead of skip the BufferStoreNode.

@LeiWang1999
Copy link
Contributor Author

Summarize current solution:

  • Traverse the final for node:
    • Rewrite each buffer store node with an if_then_else buffer load, unless storing to global memory (skip in this case).
    • If storing to shared memory from global memory, replace the store with zeros (shared memory might be uninitialized).
    • Rewrite each buffer call node with an if_then_else buffer load.

Detailed information please checkout Pass SafeMemoryLegalizer

// Class to legalize safe memory access by transforming them appropriately
class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
 public:
  // Static method to substitute and transform the given PrimFunc
  static PrimFunc Substitute(PrimFunc f) {
    arith::Analyzer analyzer;
    // Create an instance of the legalizer with the analyzer
    SafeMemoryLegalizer substituter(&analyzer);
    // Get a mutable copy of the function node
    PrimFuncNode* fptr = f.CopyOnWrite();
    // Apply the legalizer to the function body
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

 private:
  // Constructor initializing the base class with the analyzer
  SafeMemoryLegalizer(arith::Analyzer* analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {}

  // Override the VisitStmt_ method to handle ForNode (loop statements)
  Stmt VisitStmt_(const ForNode* op) final {
    // Visit and potentially modify the loop node
    For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
    auto has_inner_loop = HasInnerLoop(for_node->body);
    if (!has_inner_loop) {
      SafeMemorysRewriter rewriter(analyzer_);
      for_node.CopyOnWrite()->body = rewriter(for_node->body);
      // // Detect Buffer Load Node in the loop body, collect the indices and buffer size

      // // Run the checker on the loop body
      // GlobalMemChecker checker(analyzer_);
      // checker(for_node->body);
      // Array<PrimExpr> conditions = checker.GetConditions();
      // auto body = for_node->body;
      // // Note that we might have duplicate conditions
      // // Which will be optimzied by simplify pass
      // // Replace the loop body with the new body
      // for (auto cond : conditions) {
      //   body = IfThenElse(cond, body);
      // }
      // for_node.CopyOnWrite()->body = body;
      return std::move(for_node);
    }

    // Visit a For Node
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

  static bool HasInnerLoop(const Stmt& stmt) {
    LeafForFinder finder;
    finder(stmt);
    return finder.leaf_for_nodes.size() > 0;
  }
};

@LeiWang1999 LeiWang1999 merged commit fe8e435 into microsoft:main Dec 16, 2024
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant