Skip to content

Conversation

@Hanumanth04
Copy link
Contributor

Previously, the runtime verification pass would insert assertion statements with conditions that always evaluate to false for semantically valid memref.subview operations where one of the dimensions had a size of 0.

The memref.subview runtime verification logic was unconditionally generating checks for the position of the last element (offset + (size - 1) * stride). When size is 0, this causes the assertion condition to always be false, leading to runtime failures even though the operation is semantically valid.

This patch fixes the issue by making the lastPos check conditional. The offset is always verified, but the endpoint check is only performed when size > 0 to avoid generating spurious assert statements.

This issue was discovered through a LiteRT model, where a dynamic shape calculation resulted in a zero-sized dimension being passed to memref.subview. The following is a simplified IR snippet from the model. After running the runtime verification pass, an assertion that always fails is generated because the SSA value %5 becomes 0.

module {
  memref.global "private" constant @__constant_2xi32 : memref<2xi32> = dense<-1> {alignment = 64 : i64}
  memref.global "private" constant @__constant_1xi32 : memref<1xi32> = dense<0> {alignment = 64 : i64}
  func.func @simpleRepro(%arg0: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
    %c2 = arith.constant 2 : index
    %c4 = arith.constant 4 : index
    %c1 = arith.constant 1 : index
    %c10 = arith.constant 10 : index
    %c0 = arith.constant 0 : index
    %c-1 = arith.constant -1 : index
    %0 = memref.get_global @__constant_1xi32 : memref<1xi32>
    %1 = memref.get_global @__constant_2xi32 : memref<2xi32>
    %alloca = memref.alloca() {alignment = 64 : i64} : memref<3xi32>
    %subview = memref.subview %alloca[0] [1] [1] : memref<3xi32> to memref<1xi32, strided<[1]>>
    memref.copy %0, %subview : memref<1xi32> to memref<1xi32, strided<[1]>>
    %subview_0 = memref.subview %alloca[1] [2] [1] : memref<3xi32> to memref<2xi32, strided<[1], offset: 1>>
    memref.copy %1, %subview_0 : memref<2xi32> to memref<2xi32, strided<[1], offset: 1>>
    %2 = memref.load %alloca[%c0] : memref<3xi32>
    %3 = index.casts %2 : i32 to index
    %4 = arith.cmpi eq, %3, %c-1 : index
    %5 = arith.select %4, %c10, %3 : index
    %6 = memref.load %alloca[%c1] : memref<3xi32>
    %7 = index.casts %6 : i32 to index
    %8 = arith.cmpi eq, %7, %c-1 : index
    %9 = arith.select %8, %c4, %7 : index
    %10 = memref.load %alloca[%c2] : memref<3xi32>
    %11 = index.casts %10 : i32 to index
    %12 = arith.cmpi eq, %11, %c-1 : index
    %13 = arith.select %12, %c1, %11 : index
    %subview_1 = memref.subview %arg0[0, 0, 0] [%5, %9, %13] [1, 1, 1] : memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
    return %subview_1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
  }
}

P.S. This is a similar issue to the one fixed for tensor.extract_slice in #164878

@llvmbot
Copy link
Member

llvmbot commented Oct 23, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Hanumanth (Hanumanth04)

Changes

Previously, the runtime verification pass would insert assertion statements with conditions that always evaluate to false for semantically valid memref.subview operations where one of the dimensions had a size of 0.

The memref.subview runtime verification logic was unconditionally generating checks for the position of the last element (offset + (size - 1) * stride). When size is 0, this causes the assertion condition to always be false, leading to runtime failures even though the operation is semantically valid.

This patch fixes the issue by making the lastPos check conditional. The offset is always verified, but the endpoint check is only performed when size &gt; 0 to avoid generating spurious assert statements.

This issue was discovered through a LiteRT model, where a dynamic shape calculation resulted in a zero-sized dimension being passed to memref.subview. The following is a simplified IR snippet from the model. After running the runtime verification pass, an assertion that always fails is generated because the SSA value %5 becomes 0.

