From f5b54e5f41c285b52d1e22b7fb9912c39088722e Mon Sep 17 00:00:00 2001 From: Ryutaro Okada <1015ryu88@gmail.com> Date: Thu, 20 Nov 2025 13:57:33 +0900 Subject: [PATCH] [MLIR] [Vector] Disable canonicalization for vector.scatter with tensor output Commit https://github.com/llvm/llvm-project/commit/7e7ea9c5357efcdf9ba6bd7ea3669e607a9af400 added tensor support for scatter, but running the existing canonicalization on tensors causes bugs, so we now disable the canonicalization when the result is a tensor. Closes https://github.com/llvm/llvm-project/issues/168695 Signed-off-by: Ryutaro Okada <1015ryu88@gmail.com> --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a97d0cd7f755b..c4d49334602db 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6087,6 +6087,9 @@ class ScatterFolder final : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(ScatterOp scatter, PatternRewriter &rewriter) const override { + if (!isa(scatter.getBase().getType())) + return failure(); + switch (getMaskFormat(scatter.getMask())) { case MaskFormat::AllTrue: return failure(); // no unmasked equivalent @@ -6107,6 +6110,9 @@ class FoldContiguousScatter final : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(ScatterOp op, PatternRewriter &rewriter) const override { + if (!isa(op.getBase().getType())) + return failure(); + if (failed(isZeroBasedContiguousSeq(op.getIndices()))) return failure();