-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR] [Vector] Disable canonicalization for vector.scatter with tensor output #168824
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…or output Commit llvm@7e7ea9c added tensor support for scatter, but running the existing canonicalization on tensors causes bugs, so we now disable the canonicalization when the result is a tensor. Closes llvm#168695 Signed-off-by: Ryutaro Okada <1015ryu88@gmail.com>
|
@llvm/pr-subscribers-mlir Author: Ryutaro Okada (sakupan102) ChangesCommit 7e7ea9c added tensor support for scatter, but running the existing canonicalization on tensors causes bugs, so we now disable the canonicalization when the result is a tensor. Closes #168695 Full diff: https://github.com/llvm/llvm-project/pull/168824.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a97d0cd7f755b..c4d49334602db 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6087,6 +6087,9 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
using Base::Base;
LogicalResult matchAndRewrite(ScatterOp scatter,
PatternRewriter &rewriter) const override {
+ if (!isa<MemRefType>(scatter.getBase().getType()))
+ return failure();
+
switch (getMaskFormat(scatter.getMask())) {
case MaskFormat::AllTrue:
return failure(); // no unmasked equivalent
@@ -6107,6 +6110,9 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
using Base::Base;
LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
+ if (!isa<MemRefType>(op.getBase().getType()))
+ return failure();
+
if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();
|
|
@llvm/pr-subscribers-mlir-vector Author: Ryutaro Okada (sakupan102) ChangesCommit 7e7ea9c added tensor support for scatter, but running the existing canonicalization on tensors causes bugs, so we now disable the canonicalization when the result is a tensor. Closes #168695 Full diff: https://github.com/llvm/llvm-project/pull/168824.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a97d0cd7f755b..c4d49334602db 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6087,6 +6087,9 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
using Base::Base;
LogicalResult matchAndRewrite(ScatterOp scatter,
PatternRewriter &rewriter) const override {
+ if (!isa<MemRefType>(scatter.getBase().getType()))
+ return failure();
+
switch (getMaskFormat(scatter.getMask())) {
case MaskFormat::AllTrue:
return failure(); // no unmasked equivalent
@@ -6107,6 +6110,9 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
using Base::Base;
LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
+ if (!isa<MemRefType>(op.getBase().getType()))
+ return failure();
+
if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();
|
|
Should I add a test to verify that we have correctly disabled canonicalization? |
🐧 Linux x64 Test Results
|
|
Could we try to fix the issues instead of disabling canonicalization? |
|
@dcaballe I’m proposing to guard these patterns so they bail out early for tensor semantics. |
|
@dcaballe Ping |
Commit 7e7ea9c added tensor support for scatter, but running the existing canonicalization on tensors causes bugs, so we now disable the canonicalization when the result is a tensor.
Closes #168695