Skip to content

Commit

Permalink
Optionally eliminate blocking runtime.await calls by converting funct…
Browse files Browse the repository at this point in the history
…ions to coroutines.

Interop parallelism requires needs awaiting on results. Blocking awaits are bad for performance. TFRT supports lightweight resumption on threads, and coroutines are an abstraction than can be used to lower the kernels onto TFRT threads.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D106508
  • Loading branch information
bakhtiyar authored and ezhulenev committed Jul 28, 2021
1 parent 0f4b41e commit 6ea22d4
Show file tree
Hide file tree
Showing 3 changed files with 467 additions and 17 deletions.
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Async/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
let summary = "Lower high level async operations (e.g. async.execute) to the"
"explicit async.runtime and async.coro operations";
let constructor = "mlir::createAsyncToAsyncRuntimePass()";
let options = [
// Temporary for bringup, should become the default.
Option<"eliminateBlockingAwaitOps", "eliminate-blocking-await-ops", "bool",
/*default=*/"false",
"Rewrite functions with blocking async.runtime.await as coroutines "
"with async.runtime.await_and_resume.">
];
let dependentDialects = ["async::AsyncDialect"];
}

Expand Down
173 changes: 156 additions & 17 deletions mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,18 @@ struct CoroMachinery {
};
} // namespace

/// Builds an coroutine template compatible with LLVM coroutines switched-resume
/// lowering using `async.runtime.*` and `async.coro.*` operations.
/// Utility to partially update the regular function CFG to the coroutine CFG
/// compatible with LLVM coroutines switched-resume lowering using
/// `async.runtime.*` and `async.coro.*` operations. Modifies the entry block
/// by prepending its ops with coroutine setup. Also inserts trailing blocks.
///
/// The result types of the passed `func` must start with an `async.token`
/// and be continued with some number of `async.value`s.
///
/// It's up to the caller of this function to fix up the terminators of the
/// preexisting blocks of the passed func op. If the passed `func` is legal,
/// this typically means rewriting every return op as a yield op and a branch op
/// to the suspend block.
///
/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
///
Expand All @@ -87,15 +97,16 @@ struct CoroMachinery {
///
/// Coroutine structure (only the important bits):
///
/// func @async_execute_fn(<function-arguments>)
/// -> (!async.token, !async.value<T>)
/// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
/// {
/// ^entry(<function-arguments>):
/// %token = <async token> : !async.token // create async runtime token
/// %value = <async value> : !async.value<T> // create async value
/// %id = async.coro.id // create a coroutine id
/// %hdl = async.coro.begin %id // create a coroutine handle
/// br ^cleanup
/// /* other ops of the preexisting entry block */
///
/// /* other preexisting blocks */
///
/// ^set_error: // this block created lazily only if needed (see code below)
/// async.runtime.set_error %token : !async.token
Expand All @@ -111,16 +122,11 @@ struct CoroMachinery {
/// return %token, %value : !async.token, !async.value<T>
/// }
///
/// The actual code for the async.execute operation body region will be inserted
/// before the entry block terminator.
///
///
static CoroMachinery setupCoroMachinery(FuncOp func) {
assert(func.getBody().empty() && "Function must have empty body");
assert(!func.getBlocks().empty() && "Function must have an entry block");

MLIRContext *ctx = func.getContext();
Block *entryBlock = func.addEntryBlock();

Block *entryBlock = &func.getBlocks().front();
auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);

// ------------------------------------------------------------------------ //
Expand Down Expand Up @@ -166,10 +172,6 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
ret.insert(ret.end(), retValues.begin(), retValues.end());
builder.create<ReturnOp>(ret);

// Branch from the entry block to the cleanup block to create a valid CFG.
builder.setInsertionPointToEnd(entryBlock);
builder.create<BranchOp>(cleanupBlock);

// `async.await` op lowering will create resume blocks for async
// continuations, and will conditionally branch to cleanup or suspend blocks.

Expand Down Expand Up @@ -242,20 +244,22 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {

// Prepare a function for coroutine lowering by adding entry/cleanup/suspend
// blocks, adding async.coro operations and setting up control flow.
func.addEntryBlock();
CoroMachinery coro = setupCoroMachinery(func);

// Suspend async function at the end of an entry block, and resume it using
// Async resume operation (execution will be resumed in a thread managed by
// the async runtime).
Block *entryBlock = &func.getBlocks().front();
auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock);
auto builder = ImplicitLocOpBuilder::atBlockEnd(loc, entryBlock);

// Save the coroutine state: async.coro.save
auto coroSaveOp =
builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);

// Pass coroutine to the runtime to be resumed on a runtime managed thread.
builder.create<RuntimeResumeOp>(coro.coroHandle);
builder.create<BranchOp>(coro.cleanup);

// Split the entry block before the terminator (branch to suspend block).
auto *terminatorOp = entryBlock->getTerminator();
Expand Down Expand Up @@ -557,6 +561,132 @@ class AssertOpLowering : public OpConversionPattern<AssertOp> {

//===----------------------------------------------------------------------===//

/// Rewrite a func as a coroutine by:
/// 1) Wrapping the results into `async.value`.
/// 2) Prepending the results with `async.token`.
/// 3) Setting up coroutine blocks.
/// 4) Rewriting return ops as yield op and branch op into the suspend block.
static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) {
auto *ctx = func->getContext();
auto loc = func.getLoc();
SmallVector<Type> resultTypes;
resultTypes.reserve(func.getCallableResults().size());
llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes),
[](Type type) { return ValueType::get(type); });
func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes));
func.insertResult(0, TokenType::get(ctx), {});
CoroMachinery coro = setupCoroMachinery(func);
for (Block &block : func.getBlocks()) {
if (&block == coro.suspend)
continue;

Operation *terminator = block.getTerminator();
if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) {
ImplicitLocOpBuilder builder(loc, returnOp);
builder.create<YieldOp>(returnOp.getOperands());
builder.create<BranchOp>(coro.cleanup);
returnOp.erase();
}
}
return coro;
}

