Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"

using namespace mlir;
Expand Down Expand Up @@ -273,7 +274,9 @@ struct SubViewOpInterface
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
auto metadataOp =
ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
// Reset insertion point to before the operation for each dimension
builder.setInsertionPoint(subView);
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(builder, loc,
Expand All @@ -290,6 +293,16 @@ struct SubViewOpInterface
std::to_string(i) +
" is out-of-bounds"));

// Only verify if size > 0
Value sizeIsNonZero = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::sgt, size, zero);

auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
sizeIsNonZero, /*withElseRegion=*/true);

// Populate the "then" region (for size > 0).
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());

// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
Expand All @@ -298,8 +311,20 @@ struct SubViewOpInterface
arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
Value lastPosInBounds =
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);

scf::YieldOp::create(builder, loc, lastPosInBounds);

// Populate the "else" region (for size == 0).
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
Value trueVal =
arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
scf::YieldOp::create(builder, loc, trueVal);

builder.setInsertionPointAfter(ifOp);
Value finalCondition = ifOp.getResult(0);

cf::AssertOp::create(
builder, loc, lastPosInBounds,
builder, loc, finalCondition,
generateErrorMessage(op,
"subview runs out-of-bounds along dimension " +
std::to_string(i)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// RUN: -expand-strided-metadata \
// RUN: -lower-affine \
// RUN: -test-cf-assert \
// RUN: -convert-scf-to-cf \
// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
Expand All @@ -11,6 +12,7 @@
// RUN: -expand-strided-metadata \
// RUN: -lower-affine \
// RUN: -test-cf-assert \
// RUN: -convert-scf-to-cf \
// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
Expand Down Expand Up @@ -38,6 +40,17 @@ func.func @subview_dynamic_rank_reduce(%memref: memref<?x4xf32>, %offset: index,
return
}

func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>,
%dim_0: index,
%dim_1: index,
%dim_2: index) {
%subview = memref.subview %memref[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] :
memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to
memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
return
}


func.func @main() {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
Expand Down Expand Up @@ -105,6 +118,14 @@ func.func @main() {
// CHECK-NOT: ERROR: Runtime op verification failed
func.call @subview_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (memref<?x4xf32>, index, index, index) -> ()

%alloca_10x4x1 = memref.alloca() : memref<10x4x1xf32>
%alloca_10x4x1_dyn_stride = memref.cast %alloca_10x4x1 : memref<10x4x1xf32> to memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>
// CHECK-NOT: ERROR: Runtime op verification failed
%dim_0 = arith.constant 0 : index
%dim_1 = arith.constant 4 : index
%dim_2 = arith.constant 1 : index
func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2)
: (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> ()

return
}