module {
  memref.global "private" constant @<!-- -->__constant_2xi32 : memref&lt;2xi32&gt; = dense&lt;-1&gt; {alignment = 64 : i64}
  memref.global "private" constant @<!-- -->__constant_1xi32 : memref&lt;1xi32&gt; = dense&lt;0&gt; {alignment = 64 : i64}
  func.func @<!-- -->simpleRepro(%arg0: memref&lt;10x4x1xf32, strided&lt;[?, ?, ?], offset: ?&gt;&gt;) -&gt; memref&lt;?x?x?xf32, strided&lt;[?, ?, ?], offset: ?&gt;&gt; {
    %c2 = arith.constant 2 : index
    %c4 = arith.constant 4 : index
    %c1 = arith.constant 1 : index
    %c10 = arith.constant 10 : index
    %c0 = arith.constant 0 : index
    %c-1 = arith.constant -1 : index
    %0 = memref.get_global @<!-- -->__constant_1xi32 : memref&lt;1xi32&gt;
    %1 = memref.get_global @<!-- -->__constant_2xi32 : memref&lt;2xi32&gt;
    %alloca = memref.alloca() {alignment = 64 : i64} : memref&lt;3xi32&gt;
    %subview = memref.subview %alloca[0] [1] [1] : memref&lt;3xi32&gt; to memref&lt;1xi32, strided&lt;[1]&gt;&gt;
    memref.copy %0, %subview : memref&lt;1xi32&gt; to memref&lt;1xi32, strided&lt;[1]&gt;&gt;
    %subview_0 = memref.subview %alloca[1] [2] [1] : memref&lt;3xi32&gt; to memref&lt;2xi32, strided&lt;[1], offset: 1&gt;&gt;
    memref.copy %1, %subview_0 : memref&lt;2xi32&gt; to memref&lt;2xi32, strided&lt;[1], offset: 1&gt;&gt;
    %2 = memref.load %alloca[%c0] : memref&lt;3xi32&gt;
    %3 = index.casts %2 : i32 to index
    %4 = arith.cmpi eq, %3, %c-1 : index
    %5 = arith.select %4, %c10, %3 : index
    %6 = memref.load %alloca[%c1] : memref&lt;3xi32&gt;
    %7 = index.casts %6 : i32 to index
    %8 = arith.cmpi eq, %7, %c-1 : index
    %9 = arith.select %8, %c4, %7 : index
    %10 = memref.load %alloca[%c2] : memref&lt;3xi32&gt;
    %11 = index.casts %10 : i32 to index
    %12 = arith.cmpi eq, %11, %c-1 : index
    %13 = arith.select %12, %c1, %11 : index
    %subview_1 = memref.subview %arg0[0, 0, 0] [%5, %9, %13] [1, 1, 1] : memref&lt;10x4x1xf32, strided&lt;[?, ?, ?], offset: ?&gt;&gt; to memref&lt;?x?x?xf32, strided&lt;[?, ?, ?], offset: ?&gt;&gt;
    return %subview_1 : memref&lt;?x?x?xf32, strided&lt;[?, ?, ?], offset: ?&gt;&gt;
  }
}

P.S. This is a similar issue to the one fixed for tensor.extract_slice in #164878


Full diff: https://github.com/llvm/llvm-project/pull/164897.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+41-1)
  • (modified) mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir (+19)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 291da1f76ca9b..1979d5b7e6310 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -273,7 +274,9 @@ struct SubViewOpInterface
     Value one = arith::ConstantIndexOp::create(builder, loc, 1);
     auto metadataOp =
         ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
-    for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+    for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
+      // Reset insertion point to before the operation for each dimension
+      builder.setInsertionPoint(subView);
       Value offset = getValueOrCreateConstantIndexOp(
           builder, loc, subView.getMixedOffsets()[i]);
       Value size = getValueOrCreateConstantIndexOp(builder, loc,
@@ -290,6 +293,42 @@ struct SubViewOpInterface
                                                         std::to_string(i) +
                                                         " is out-of-bounds"));
 
+      // Only verify if size > 0
+      Value sizeIsNonZero = arith::CmpIOp::create(
+          builder, loc, arith::CmpIPredicate::sgt, size, zero);
+
+      /*
+       * Split the current block to create the below control flow structure:
+       *
+       * ^preCondBlock:
+       *   ... // offset check already done above
+       *   %size_nonzero = arith.cmpi sgt, %size, %zero
+       *   cf.cond_br %size_nonzero, ^sizeBoundsCheckBlock, ^afterCheckBlock
+       *
+       * ^sizeBoundsCheckBlock:
+       *   %last_pos = ... // compute offset + (size-1) * stride
+       *   %last_pos_ok = ... // last position bounds check
+       *   cf.assert %last_pos_ok, "extract_slice runs out-of-bounds"
+       *   cf.br ^afterCheckBlock
+       *
+       * ^afterCheckBlock:
+       *   tensor.extract_slice ... // the original operation
+       */
+      Block *preCondBlock = builder.getBlock();
+      Block *afterCheckBlock = preCondBlock->splitBlock(subView);
+
+      // Create the block for conditional size bounds verification.
+      Block *sizeBoundsCheckBlock = builder.createBlock(
+          preCondBlock->getParent(), Region::iterator(afterCheckBlock));
+
+      // Terminate the pre-condition block with the conditional branch.
+      builder.setInsertionPointToEnd(preCondBlock);
+      cf::CondBranchOp::create(builder, loc, sizeIsNonZero,
+                               sizeBoundsCheckBlock, afterCheckBlock);
+
+      // Populate the size bounds check block with lastPos verification.
+      builder.setInsertionPointToStart(sizeBoundsCheckBlock);
+
       // Verify that slice does not run out-of-bounds.
       Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
       Value sizeMinusOneTimesStride =
@@ -303,6 +342,7 @@ struct SubViewOpInterface
           generateErrorMessage(op,
                                "subview runs out-of-bounds along dimension " +
                                    std::to_string(i)));
