Skip to content

Conversation

@hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Nov 4, 2025

Reverts #157330

The original revision introduces a bug in isGuaranteedCollapsible. The memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>> is no longer collapsable with the change. The revision reverts the change to bring back correct behavior. stride should be computed as 96 like the old behavior in the failed iteration.

bool CollapseShapeOp::isGuaranteedCollapsible(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
// MemRefs with identity layout are always collapsible.
if (srcType.getLayout().isIdentity())
return true;
return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
/*strict=*/true));
}

@llvmbot
Copy link
Member

llvmbot commented Nov 4, 2025

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

Reverts llvm/llvm-project#157330

The original revision introduces a bug in isGuaranteedCollapsible. The memref&lt;3x3x1x96xf32, strided&lt;[288, 96, 96, 1], offset: 864&gt;&gt; is no longer collapsable with the change. The revision reverts the change to bring back correct behavior. stride should be computed as 96 like the old behavior in the failed iteration.

bool CollapseShapeOp::isGuaranteedCollapsible(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
// MemRefs with identity layout are always collapsible.
if (srcType.getLayout().isIdentity())
return true;
return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
/*strict=*/true));
}


Full diff: https://github.com/llvm/llvm-project/pull/166448.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+5-5)
  • (modified) mlir/test/Dialect/MemRef/ops.mlir (+1-6)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e271ac58db327..1c21a2f270da6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2568,11 +2568,6 @@ computeCollapsedLayoutMap(MemRefType srcType,
     auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
     auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
     for (int64_t idx : llvm::reverse(trailingReassocs)) {
-      // Dimensions of size 1 should be skipped, because their strides are
-      // meaningless and could have any arbitrary value.
-      if (srcShape[idx - 1] == 1)
-        continue;
-
       stride = stride * SaturatedInteger::wrap(srcShape[idx]);
 
       // Both source and result stride must have the same static value. In that
@@ -2587,6 +2582,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
       if (strict && (stride.saturated || srcStride.saturated))
         return failure();
 
+      // Dimensions of size 1 should be skipped, because their strides are
+      // meaningless and could have any arbitrary value.
+      if (srcShape[idx - 1] == 1)
+        continue;
+
       if (!stride.saturated && !srcStride.saturated && stride != srcStride)
         return failure();
     }
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index b1db99bb3ad08..a90c9505a8405 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -440,8 +440,7 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
          %arg4: index,
          %arg5: index,
          %arg6: index,
-         %arg7: memref<4x?x4xf32>,
-         %arg8: memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>>) {
+         %arg7: memref<4x?x4xf32>) {
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
   %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
@@ -490,10 +489,6 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
 //       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
   %4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
         : memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
-
-//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3]]
-//  CHECK-SAME:     memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xsi8, strided<[?, ?, 1], offset: ?>>
-  %5 = memref.collapse_shape %arg8 [[0, 1], [2], [3]] : memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xsi8, strided<[?, ?, 1], offset: ?>>
   return
 }
 

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert approved

@hanhanW hanhanW enabled auto-merge (squash) November 4, 2025 21:48
@hanhanW hanhanW merged commit b21949e into main Nov 4, 2025
11 of 12 checks passed
@hanhanW hanhanW deleted the revert-157330-amrami/collapse_1 branch November 4, 2025 21:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants