diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp index c031118606823..753cb95b1c906 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" @@ -158,7 +159,11 @@ struct ExtractSliceOpInterface // 0 <= offset + (size - 1) * stride < dim_size Value zero = arith::ConstantIndexOp::create(builder, loc, 0); Value one = arith::ConstantIndexOp::create(builder, loc, 1); - for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) { + + for (int64_t i : llvm::seq(0, sourceType.getRank())) { + // Reset insertion point to before the operation for each dimension + builder.setInsertionPoint(extractSliceOp); + Value offset = getValueOrCreateConstantIndexOp( builder, loc, extractSliceOp.getMixedOffsets()[i]); Value size = getValueOrCreateConstantIndexOp( @@ -176,6 +181,16 @@ struct ExtractSliceOpInterface std::to_string(i) + " is out-of-bounds")); + // Only verify if size > 0 + Value sizeIsNonZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::sgt, size, zero); + + auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(), + sizeIsNonZero, /*withElseRegion=*/true); + + // Populate the "then" region (for size > 0). + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + // Verify that slice does not run out-of-bounds. Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); Value sizeMinusOneTimesStride = @@ -184,8 +199,19 @@ struct ExtractSliceOpInterface arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); Value lastPosInBounds = generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); + scf::YieldOp::create(builder, loc, lastPosInBounds); + + // Populate the "else" region (for size == 0). + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + Value trueVal = + arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); + scf::YieldOp::create(builder, loc, trueVal); + + builder.setInsertionPointAfter(ifOp); + Value finalCondition = ifOp.getResult(0); + cf::AssertOp::create( - builder, loc, lastPosInBounds, + builder, loc, finalCondition, generateErrorMessage( op, "extract_slice runs out-of-bounds along dimension " + std::to_string(i))); diff --git a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir index 0c7c4a6cb2d6f..a77fa310a3699 100644 --- a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir @@ -34,6 +34,12 @@ func.func @extract_slice_dynamic_rank_reduce(%tensor: tensor, %offset: return } +func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>, %dim_0: index, %dim_1: index, %dim_2: index) { + tensor.extract_slice %arg0[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : tensor<10x4x1xf32> to tensor + return +} + + func.func @main() { %0 = arith.constant 0 : index %1 = arith.constant 1 : index @@ -101,6 +107,13 @@ func.func @main() { // CHECK-NOT: ERROR: Runtime op verification failed func.call @extract_slice_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (tensor, index, index, index) -> () + %cst10x4x1xf32 = arith.constant dense<1.0> : tensor<10x4x1xf32> + + // CHECK-NOT: ERROR: Runtime op verification failed + %dim_0 = arith.constant 0 : index + %dim_1 = arith.constant 4 : index + %dim_2 = arith.constant 1 : index + func.call @extract_slice_zero_size_dim(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> () return }