diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 844e6183cff06..2c2a4cf65f2eb 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1071,22 +1071,20 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) { memrefType.getDynamicDimIndex(unsignedIndex)); if (auto subview = dyn_cast_or_null(definingOp)) { - llvm::SmallBitVector unusedDims = subview.getDroppedDims(); - unsigned resultIndex = 0; - unsigned sourceRank = subview.getSourceType().getRank(); - unsigned sourceIndex = 0; - for (auto i : llvm::seq(0, sourceRank)) { - if (unusedDims.test(i)) + // The result dim is dynamic (the static case was handled above). Dropped + // dims always have static size 1, so dynamic source sizes are never + // dropped and map in order to the dynamic result dims. Find the k-th + // dynamic source size, where k is the dynamic dim index of the result dim. + unsigned dynamicResultDimIdx = memrefType.getDynamicDimIndex(unsignedIndex); + unsigned dynamicIdx = 0; + for (OpFoldResult size : subview.getMixedSizes()) { + if (llvm::isa(size)) continue; - if (resultIndex == unsignedIndex) { - sourceIndex = i; - break; - } - resultIndex++; + if (dynamicIdx == dynamicResultDimIdx) + return size; + dynamicIdx++; } - assert(subview.isDynamicSize(sourceIndex) && - "expected dynamic subview size"); - return subview.getDynamicSize(sourceIndex); + return {}; } // dim(memrefcast) -> dim diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 3cfea1e8cd961..9c081398e8d55 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -1639,3 +1639,24 @@ func.func @non_replace_view_negative_static_dims(%src: memref, %offset : i %res = memref.view %src[%offset][%c-1] : memref to memref return %res : memref } + +// ----- + +// Verify that canonicalization does not crash when a memref.dim is applied to +// a subview with ambiguous dropped dimensions (multiple size-1 source dims with +// all-dynamic strides). The dim should be folded to the corresponding subview +// size operand. +// See: https://github.com/llvm/llvm-project/issues/111244 + +// CHECK-LABEL: func @no_crash_dim_of_ambiguous_subview +// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: index) -> index +// CHECK-NOT: memref.dim +// CHECK: return %[[ARG1]] +func.func @no_crash_dim_of_ambiguous_subview( + %arg0: memref>, %arg1: index) -> index { + %c1 = arith.constant 1 : index + %subview = memref.subview %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] + : memref> to memref<1x?xf32, strided<[?, ?], offset: ?>> + %dim = memref.dim %subview, %c1 : memref<1x?xf32, strided<[?, ?], offset: ?>> + return %dim : index +}