diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp index 753cb95b1c906..d35f458cbdb36 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp @@ -155,13 +155,15 @@ struct ExtractSliceOpInterface RankedTensorType sourceType = extractSliceOp.getSource().getType(); // For each dimension, assert that: - // 0 <= offset < dim_size - // 0 <= offset + (size - 1) * stride < dim_size + // For empty slices (size == 0) : 0 <= offset <= dim_size + // For non-empty slices (size > 0): 0 <= offset < dim_size + // 0 <= offset + (size - 1) * stride < + // dim_size Value zero = arith::ConstantIndexOp::create(builder, loc, 0); Value one = arith::ConstantIndexOp::create(builder, loc, 1); for (int64_t i : llvm::seq(0, sourceType.getRank())) { - // Reset insertion point to before the operation for each dimension + builder.setInsertionPoint(extractSliceOp); Value offset = getValueOrCreateConstantIndexOp( @@ -170,46 +172,63 @@ struct ExtractSliceOpInterface builder, loc, extractSliceOp.getMixedSizes()[i]); Value stride = getValueOrCreateConstantIndexOp( builder, loc, extractSliceOp.getMixedStrides()[i]); - - // Verify that offset is in-bounds. Value dimSize = builder.createOrFold( loc, extractSliceOp.getSource(), i); - Value offsetInBounds = - generateInBoundsCheck(builder, loc, offset, zero, dimSize); - cf::AssertOp::create(builder, loc, offsetInBounds, + + // Verify that offset is in-bounds (conditional on slice size). + Value sizeIsZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, size, zero); + auto offsetCheckIf = scf::IfOp::create( + builder, loc, sizeIsZero, + [&](OpBuilder &b, Location loc) { + // For empty slices, offset can be at the boundary: 0 <= offset <= + // dimSize. + Value offsetGEZero = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sge, offset, zero); + Value offsetLEDimSize = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sle, offset, dimSize); + Value emptyOffsetValid = + arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize); + scf::YieldOp::create(b, loc, emptyOffsetValid); + }, + [&](OpBuilder &b, Location loc) { + // For non-empty slices, offset must be a valid index: 0 <= offset < + // dimSize. + Value offsetInBounds = + generateInBoundsCheck(b, loc, offset, zero, dimSize); + scf::YieldOp::create(b, loc, offsetInBounds); + }); + + Value offsetCondition = offsetCheckIf.getResult(0); + cf::AssertOp::create(builder, loc, offsetCondition, generateErrorMessage(op, "offset " + std::to_string(i) + " is out-of-bounds")); - // Only verify if size > 0 + // Verify that the slice endpoint is in-bounds (only for non-empty + // slices). Value sizeIsNonZero = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::sgt, size, zero); + auto ifOp = scf::IfOp::create( + builder, loc, sizeIsNonZero, + [&](OpBuilder &b, Location loc) { + // Verify that slice does not run out-of-bounds. + Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one); + Value sizeMinusOneTimesStride = + arith::MulIOp::create(b, loc, sizeMinusOne, stride); + Value lastPos = + arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride); + Value lastPosInBounds = + generateInBoundsCheck(b, loc, lastPos, zero, dimSize); + scf::YieldOp::create(b, loc, lastPosInBounds); + }, + [&](OpBuilder &b, Location loc) { + Value trueVal = + arith::ConstantOp::create(b, loc, b.getBoolAttr(true)); + scf::YieldOp::create(b, loc, trueVal); + }); - auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(), - sizeIsNonZero, /*withElseRegion=*/true); - - // Populate the "then" region (for size > 0). - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - - // Verify that slice does not run out-of-bounds. - Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); - Value sizeMinusOneTimesStride = - arith::MulIOp::create(builder, loc, sizeMinusOne, stride); - Value lastPos = - arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); - Value lastPosInBounds = - generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); - scf::YieldOp::create(builder, loc, lastPosInBounds); - - // Populate the "else" region (for size == 0). - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - Value trueVal = - arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); - scf::YieldOp::create(builder, loc, trueVal); - - builder.setInsertionPointAfter(ifOp); Value finalCondition = ifOp.getResult(0); - cf::AssertOp::create( builder, loc, finalCondition, generateErrorMessage( diff --git a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir index a77fa310a3699..745eea37f7fca 100644 --- a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir @@ -39,6 +39,11 @@ func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>, %dim_0: index, return } +func.func @extract_slice_empty_tensor(%arg0: tensor<10x4x1xf32>, %dim_0: index, %dim_1: index, %dim_2: index, %offset: index) { + tensor.extract_slice %arg0[%offset, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : tensor<10x4x1xf32> to tensor + return +} + func.func @main() { %0 = arith.constant 0 : index @@ -115,5 +120,9 @@ func.func @main() { %dim_2 = arith.constant 1 : index func.call @extract_slice_zero_size_dim(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> () + // CHECK-NOT: ERROR: Runtime op verification failed + %offset = arith.constant 10 : index + func.call @extract_slice_empty_tensor(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2, %offset) : (tensor<10x4x1xf32>, index, index, index, index) -> () + return }