Skip to content

Conversation

@sakupan102
Copy link
Contributor

Commit 7e7ea9c added tensor support for scatter, but running the existing canonicalization on tensors causes bugs, so we now disable the canonicalization when the result is a tensor.

Closes #168695

…or output

Commit llvm@7e7ea9c added tensor support for scatter, but running the existing canonicalization on tensors causes bugs, so we now disable the canonicalization when the result is a tensor.

Closes llvm#168695

Signed-off-by: Ryutaro Okada <1015ryu88@gmail.com>
@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2025

@llvm/pr-subscribers-mlir

Author: Ryutaro Okada (sakupan102)

Changes

Commit 7e7ea9c added tensor support for scatter, but running the existing canonicalization on tensors causes bugs, so we now disable the canonicalization when the result is a tensor.

Closes #168695


Full diff: https://github.com/llvm/llvm-project/pull/168824.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a97d0cd7f755b..c4d49334602db 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6087,6 +6087,9 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
   using Base::Base;
   LogicalResult matchAndRewrite(ScatterOp scatter,
                                 PatternRewriter &rewriter) const override {
+    if (!isa<MemRefType>(scatter.getBase().getType()))
+      return failure();
+
     switch (getMaskFormat(scatter.getMask())) {
     case MaskFormat::AllTrue:
       return failure(); // no unmasked equivalent
@@ -6107,6 +6110,9 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
   using Base::Base;
   LogicalResult matchAndRewrite(ScatterOp op,
                                 PatternRewriter &rewriter) const override {
+    if (!isa<MemRefType>(op.getBase().getType()))
+      return failure();
+
     if (failed(isZeroBasedContiguousSeq(op.getIndices())))
       return failure();
 

@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2025

@llvm/pr-subscribers-mlir-vector

Author: Ryutaro Okada (sakupan102)

Changes

Commit 7e7ea9c added tensor support for scatter, but running the existing canonicalization on tensors causes bugs, so we now disable the canonicalization when the result is a tensor.

Closes #168695


Full diff: https://github.com/llvm/llvm-project/pull/168824.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a97d0cd7f755b..c4d49334602db 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6087,6 +6087,9 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
   using Base::Base;
   LogicalResult matchAndRewrite(ScatterOp scatter,
                                 PatternRewriter &rewriter) const override {
+    if (!isa<MemRefType>(scatter.getBase().getType()))
+      return failure();
+
     switch (getMaskFormat(scatter.getMask())) {
     case MaskFormat::AllTrue:
       return failure(); // no unmasked equivalent
@@ -6107,6 +6110,9 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
   using Base::Base;
   LogicalResult matchAndRewrite(ScatterOp op,
                                 PatternRewriter &rewriter) const override {
+    if (!isa<MemRefType>(op.getBase().getType()))
+      return failure();
+
     if (failed(isZeroBasedContiguousSeq(op.getIndices())))
       return failure();
 

@sakupan102
Copy link
Contributor Author

Should I add a test to verify that we have correctly disabled canonicalization?

@github-actions
Copy link

🐧 Linux x64 Test Results

  • 7101 tests passed
  • 594 tests skipped

@sakupan102
Copy link
Contributor Author

@dcaballe
Could you review this?
It is related to #165548

@dcaballe
Copy link
Contributor

Could we try to fix the issues instead of disabling canonicalization?

@sakupan102
Copy link
Contributor Author

sakupan102 commented Nov 25, 2025

@dcaballe
The crash comes from applying the existing canonicalization patterns to tensor semantics: they assume memref erase the op (e.g., ScatterFolder when mask is all false https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Vector/IR/VectorOps.cpp#L6085-L6100), which has no result. On a tensor vector.scatter, this drops the result and trips the replaceOp assertion.

I’m proposing to guard these patterns so they bail out early for tensor semantics.

@sakupan102
Copy link
Contributor Author

@dcaballe Ping

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MLIR] [Vector] Error when canonicalizing vector.scatter with tensor output

3 participants