[mlir][memref] Make memref.cast areCastCompatible return true when meet same types#192029
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: lonely eagle (linuxlonelyeagle) ChangesWhen both the source and destination types of Full diff: https://github.com/llvm/llvm-project/pull/192029.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 27c1649ee4ed3..31e4640499276 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -737,6 +737,8 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
+ if (inputs == outputs)
+ return true;
Type a = inputs.front(), b = outputs.front();
auto aT = llvm::dyn_cast<MemRefType>(a);
auto bT = llvm::dyn_cast<MemRefType>(b);
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index d3670fde08d81..2f061a1bb773e 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -894,12 +894,23 @@ func.func @invalid_memref_cast() {
// -----
-// unranked to unranked
+// unranked incompatible element types
func.func @invalid_memref_cast() {
%0 = memref.alloc() : memref<2x5xf32, 0>
%1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0>
- // expected-error@+1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32>' are cast incompatible}}
- %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xf32, 0>
+ // expected-error@+1 {{operand type 'memref<*xf32>' and result type 'memref<*xi32>' are cast incompatible}}
+ %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xi32, 0>
+ return
+}
+
+// -----
+
+// unranked incompatible memory space
+func.func @invalid_memref_cast() {
+ %0 = memref.alloc() : memref<2x5xf32, 0>
+ %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0>
+ // expected-error@+1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32, 1>' are cast incompatible}}
+ %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xf32, 1>
return
}
|
| func.func @invalid_memref_cast() { | ||
| %0 = memref.alloc() : memref<2x5xf32, 0> | ||
| %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0> | ||
| // expected-error@+1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32>' are cast incompatible}} | ||
| %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xf32, 0> |
There was a problem hiding this comment.
If you're erasing this old testcase, better to move it to a valid.mlir?
|
The changes lgtm, but the initial design why we are not allowing this semantic (unranked to unranked) maybe has another meaning? |
|
Please update the op documentation along with the change, it currently spells out that only ranked memrefs are supported. I don't have a strong reason not to allow this, but this needs a better, more specific justification than "impacts downstream projects", ideally that relates to upstream mlir. |
I have reviewed the MLIR documentation regarding memref.cast https://mlir.llvm.org/docs/Dialects/MemRef/#memrefcast-memrefcastop, it say "b. Either or both memref types are unranked with the same element type, and address space.", |
|
Good point, I missed that in the documentation. |
|
When merging this PR, I added to the description: 'Fixed the issue where its behavior was inconsistent with the documentation. |
…et same types (llvm#192029) When both the source and destination types of `memref.cast` are unranked, it causes an IR verification failure, which impacts downstream projects and its behavior is inconsistent with the documentation. To address this, this PR now allows the operation to return true if the source and destination types are identical.
…et same types (llvm#192029) When both the source and destination types of `memref.cast` are unranked, it causes an IR verification failure, which impacts downstream projects and its behavior is inconsistent with the documentation. To address this, this PR now allows the operation to return true if the source and destination types are identical.
When both the source and destination types of
memref.castare unranked, it causes an IR verification failure, which impacts downstream projects and its behavior is inconsistent with the documentation. To address this, this PR now allows the operation to return true if the source and destination types are identical.