diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp index 81cd3296de294..337eb279b7154 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp @@ -178,17 +178,25 @@ struct SparseSpaceCollapsePass // %space2 = extract_space %t2 ... // sparse_tensor.iterate(%sp1) ... // + // Collect all groups to collapse before performing any IR mutations. + // Mutating (erasing) ops during the walk would invalidate the walk's + // internal iterator and cause use-after-free crashes. + SmallVector> groups; SmallVector toCollapse; func->walk([&](ExtractIterSpaceOp op) { if (!legalToCollapse(toCollapse, op)) { - // if not legal to collapse one more space, collapse the existing ones - // and clear. - collapseSparseSpace(toCollapse); + // Save the current group and start a new one. + groups.push_back(std::move(toCollapse)); toCollapse.clear(); + // Try to start a new group with the current op. + legalToCollapse(toCollapse, op); } }); + groups.push_back(std::move(toCollapse)); - collapseSparseSpace(toCollapse); + // Apply all collapse transformations after the walk is complete. + for (auto &group : groups) + collapseSparseSpace(group); } }; diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir index b5d041273f440..73bd9dacf8c24 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir @@ -34,3 +34,40 @@ func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>) -> index { } return %r1 : index } + +// Verify that --sparse-space-collapse does not crash when an +// ExtractIterSpaceOp inside a collapsable loop body is not consumed by an +// IterateOp. Previously the pass erased ops during the walk, invalidating the +// walk iterator and causing a use-after-free. See: +// https://github.com/llvm/llvm-project/issues/130021 + +// The inner %l3 (from %sp2) is not consumed by an IterateOp, so it cannot be +// collapsed. Before the fix, processing the collapsable group {%l1,%l2} during +// the walk would erase %r1 (and everything nested inside, including %l3), +// causing the walk to access freed memory on the next step. + +// CHECK-LABEL: func.func @no_crash_unconsumed_iter_space( +// CHECK: sparse_tensor.extract_iteration_space {{.*}} lvls = 0 to 2 +// CHECK: sparse_tensor.iterate +// CHECK: sparse_tensor.extract_iteration_space {{.*}} lvls = 0 +func.func @no_crash_unconsumed_iter_space( + %sp : tensor<4x8xf32, #COO>, %sp2 : tensor<4x8xf32, #COO>) -> index { + %i = arith.constant 0 : index + %c1 = arith.constant 1 : index + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 + : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0> + %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index { + %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 + : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> -> !sparse_tensor.iter_space<#COO, lvls = 1> + %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index { + // This space is from a different tensor and is not consumed by an IterateOp, + // so it breaks the collapsable chain. It must not cause a crash. + %l3 = sparse_tensor.extract_iteration_space %sp2 lvls = 0 + : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0> + %k = arith.addi %inner, %c1 : index + sparse_tensor.yield %k : index + } + sparse_tensor.yield %r2 : index + } + return %r1 : index +}