Skip to content

Commit

Permalink
[mlir][scf][bufferize] Fix bug in WhileOp analysis verification
Browse files Browse the repository at this point in the history
Block arguments and yielded values are not equivalent if there are not enough block arguments. This fixes #59442.

Differential Revision: https://reviews.llvm.org/D145575
  • Loading branch information
matthias-springer committed May 15, 2023
1 parent bb9d1b5 commit ae8cb64
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
12 changes: 8 additions & 4 deletions mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
Expand Up @@ -902,21 +902,25 @@ struct WhileOpInterface

auto conditionOp = whileOp.getConditionOp();
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
Block *block = conditionOp->getBlock();
if (!isa<TensorType>(it.value().getType()))
continue;
if (!state.areEquivalentBufferizedValues(
it.value(), conditionOp->getBlock()->getArgument(it.index())))
if (it.index() >= block->getNumArguments() ||
!state.areEquivalentBufferizedValues(it.value(),
block->getArgument(it.index())))
return conditionOp->emitError()
<< "Condition arg #" << it.index()
<< " is not equivalent to the corresponding iter bbArg";
}

auto yieldOp = whileOp.getYieldOp();
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
Block *block = yieldOp->getBlock();
if (!isa<TensorType>(it.value().getType()))
continue;
if (!state.areEquivalentBufferizedValues(
it.value(), yieldOp->getBlock()->getArgument(it.index())))
if (it.index() >= block->getNumArguments() ||
!state.areEquivalentBufferizedValues(it.value(),
block->getArgument(it.index())))
return yieldOp->emitError()
<< "Yield operand #" << it.index()
<< " is not equivalent to the corresponding iter bbArg";
Expand Down
Expand Up @@ -324,3 +324,17 @@ func.func @copy_of_unranked_tensor(%t: tensor<*xf32>) -> tensor<*xf32> {

// This function may write to buffer(%ptr).
func.func private @maybe_writing_func(%ptr : tensor<*xf32>)

// -----

func.func @regression_scf_while() {
%false = arith.constant false
%8 = bufferization.alloc_tensor() : tensor<10x10xf32>
scf.while (%arg0 = %8) : (tensor<10x10xf32>) -> () {
scf.condition(%false)
} do {
// expected-error @+1 {{Yield operand #0 is not equivalent to the corresponding iter bbArg}}
scf.yield %8 : tensor<10x10xf32>
}
return
}

0 comments on commit ae8cb64

Please sign in to comment.