diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index e612d6a0ba886..4cd703942036f 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -802,8 +802,8 @@ def CIR_ConditionOp : CIR_Op<"condition", [ //===----------------------------------------------------------------------===// defvar CIR_YieldableScopes = [ - "ArrayCtor", "ArrayDtor", "CaseOp", "DoWhileOp", "ForOp", "GlobalOp", "IfOp", - "ScopeOp", "SwitchOp", "TernaryOp", "WhileOp", "TryOp" + "ArrayCtor", "ArrayDtor", "AwaitOp", "CaseOp", "DoWhileOp", "ForOp", + "GlobalOp", "IfOp", "ScopeOp", "SwitchOp", "TernaryOp", "WhileOp", "TryOp" ]; def CIR_YieldOp : CIR_Op<"yield", [ @@ -2752,6 +2752,100 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> { ]; } +//===----------------------------------------------------------------------===// +// AwaitOp +//===----------------------------------------------------------------------===// + +def CIR_AwaitKind : CIR_I32EnumAttr<"AwaitKind", "await kind", [ + I32EnumAttrCase<"Init", 0, "init">, + I32EnumAttrCase<"User", 1, "user">, + I32EnumAttrCase<"Yield", 2, "yield">, + I32EnumAttrCase<"Final", 3, "final"> +]>; + +def CIR_AwaitOp : CIR_Op<"await",[ + DeclareOpInterfaceMethods, + RecursivelySpeculatable, NoRegionArguments +]> { + let summary = "Wraps C++ co_await implicit logic"; + let description = [{ + The under the hood effect of using C++ `co_await expr` roughly + translates to: + + ```c++ + // co_await expr; + + auto &&x = CommonExpr(); + if (!x.await_ready()) { + ... + x.await_suspend(...); + ... + } + x.await_resume(); + ``` + + `cir.await` represents this logic by using 3 regions: + - ready: covers veto power from x.await_ready() + - suspend: wraps actual x.await_suspend() logic + - resume: handles x.await_resume() + + Breaking this up in regions allows individual scrutiny of conditions + which might lead to folding some of them out. Lowerings coming out + of CIR, e.g. LLVM, should use the `suspend` region to track more + lower level codegen (e.g. intrinsic emission for coro.save/coro.suspend). + + There are also 4 flavors of `cir.await` available: + - `init`: compiler generated initial suspend via implicit `co_await`. + - `user`: also known as normal, representing a user written `co_await`. + - `yield`: user written `co_yield` expressions. + - `final`: compiler generated final suspend via implicit `co_await`. + + ```mlir + cir.scope { + ... // auto &&x = CommonExpr(); + cir.await(user, ready : { + ... // x.await_ready() + }, suspend : { + ... // x.await_suspend() + }, resume : { + ... // x.await_resume() + }) + } + ``` + + Note that resulution of the common expression is assumed to happen + as part of the enclosing await scope. + }]; + + let arguments = (ins CIR_AwaitKind:$kind); + let regions = (region SizedRegion<1>:$ready, + SizedRegion<1>:$suspend, + SizedRegion<1>:$resume); + let assemblyFormat = [{ + `(` $kind `,` + `ready` `:` $ready `,` + `suspend` `:` $suspend `,` + `resume` `:` $resume `,` + `)` + attr-dict + }]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins + "cir::AwaitKind":$kind, + CArg<"BuilderCallbackRef", + "nullptr">:$readyBuilder, + CArg<"BuilderCallbackRef", + "nullptr">:$suspendBuilder, + CArg<"BuilderCallbackRef", + "nullptr">:$resumeBuilder + )> + ]; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp b/clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp index 05fb1aedcbf4a..bb55991d9366a 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp @@ -22,6 +22,10 @@ using namespace clang; using namespace clang::CIRGen; struct clang::CIRGen::CGCoroData { + // What is the current await expression kind and how many + // await/yield expressions were encountered so far. + // These are used to generate pretty labels for await expressions in LLVM IR. + cir::AwaitKind currentAwaitKind = cir::AwaitKind::Init; // Stores the __builtin_coro_id emitted in the function so that we can supply // it as the first argument to other builtins. cir::CallOp coroId = nullptr; @@ -249,7 +253,114 @@ CIRGenFunction::emitCoroutineBody(const CoroutineBodyStmt &s) { emitAnyExprToMem(s.getReturnValue(), returnValue, s.getReturnValue()->getType().getQualifiers(), /*isInit*/ true); + + assert(!cir::MissingFeatures::ehCleanupScope()); + // FIXME(cir): EHStack.pushCleanup(EHCleanup); + curCoro.data->currentAwaitKind = cir::AwaitKind::Init; + if (emitStmt(s.getInitSuspendStmt(), /*useCurrentScope=*/true).failed()) + return mlir::failure(); assert(!cir::MissingFeatures::emitBodyAndFallthrough()); } return mlir::success(); } +// Given a suspend expression which roughly looks like: +// +// auto && x = CommonExpr(); +// if (!x.await_ready()) { +// x.await_suspend(...); (*) +// } +// x.await_resume(); +// +// where the result of the entire expression is the result of x.await_resume() +// +// (*) If x.await_suspend return type is bool, it allows to veto a suspend: +// if (x.await_suspend(...)) +// llvm_coro_suspend(); +// +// This is more higher level than LLVM codegen, for that one see llvm's +// docs/Coroutines.rst for more details. +namespace { +struct LValueOrRValue { + LValue lv; + RValue rv; +}; +} // namespace + +static LValueOrRValue +emitSuspendExpression(CIRGenFunction &cgf, CGCoroData &coro, + CoroutineSuspendExpr const &s, cir::AwaitKind kind, + AggValueSlot aggSlot, bool ignoreResult, + mlir::Block *scopeParentBlock, + mlir::Value &tmpResumeRValAddr, bool forLValue) { + mlir::LogicalResult awaitBuild = mlir::success(); + LValueOrRValue awaitRes; + + CIRGenFunction::OpaqueValueMapping binder = + CIRGenFunction::OpaqueValueMapping(cgf, s.getOpaqueValue()); + CIRGenBuilderTy &builder = cgf.getBuilder(); + [[maybe_unused]] cir::AwaitOp awaitOp = cir::AwaitOp::create( + builder, cgf.getLoc(s.getSourceRange()), kind, + /*readyBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + builder.createCondition( + cgf.createDummyValue(loc, cgf.getContext().BoolTy)); + }, + /*suspendBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + cir::YieldOp::create(builder, loc); + }, + /*resumeBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + cir::YieldOp::create(builder, loc); + }); + + assert(awaitBuild.succeeded() && "Should know how to codegen"); + return awaitRes; +} + +static RValue emitSuspendExpr(CIRGenFunction &cgf, + const CoroutineSuspendExpr &e, + cir::AwaitKind kind, AggValueSlot aggSlot, + bool ignoreResult) { + RValue rval; + mlir::Location scopeLoc = cgf.getLoc(e.getSourceRange()); + + // Since we model suspend / resume as an inner region, we must store + // resume scalar results in a tmp alloca, and load it after we build the + // suspend expression. An alternative way to do this would be to make + // every region return a value when promise.return_value() is used, but + // it's a bit awkward given that resume is the only region that actually + // returns a value. + mlir::Block *currEntryBlock = cgf.curLexScope->getEntryBlock(); + [[maybe_unused]] mlir::Value tmpResumeRValAddr; + + // No need to explicitly wrap this into a scope since the AST already uses a + // ExprWithCleanups, which will wrap this into a cir.scope anyways. + rval = emitSuspendExpression(cgf, *cgf.curCoro.data, e, kind, aggSlot, + ignoreResult, currEntryBlock, tmpResumeRValAddr, + /*forLValue*/ false) + .rv; + + if (ignoreResult || rval.isIgnored()) + return rval; + + if (rval.isScalar()) { + rval = RValue::get(cir::LoadOp::create(cgf.getBuilder(), scopeLoc, + rval.getValue().getType(), + tmpResumeRValAddr)); + } else if (rval.isAggregate()) { + // This is probably already handled via AggSlot, remove this assertion + // once we have a testcase and prove all pieces work. + cgf.cgm.errorNYI("emitSuspendExpr Aggregate"); + } else { // complex + cgf.cgm.errorNYI("emitSuspendExpr Complex"); + } + return rval; +} + +RValue CIRGenFunction::emitCoawaitExpr(const CoawaitExpr &e, + AggValueSlot aggSlot, + bool ignoreResult) { + return emitSuspendExpr(*this, e, curCoro.data->currentAwaitKind, aggSlot, + ignoreResult); +} diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index f777562ba6309..b90caccf35c7d 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -156,6 +156,10 @@ class ScalarExprEmitter : public StmtVisitor { return cgf.emitLoadOfLValue(lv, e->getExprLoc()).getValue(); } + mlir::Value VisitCoawaitExpr(CoawaitExpr *s) { + return cgf.emitCoawaitExpr(*s).getValue(); + } + mlir::Value emitLoadOfLValue(LValue lv, SourceLocation loc) { return cgf.emitLoadOfLValue(lv, loc).getValue(); } diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index b22bf2d87fc10..b426f3389ff1b 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -1576,6 +1576,9 @@ class CIRGenFunction : public CIRGenTypeCache { void emitForwardingCallToLambda(const CXXMethodDecl *lambdaCallOperator, CallArgList &callArgs); + RValue emitCoawaitExpr(const CoawaitExpr &e, + AggValueSlot aggSlot = AggValueSlot::ignored(), + bool ignoreResult = false); /// Emit the computation of the specified expression of complex type, /// returning the result. mlir::Value emitComplexExpr(const Expr *e); diff --git a/clang/lib/CIR/CodeGen/CIRGenValue.h b/clang/lib/CIR/CodeGen/CIRGenValue.h index 20a3d0ef61341..2002bd7e7c488 100644 --- a/clang/lib/CIR/CodeGen/CIRGenValue.h +++ b/clang/lib/CIR/CodeGen/CIRGenValue.h @@ -49,6 +49,7 @@ class RValue { bool isScalar() const { return flavor == Scalar; } bool isComplex() const { return flavor == Complex; } bool isAggregate() const { return flavor == Aggregate; } + bool isIgnored() const { return isScalar() && !getValue(); } bool isVolatileQualified() const { return isVolatile; } diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 22aada882defc..b5840ff12438e 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -289,7 +289,10 @@ void cir::ConditionOp::getSuccessorRegions( regions.emplace_back(getOperation(), loopOp->getResults()); } - assert(!cir::MissingFeatures::awaitOp()); + // Parent is an await: condition may branch to resume or suspend regions. + auto await = cast(getOperation()->getParentOp()); + regions.emplace_back(&await.getResume(), await.getResume().getArguments()); + regions.emplace_back(&await.getSuspend(), await.getSuspend().getArguments()); } MutableOperandRange @@ -299,8 +302,7 @@ cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) { } LogicalResult cir::ConditionOp::verify() { - assert(!cir::MissingFeatures::awaitOp()); - if (!isa(getOperation()->getParentOp())) + if (!isa(getOperation()->getParentOp())) return emitOpError("condition must be within a conditional region"); return success(); } @@ -1910,6 +1912,19 @@ void cir::FuncOp::print(OpAsmPrinter &p) { mlir::LogicalResult cir::FuncOp::verify() { + if (!isDeclaration() && getCoroutine()) { + bool foundAwait = false; + this->walk([&](Operation *op) { + if (auto await = dyn_cast(op)) { + foundAwait = true; + return; + } + }); + if (!foundAwait) + return emitOpError() + << "coroutine body must use at least one cir.await op"; + } + llvm::SmallSet labels; llvm::SmallSet gotos; llvm::SmallSet blockAddresses; @@ -2149,6 +2164,61 @@ OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) { return {}; } +//===----------------------------------------------------------------------===// +// AwaitOp +//===----------------------------------------------------------------------===// + +void cir::AwaitOp::build(OpBuilder &builder, OperationState &result, + cir::AwaitKind kind, BuilderCallbackRef readyBuilder, + BuilderCallbackRef suspendBuilder, + BuilderCallbackRef resumeBuilder) { + result.addAttribute(getKindAttrName(result.name), + cir::AwaitKindAttr::get(builder.getContext(), kind)); + { + OpBuilder::InsertionGuard guard(builder); + Region *readyRegion = result.addRegion(); + builder.createBlock(readyRegion); + readyBuilder(builder, result.location); + } + + { + OpBuilder::InsertionGuard guard(builder); + Region *suspendRegion = result.addRegion(); + builder.createBlock(suspendRegion); + suspendBuilder(builder, result.location); + } + + { + OpBuilder::InsertionGuard guard(builder); + Region *resumeRegion = result.addRegion(); + builder.createBlock(resumeRegion); + resumeBuilder(builder, result.location); + } +} + +void cir::AwaitOp::getSuccessorRegions( + mlir::RegionBranchPoint point, SmallVectorImpl ®ions) { + // If any index all the underlying regions branch back to the parent + // operation. + if (!point.isParent()) { + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); + return; + } + + // TODO: retrieve information from the promise and only push the + // necessary ones. Example: `std::suspend_never` on initial or final + // await's might allow suspend region to be skipped. + regions.push_back(RegionSuccessor(&this->getReady())); + regions.push_back(RegionSuccessor(&this->getSuspend())); + regions.push_back(RegionSuccessor(&this->getResume())); +} + +LogicalResult cir::AwaitOp::verify() { + if (!isa(this->getReady().back().getTerminator())) + return emitOpError("ready region must end with cir.condition"); + return success(); +} //===----------------------------------------------------------------------===// // CopyOp Definitions diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 4912bd197dba4..28fc1c7a7e068 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -3841,6 +3841,12 @@ mlir::LogicalResult CIRToLLVMBlockAddressOpLowering::matchAndRewrite( return mlir::failure(); } +mlir::LogicalResult CIRToLLVMAwaitOpLowering::matchAndRewrite( + cir::AwaitOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + return mlir::failure(); +} + std::unique_ptr createConvertCIRToLLVMPass() { return std::make_unique(); } diff --git a/clang/lib/CodeGen/CGValue.h b/clang/lib/CodeGen/CGValue.h index c4ec8d207d2e3..6b381b59e71cd 100644 --- a/clang/lib/CodeGen/CGValue.h +++ b/clang/lib/CodeGen/CGValue.h @@ -64,6 +64,7 @@ class RValue { bool isScalar() const { return Flavor == Scalar; } bool isComplex() const { return Flavor == Complex; } bool isAggregate() const { return Flavor == Aggregate; } + bool isIgnored() const { return isScalar() && !getScalarVal(); } bool isVolatileQualified() const { return IsVolatile; } diff --git a/clang/test/CIR/CodeGen/coro-task.cpp b/clang/test/CIR/CodeGen/coro-task.cpp index 5738c815909ea..01e0786fbda71 100644 --- a/clang/test/CIR/CodeGen/coro-task.cpp +++ b/clang/test/CIR/CodeGen/coro-task.cpp @@ -111,6 +111,8 @@ co_invoke_fn co_invoke; // CIR-DAG: ![[VoidPromisse:.*]] = !cir.record::promise_type" padded {!u8i}> // CIR-DAG: ![[IntPromisse:.*]] = !cir.record::promise_type" padded {!u8i}> // CIR-DAG: ![[StdString:.*]] = !cir.record +// CIR-DAG: ![[SuspendAlways:.*]] = !cir.record + // CIR: module {{.*}} { // CIR-NEXT: cir.global external @_ZN5folly4coro9co_invokeE = #cir.zero : !rec_folly3A3Acoro3A3Aco_invoke_fn @@ -153,6 +155,33 @@ VoidTask silly_task() { // CIR: %[[RetObj:.*]] = cir.call @_ZN5folly4coro4TaskIvE12promise_type17get_return_objectEv(%[[VoidPromisseAddr]]) nothrow : {{.*}} -> ![[VoidTask]] // CIR: cir.store{{.*}} %[[RetObj]], %[[VoidTaskAddr]] : ![[VoidTask]] +// Start a new scope for the actual codegen for co_await, create temporary allocas for +// holding coroutine handle and the suspend_always struct. + +// CIR: cir.scope { +// CIR: %[[SuspendAlwaysAddr:.*]] = cir.alloca ![[SuspendAlways]], {{.*}} ["ref.tmp0"] {alignment = 1 : i64} + +// Effectively execute `coawait promise_type::initial_suspend()` by calling initial_suspend() and getting +// the suspend_always struct to use for cir.await. Note that we return by-value since we defer ABI lowering +// to later passes, same is done elsewhere. + +// CIR: %[[Tmp0:.*]] = cir.call @_ZN5folly4coro4TaskIvE12promise_type15initial_suspendEv(%[[VoidPromisseAddr]]) +// CIR: cir.store{{.*}} %[[Tmp0:.*]], %[[SuspendAlwaysAddr]] + +// +// Here we start mapping co_await to cir.await. +// + +// First regions `ready` has a special cir.yield code to veto suspension. + +// CIR: cir.await(init, ready : { +// CIR: cir.condition({{.*}}) +// CIR: }, suspend : { +// CIR: cir.yield +// CIR: }, resume : { +// CIR: cir.yield +// CIR: },) +// CIR: } folly::coro::Task byRef(const std::string& s) { co_return s.size(); @@ -172,3 +201,13 @@ folly::coro::Task byRef(const std::string& s) { // CIR: cir.store {{.*}} %[[LOAD]], %[[AllocaFnUse]] : !cir.ptr, !cir.ptr> // CIR: %[[RetObj:.*]] = cir.call @_ZN5folly4coro4TaskIiE12promise_type17get_return_objectEv(%4) nothrow : {{.*}} -> ![[IntTask]] // CIR: cir.store {{.*}} %[[RetObj]], %[[IntTaskAddr]] : ![[IntTask]] +// CIR: cir.scope { +// CIR: %[[SuspendAlwaysAddr:.*]] = cir.alloca ![[SuspendAlways]], {{.*}} ["ref.tmp0"] {alignment = 1 : i64} +// CIR: %[[Tmp0:.*]] = cir.call @_ZN5folly4coro4TaskIiE12promise_type15initial_suspendEv(%[[IntPromisseAddr]]) +// CIR: cir.await(init, ready : { +// CIR: cir.condition({{.*}}) +// CIR: }, suspend : { +// CIR: cir.yield +// CIR: }, resume : { +// CIR: cir.yield +// CIR: },) diff --git a/clang/test/CIR/IR/await.cir b/clang/test/CIR/IR/await.cir new file mode 100644 index 0000000000000..c1fb0d6d7c57c --- /dev/null +++ b/clang/test/CIR/IR/await.cir @@ -0,0 +1,21 @@ +// RUN: cir-opt %s --verify-roundtrip | FileCheck %s + +cir.func coroutine @checkPrintParse(%arg0 : !cir.bool) { + cir.await(user, ready : { + cir.condition(%arg0) + }, suspend : { + cir.yield + }, resume : { + cir.yield + },) + cir.return +} + +// CHECK: cir.func coroutine @checkPrintParse +// CHECK: cir.await(user, ready : { +// CHECK: cir.condition(%arg0) +// CHECK: }, suspend : { +// CHECK: cir.yield +// CHECK: }, resume : { +// CHECK: cir.yield +// CHECK: },) diff --git a/clang/test/CIR/IR/func.cir b/clang/test/CIR/IR/func.cir index 6e91898a3b452..4e79149315a9d 100644 --- a/clang/test/CIR/IR/func.cir +++ b/clang/test/CIR/IR/func.cir @@ -101,6 +101,15 @@ cir.func @ullfunc() -> !u64i { // CHECK: } cir.func coroutine @coro() { + cir.await(init, ready : { + %0 = cir.alloca !cir.bool, !cir.ptr, [""] {alignment = 1 : i64} + %1 = cir.load align(1) %0 : !cir.ptr, !cir.bool + cir.condition(%1) + }, suspend : { + cir.yield + }, resume : { + cir.yield + },) cir.return } // CHECK: cir.func{{.*}} coroutine @coro() diff --git a/clang/test/CIR/IR/invalid-await.cir b/clang/test/CIR/IR/invalid-await.cir new file mode 100644 index 0000000000000..5f422ca7e954e --- /dev/null +++ b/clang/test/CIR/IR/invalid-await.cir @@ -0,0 +1,19 @@ +// RUN: cir-opt %s -verify-diagnostics -split-input-file +cir.func coroutine @bad_task() { // expected-error {{coroutine body must use at least one cir.await op}} + cir.return +} + +// ----- + +cir.func coroutine @missing_condition() { + cir.scope { + cir.await(user, ready : { // expected-error {{ready region must end with cir.condition}} + cir.yield + }, suspend : { + cir.yield + }, resume : { + cir.yield + },) + } + cir.return +}