diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 291da1f76ca9b..14152c5a1af0c 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" using namespace mlir; @@ -273,7 +274,9 @@ struct SubViewOpInterface Value one = arith::ConstantIndexOp::create(builder, loc, 1); auto metadataOp = ExtractStridedMetadataOp::create(builder, loc, subView.getSource()); - for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) { + for (int64_t i : llvm::seq(0, sourceType.getRank())) { + // Reset insertion point to before the operation for each dimension + builder.setInsertionPoint(subView); Value offset = getValueOrCreateConstantIndexOp( builder, loc, subView.getMixedOffsets()[i]); Value size = getValueOrCreateConstantIndexOp(builder, loc, @@ -290,6 +293,16 @@ struct SubViewOpInterface std::to_string(i) + " is out-of-bounds")); + // Only verify if size > 0 + Value sizeIsNonZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::sgt, size, zero); + + auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(), + sizeIsNonZero, /*withElseRegion=*/true); + + // Populate the "then" region (for size > 0). + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + // Verify that slice does not run out-of-bounds. Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); Value sizeMinusOneTimesStride = @@ -298,8 +311,20 @@ struct SubViewOpInterface arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); Value lastPosInBounds = generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); + + scf::YieldOp::create(builder, loc, lastPosInBounds); + + // Populate the "else" region (for size == 0). + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + Value trueVal = + arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); + scf::YieldOp::create(builder, loc, trueVal); + + builder.setInsertionPointAfter(ifOp); + Value finalCondition = ifOp.getResult(0); + cf::AssertOp::create( - builder, loc, lastPosInBounds, + builder, loc, finalCondition, generateErrorMessage(op, "subview runs out-of-bounds along dimension " + std::to_string(i))); diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir index 71e813c0a6300..84875675ac3d0 100644 --- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir @@ -2,6 +2,7 @@ // RUN: -expand-strided-metadata \ // RUN: -lower-affine \ // RUN: -test-cf-assert \ +// RUN: -convert-scf-to-cf \ // RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ @@ -11,6 +12,7 @@ // RUN: -expand-strided-metadata \ // RUN: -lower-affine \ // RUN: -test-cf-assert \ +// RUN: -convert-scf-to-cf \ // RUN: -convert-to-llvm="allow-pattern-rollback=0" \ // RUN: -reconcile-unrealized-casts | \ // RUN: mlir-runner -e main -entry-point-result=void \ @@ -38,6 +40,17 @@ func.func @subview_dynamic_rank_reduce(%memref: memref, %offset: index, return } +func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, + %dim_0: index, + %dim_1: index, + %dim_2: index) { + %subview = memref.subview %memref[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : + memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to + memref> + return +} + + func.func @main() { %0 = arith.constant 0 : index %1 = arith.constant 1 : index @@ -105,6 +118,14 @@ func.func @main() { // CHECK-NOT: ERROR: Runtime op verification failed func.call @subview_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (memref, index, index, index) -> () + %alloca_10x4x1 = memref.alloca() : memref<10x4x1xf32> + %alloca_10x4x1_dyn_stride = memref.cast %alloca_10x4x1 : memref<10x4x1xf32> to memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> + // CHECK-NOT: ERROR: Runtime op verification failed + %dim_0 = arith.constant 0 : index + %dim_1 = arith.constant 4 : index + %dim_2 = arith.constant 1 : index + func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2) + : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> () return }