/// Rewrites a call into a function that has been rewritten as a coroutine.
///
/// The invocation of this function is safe only when call ops are traversed in
/// reverse order of how they appear in a single block. See `funcsToCoroutines`.
static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) {
auto loc = func.getLoc();
ImplicitLocOpBuilder callBuilder(loc, oldCall);
auto newCall = callBuilder.create<CallOp>(
func.getName(), func.getCallableResults(), oldCall.getArgOperands());

// Await on the async token and all the value results and unwrap the latter.
callBuilder.create<AwaitOp>(loc, newCall.getResults().front());
SmallVector<Value> unwrappedResults;
unwrappedResults.reserve(newCall->getResults().size() - 1);
for (Value result : newCall.getResults().drop_front())
unwrappedResults.push_back(
callBuilder.create<AwaitOp>(loc, result).result());
// Careful, when result of a call is piped into another call this could lead
// to a dangling pointer.
oldCall.replaceAllUsesWith(unwrappedResults);
oldCall.erase();
}

static LogicalResult
funcsToCoroutines(ModuleOp module,
llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) {
// The following code supports the general case when 2 functions mutually
// recurse into each other. Because of this and that we are relying on
// SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase
// a FuncOp while inserting an equivalent coroutine, because that could lead
// to dangling pointers.

SmallVector<FuncOp> funcWorklist;

// Careful, it's okay to add a func to the worklist multiple times if and only
// if the loop processing the worklist will skip the functions that have
// already been converted to coroutines.
auto addToWorklist = [&outlinedFunctions, &funcWorklist](FuncOp func) {
// N.B. To refactor this code into a separate pass the lookup in
// outlinedFunctions is the most obvious obstacle. Looking at an arbitrary
// func and recognizing if it has a coroutine structure is messy. Passing
// this dict between the passes is ugly.
if (outlinedFunctions.find(func) == outlinedFunctions.end()) {
for (Operation &op : func.body().getOps()) {
if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
funcWorklist.push_back(func);
break;
}
}
}
};

// Traverse in post-order collecting for each func op the await ops it has.
for (FuncOp func : module.getOps<FuncOp>())
addToWorklist(func);

SymbolTableCollection symbolTable;
SymbolUserMap symbolUserMap(symbolTable, module);

// Rewrite funcs, while updating call sites and adding them to the worklist.
while (!funcWorklist.empty()) {
auto func = funcWorklist.pop_back_val();
auto insertion = outlinedFunctions.insert({func, CoroMachinery{}});
if (!insertion.second)
// This function has already been processed because this is either
// the corecursive case, or a caller with multiple calls to a newly
// created corouting. Either way, skip updating the call sites.
continue;
insertion.first->second = rewriteFuncAsCoroutine(func);
SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(),
symbolUserMap.getUsers(func).end());
// If there are multiple calls from the same block they need to be traversed
// in reverse order so that symbolUserMap references are not invalidated
// when updating the users of the call op which is earlier in the block.
llvm::sort(users, [](Operation *a, Operation *b) {
Block *blockA = a->getBlock();
Block *blockB = b->getBlock();
// Impose arbitrary order on blocks so that there is a well-defined order.
return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b));
});
// Rewrite the callsites to await on results of the newly created coroutine.
for (Operation *op : users) {
if (CallOp call = dyn_cast<mlir::CallOp>(*op)) {
FuncOp caller = call->getParentOfType<FuncOp>();
rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op.
addToWorklist(caller);
} else {
op->emitError("Unexpected reference to func referenced by symbol");
return failure();
}
}
}
return success();
}

//===----------------------------------------------------------------------===//
void AsyncToAsyncRuntimePass::runOnOperation() {
ModuleOp module = getOperation();
SymbolTable symbolTable(module);
Expand All @@ -579,6 +709,12 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
};

if (eliminateBlockingAwaitOps &&
failed(funcsToCoroutines(module, outlinedFunctions))) {
signalPassFailure();
return;
}

// Lower async operations to async.runtime operations.
MLIRContext *ctx = module->getContext();
RewritePatternSet asyncPatterns(ctx);
Expand Down Expand Up @@ -622,6 +758,9 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
return outlinedFunctions.find(func) == outlinedFunctions.end();
});

if (eliminateBlockingAwaitOps)
runtimeTarget.addIllegalOp<RuntimeAwaitOp>();

if (failed(applyPartialConversion(module, runtimeTarget,
std::move(asyncPatterns)))) {
signalPassFailure();
Expand Down
Loading

0 comments on commit 6ea22d4

Please sign in to comment.