Skip to content

Commit

Permalink
[mlir] Fix ControlFlowInterfaces implementation for Async dialect
Browse files Browse the repository at this point in the history
* Add `RegionBranchTerminatorOpInterface` to `YieldOp`.
* Implement `getSuccessorEntryOperands` in `ExecuteOp`.
* Fix `getSuccessorRegions` implementation in `ExecuteOp`.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D108373
  • Loading branch information
Vladislav Vinogradov committed Aug 20, 2021
1 parent 119146f commit 9775c0c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
7 changes: 5 additions & 2 deletions mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class Async_Op<string mnemonic, list<OpTrait> traits = []> :
def Async_ExecuteOp :
Async_Op<"execute", [SingleBlockImplicitTerminator<"YieldOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getNumRegionInvocations"]>,
["getSuccessorEntryOperands",
"getNumRegionInvocations"]>,
AttrSizedOperandSegments]> {
let summary = "Asynchronous execute operation";
let description = [{
Expand Down Expand Up @@ -99,7 +100,9 @@ def Async_ExecuteOp :
}

def Async_YieldOp :
Async_Op<"yield", [HasParent<"ExecuteOp">, NoSideEffect, Terminator]> {
Async_Op<"yield", [
HasParent<"ExecuteOp">, NoSideEffect, Terminator,
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>]> {
let summary = "terminator for Async execute operation";
let description = [{
The `async.yield` is a special terminator operation for the block inside
Expand Down
22 changes: 16 additions & 6 deletions mlir/lib/Dialect/Async/IR/Async.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,31 +48,41 @@ static LogicalResult verify(YieldOp op) {
return success();
}

MutableOperandRange
YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) {
assert(!index.hasValue());
return operandsMutable();
}

//===----------------------------------------------------------------------===//
/// ExecuteOp
//===----------------------------------------------------------------------===//

constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";

void ExecuteOp::getNumRegionInvocations(
ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) {
(void)operands;
ArrayRef<Attribute>, SmallVectorImpl<int64_t> &countPerRegion) {
assert(countPerRegion.empty());
countPerRegion.push_back(1);
}

OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) {
assert(index == 0 && "invalid region index");
return operands();
}

void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
ArrayRef<Attribute> operands,
ArrayRef<Attribute>,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `body` region branch back to the parent operation.
if (index.hasValue()) {
assert(*index == 0);
regions.push_back(RegionSuccessor(getResults()));
assert(*index == 0 && "invalid region index");
regions.push_back(RegionSuccessor(results()));
return;
}

// Otherwise the successor is the body region.
regions.push_back(RegionSuccessor(&body()));
regions.push_back(RegionSuccessor(&body(), body().getArguments()));
}

void ExecuteOp::build(OpBuilder &builder, OperationState &result,
Expand Down

0 comments on commit 9775c0c

Please sign in to comment.