+      cf::BranchOp::create(builder, loc, afterCheckBlock);
     }
   }
 };
diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index 71e813c0a6300..001c435086976 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -38,6 +38,17 @@ func.func @subview_dynamic_rank_reduce(%memref: memref<?x4xf32>, %offset: index,
     return
 }
 
+func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, 
+                                 %dim_0: index, 
+                                 %dim_1: index, 
+                                 %dim_2: index) {
+    %subview = memref.subview %memref[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] :
+        memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to
+        memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+    return
+}
+
+
 func.func @main() {
   %0 = arith.constant 0 : index
   %1 = arith.constant 1 : index
@@ -105,6 +116,14 @@ func.func @main() {
   // CHECK-NOT: ERROR: Runtime op verification failed
   func.call @subview_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (memref<?x4xf32>, index, index, index) -> ()
 
+  %alloca_10x4x1 = memref.alloca() : memref<10x4x1xf32>
+  %alloca_10x4x1_dyn_stride = memref.cast %alloca_10x4x1 : memref<10x4x1xf32> to memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  %dim_0 = arith.constant 0 : index
+  %dim_1 = arith.constant 4 : index
+  %dim_2 = arith.constant 1 : index
+  func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2)
+                                        : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> ()
 
   return
 }

@Hanumanth04
Copy link
Contributor Author

Hi @matthias-springer, could you please look at this PR when you get a chance? This is similar to the fix in #164878. Thanks!

@matthias-springer matthias-springer merged commit cbe7c49 into llvm:main Oct 27, 2025
10 checks passed
dvbuka pushed a commit to dvbuka/llvm-project that referenced this pull request Oct 27, 2025
…dimension value is 0 (llvm#164897)

Previously, the runtime verification pass would insert assertion
statements with conditions that always evaluate to false for
semantically valid `memref.subview` operations where one of the
dimensions had a size of 0.

The `memref.subview` runtime verification logic was unconditionally
generating checks for the position of the last element (`offset + (size
- 1) * stride`). When `size` is 0, this causes the assertion condition
to always be false, leading to runtime failures even though the
operation is semantically valid.

This patch fixes the issue by making the `lastPos` check conditional.
The offset is always verified, but the endpoint check is only performed
when `size > 0` to avoid generating spurious assert statements.

This issue was discovered through a LiteRT model, where a dynamic shape
calculation resulted in a zero-sized dimension being passed to
`memref.subview`. The following is a simplified IR snippet from the
model. After running the runtime verification pass, an assertion that
always fails is generated because the SSA value `%5` becomes 0.

```mlir
module {
  memref.global "private" constant @__constant_2xi32 : memref<2xi32> = dense<-1> {alignment = 64 : i64}
  memref.global "private" constant @__constant_1xi32 : memref<1xi32> = dense<0> {alignment = 64 : i64}
  func.func @simpleRepro(%arg0: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
    %c2 = arith.constant 2 : index
    %c4 = arith.constant 4 : index
    %c1 = arith.constant 1 : index
    %c10 = arith.constant 10 : index
    %c0 = arith.constant 0 : index
    %c-1 = arith.constant -1 : index
    %0 = memref.get_global @__constant_1xi32 : memref<1xi32>
    %1 = memref.get_global @__constant_2xi32 : memref<2xi32>
    %alloca = memref.alloca() {alignment = 64 : i64} : memref<3xi32>
    %subview = memref.subview %alloca[0] [1] [1] : memref<3xi32> to memref<1xi32, strided<[1]>>
    memref.copy %0, %subview : memref<1xi32> to memref<1xi32, strided<[1]>>
    %subview_0 = memref.subview %alloca[1] [2] [1] : memref<3xi32> to memref<2xi32, strided<[1], offset: 1>>
    memref.copy %1, %subview_0 : memref<2xi32> to memref<2xi32, strided<[1], offset: 1>>
    %2 = memref.load %alloca[%c0] : memref<3xi32>
    %3 = index.casts %2 : i32 to index
    %4 = arith.cmpi eq, %3, %c-1 : index
    %5 = arith.select %4, %c10, %3 : index
    %6 = memref.load %alloca[%c1] : memref<3xi32>
    %7 = index.casts %6 : i32 to index
    %8 = arith.cmpi eq, %7, %c-1 : index
    %9 = arith.select %8, %c4, %7 : index
    %10 = memref.load %alloca[%c2] : memref<3xi32>
    %11 = index.casts %10 : i32 to index
    %12 = arith.cmpi eq, %11, %c-1 : index
    %13 = arith.select %12, %c1, %11 : index
    %subview_1 = memref.subview %arg0[0, 0, 0] [%5, %9, %13] [1, 1, 1] : memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
    return %subview_1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
  }
}
```

P.S. This is a similar issue to the one fixed for `tensor.extract_slice`
in llvm#164878

---------

Co-authored-by: Hanumanth Hanumantharayappa <hhanuman@ah-hhanuman-l.dhcp.mathworks.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants