Skip to content

Commit

Permalink
[mlir] Harden verifiers for DMA ops
Browse files Browse the repository at this point in the history
DMA operation classes in the Standard dialect (`DmaStartOp` and `DmaWaitOp`)
provide helper functions that make numerous assumptions about the number and
order of operands, and about their types. However, these assumptions were not
checked in the verifier, leading to assertion failures or crashes when helper
functions were used on ill-formed ops. Some of the assuptions were checked in
the custom parser (and thus could not check assumption violations in ops
constructed programmatically, e.g., during rewrites) and others were not
checked at all. Introduce the verifiers for all these assumptions and drop
unnecessary checks in the parser that are now covered by the verifier.

Addresses PR45560.

Differential Revision: https://reviews.llvm.org/D79408
  • Loading branch information
ftynse committed May 5, 2020
1 parent 0195b3a commit 9d273c0
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 37 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
Expand Up @@ -286,6 +286,7 @@ class DmaWaitOp
void print(OpAsmPrinter &p);
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);
LogicalResult verify();
};

/// Prints dimension and symbol list.
Expand Down
122 changes: 86 additions & 36 deletions mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Expand Up @@ -1444,49 +1444,82 @@ ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
parser.resolveOperands(tagIndexInfos, indexType, result.operands))
return failure();

auto memrefType0 = types[0].dyn_cast<MemRefType>();
if (!memrefType0)
return parser.emitError(parser.getNameLoc(),
"expected source to be of memref type");

auto memrefType1 = types[1].dyn_cast<MemRefType>();
if (!memrefType1)
return parser.emitError(parser.getNameLoc(),
"expected destination to be of memref type");

auto memrefType2 = types[2].dyn_cast<MemRefType>();
if (!memrefType2)
return parser.emitError(parser.getNameLoc(),
"expected tag to be of memref type");

if (isStrided) {
if (parser.resolveOperands(strideInfo, indexType, result.operands))
return failure();
}

// Check that source/destination index list size matches associated rank.
if (static_cast<int64_t>(srcIndexInfos.size()) != memrefType0.getRank() ||
static_cast<int64_t>(dstIndexInfos.size()) != memrefType1.getRank())
return parser.emitError(parser.getNameLoc(),
"memref rank not equal to indices count");
if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType2.getRank())
return parser.emitError(parser.getNameLoc(),
"tag memref rank not equal to indices count");

return success();
}

LogicalResult DmaStartOp::verify() {
unsigned numOperands = getNumOperands();

// Mandatory non-variadic operands are: src memref, dst memref, tag memref and
// the number of elements.
if (numOperands < 4)
return emitOpError("expected at least 4 operands");

// Check types of operands. The order of these calls is important: the later
// calls rely on some type properties to compute the operand position.
// 1. Source memref.
if (!getSrcMemRef().getType().isa<MemRefType>())
return emitOpError("expected source to be of memref type");
if (numOperands < getSrcMemRefRank() + 4)
return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
<< " operands";
if (!getSrcIndices().empty() &&
!llvm::all_of(getSrcIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
return emitOpError("expected source indices to be of index type");

// 2. Destination memref.
if (!getDstMemRef().getType().isa<MemRefType>())
return emitOpError("expected destination to be of memref type");
unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
if (numOperands < numExpectedOperands)
return emitOpError() << "expected at least " << numExpectedOperands
<< " operands";
if (!getDstIndices().empty() &&
!llvm::all_of(getDstIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
return emitOpError("expected destination indices to be of index type");

// 3. Number of elements.
if (!getNumElements().getType().isIndex())
return emitOpError("expected num elements to be of index type");

// 4. Tag memref.
if (!getTagMemRef().getType().isa<MemRefType>())
return emitOpError("expected tag to be of memref type");
numExpectedOperands += getTagMemRefRank();
if (numOperands < numExpectedOperands)
return emitOpError() << "expected at least " << numExpectedOperands
<< " operands";
if (!getTagIndices().empty() &&
!llvm::all_of(getTagIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
return emitOpError("expected tag indices to be of index type");

// DMAs from different memory spaces supported.
if (getSrcMemorySpace() == getDstMemorySpace())
return emitOpError("DMA should be between different memory spaces");

if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
getDstMemRefRank() + 3 + 1 &&
getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
getDstMemRefRank() + 3 + 1 + 2) {
// Optional stride-related operands must be either both present or both
// absent.
if (numOperands != numExpectedOperands &&
numOperands != numExpectedOperands + 2)
return emitOpError("incorrect number of operands");

// 5. Strides.
if (isStrided()) {
if (!getStride().getType().isIndex() ||
!getNumElementsPerStride().getType().isIndex())
return emitOpError(
"expected stride and num elements per stride to be of type index");
}

return success();
}

Expand Down Expand Up @@ -1536,15 +1569,6 @@ ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
parser.resolveOperand(numElementsInfo, indexType, result.operands))
return failure();

auto memrefType = type.dyn_cast<MemRefType>();
if (!memrefType)
return parser.emitError(parser.getNameLoc(),
"expected tag to be of memref type");

if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType.getRank())
return parser.emitError(parser.getNameLoc(),
"tag memref rank not equal to indices count");

return success();
}

Expand All @@ -1554,6 +1578,32 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
return foldMemRefCast(*this);
}

LogicalResult DmaWaitOp::verify() {
// Mandatory non-variadic operands are tag and the number of elements.
if (getNumOperands() < 2)
return emitOpError() << "expected at least 2 operands";

// Check types of operands. The order of these calls is important: the later
// calls rely on some type properties to compute the operand position.
if (!getTagMemRef().getType().isa<MemRefType>())
return emitOpError() << "expected tag to be of memref type";

if (getNumOperands() != 2 + getTagMemRefRank())
return emitOpError() << "expected " << 2 + getTagMemRefRank()
<< " operands";

if (!getTagIndices().empty() &&
!llvm::all_of(getTagIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
return emitOpError() << "expected tag indices to be of index type";

if (!getNumElements().getType().isIndex())
return emitOpError()
<< "expected the number of elements to be of index type";

return success();
}

//===----------------------------------------------------------------------===//
// ExtractElementOp
//===----------------------------------------------------------------------===//
Expand Down
128 changes: 127 additions & 1 deletion mlir/test/IR/invalid-ops.mlir
Expand Up @@ -303,13 +303,38 @@ func @invalid_cmp_shape(%idx : () -> ()) {

// -----

func @dma_start_not_enough_operands() {
// expected-error@+1 {{expected at least 4 operands}}
"std.dma_start"() : () -> ()
}

// -----

func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
// expected-error@+1 {{expected source to be of memref type}}
dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32
}

// -----

func @dma_start_not_enough_operands_for_src(
%src: memref<2x2x2xf32>, %idx: index) {
// expected-error@+1 {{expected at least 7 operands}}
"std.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> ()
}

// -----

func @dma_start_src_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected source indices to be of index type}}
"std.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx)
: (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref<i32,2>, index) -> ()
}

// -----

func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
%mref = alloc() : memref<8 x f32>
// expected-error@+1 {{expected destination to be of memref type}}
Expand All @@ -318,6 +343,36 @@ func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {

// -----

func @dma_start_not_enough_operands_for_dst(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>) {
// expected-error@+1 {{expected at least 7 operands}}
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> ()
}

// -----

func @dma_start_dst_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected destination indices to be of index type}}
"std.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref<i32,2>, index) -> ()
}

// -----

func @dma_start_dst_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected num elements to be of index type}}
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref<i32,2>) -> ()
}

// -----

func @dma_no_tag_memref(%tag : f32, %c0 : index) {
%mref = alloc() : memref<8 x f32>
// expected-error@+1 {{expected tag to be of memref type}}
Expand All @@ -326,9 +381,80 @@ func @dma_no_tag_memref(%tag : f32, %c0 : index) {

// -----

func @dma_start_not_enough_operands_for_tag(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<2xi32,2>) {
// expected-error@+1 {{expected at least 8 operands}}
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> ()
}

// -----

func @dma_start_dst_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<2xi32,2>, %flt: f32) {
// expected-error@+1 {{expected tag indices to be of index type}}
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> ()
}

// -----

func @dma_start_same_space(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32>,
%tag: memref<i32,2>) {
// expected-error@+1 {{DMA should be between different memory spaces}}
dma_start %src[%idx, %idx], %dst[%idx], %idx, %tag[] : memref<2x2xf32>, memref<2xf32>, memref<i32,2>
}

// -----

func @dma_start_too_many_operands(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>) {
// expected-error@+1 {{incorrect number of operands}}
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, index, index) -> ()
}


// -----

func @dma_start_wrong_stride_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected stride and num elements per stride to be of type index}}
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, f32) -> ()
}

// -----

func @dma_wait_not_enough_operands() {
// expected-error@+1 {{expected at least 2 operands}}
"std.dma_wait"() : () -> ()
}

// -----

func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) {
// expected-error@+1 {{expected tag to be of memref type}}
dma_wait %tag[%c0], %arg0 : f32
"std.dma_wait"(%tag, %c0, %c0) : (f32, index, index) -> ()
}

// -----

func @dma_wait_wrong_index_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
// expected-error@+1 {{expected tag indices to be of index type}}
"std.dma_wait"(%tag, %flt, %idx) : (memref<2xi32>, f32, index) -> ()
}

// -----

func @dma_wait_wrong_num_elements_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
// expected-error@+1 {{expected the number of elements to be of index type}}
"std.dma_wait"(%tag, %idx, %flt) : (memref<2xi32>, index, f32) -> ()
}

// -----
Expand Down

0 comments on commit 9d273c0

Please sign in to comment.