diff --git a/clang/include/clang/AST/StmtCXX.h b/clang/include/clang/AST/StmtCXX.h index 288ffc28f40ee..60fc3f3a63f49 100644 --- a/clang/include/clang/AST/StmtCXX.h +++ b/clang/include/clang/AST/StmtCXX.h @@ -375,9 +375,10 @@ class CoroutineBodyStmt final } /// Retrieve the body of the coroutine as written. This will be either - /// a CompoundStmt or a TryStmt. - Stmt *getBody() const { - return getStoredStmts()[SubStmt::Body]; + /// a CompoundStmt. If the coroutine is in function-try-block, we will + /// wrap the CXXTryStmt into a CompoundStmt to keep consistency. + CompoundStmt *getBody() const { + return cast(getStoredStmts()[SubStmt::Body]); } Stmt *getPromiseDeclStmt() const { diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp index 90ab82ed88a11..da3da5e600104 100644 --- a/clang/lib/CodeGen/CGCoroutine.cpp +++ b/clang/lib/CodeGen/CGCoroutine.cpp @@ -593,18 +593,6 @@ static void emitBodyAndFallthrough(CodeGenFunction &CGF, CGF.EmitStmt(OnFallthrough); } -static CompoundStmt *CoroutineStmtBuilder(ASTContext &Context, - const CoroutineBodyStmt &S) { - Stmt *Stmt = S.getBody(); - if (CompoundStmt *Body = dyn_cast(Stmt)) - return Body; - // We are about to create a `CXXTryStmt` which requires a `CompoundStmt`. - // If the function body is not a `CompoundStmt` yet then we have to create - // a new one. This happens for cases like the "function-try-block" syntax. - return CompoundStmt::Create(Context, {Stmt}, FPOptionsOverride(), - SourceLocation(), SourceLocation()); -} - void CodeGenFunction::EmitCoroutineBody(const CoroutineBodyStmt &S) { auto *NullPtr = llvm::ConstantPointerNull::get(Builder.getInt8PtrTy()); auto &TI = CGM.getContext().getTargetInfo(); @@ -733,8 +721,8 @@ void CodeGenFunction::EmitCoroutineBody(const CoroutineBodyStmt &S) { auto Loc = S.getBeginLoc(); CXXCatchStmt Catch(Loc, /*exDecl=*/nullptr, CurCoro.Data->ExceptionHandler); - CompoundStmt *Body = CoroutineStmtBuilder(getContext(), S); - auto *TryStmt = CXXTryStmt::Create(getContext(), Loc, Body, &Catch); + auto *TryStmt = + CXXTryStmt::Create(getContext(), Loc, S.getBody(), &Catch); EnterCXXTryStmt(*TryStmt); emitBodyAndFallthrough(*this, S, TryStmt->getTryBlock()); diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp index e87f2a78e2394..deb67337a2aec 100644 --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -1137,6 +1137,18 @@ void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { Body = CoroutineBodyStmt::Create(Context, Builder); } +static CompoundStmt *buildCoroutineBody(Stmt *Body, ASTContext &Context) { + if (auto *CS = dyn_cast(Body)) + return CS; + + // The body of the coroutine may be a try statement if it is in + // 'function-try-block' syntax. Here we wrap it into a compound + // statement for consistency. + assert(isa(Body) && "Unimaged coroutine body type"); + return CompoundStmt::Create(Context, {Body}, FPOptionsOverride(), + SourceLocation(), SourceLocation()); +} + CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD, sema::FunctionScopeInfo &Fn, Stmt *Body) @@ -1144,7 +1156,7 @@ CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD, IsPromiseDependentType( !Fn.CoroutinePromise || Fn.CoroutinePromise->getType()->isDependentType()) { - this->Body = Body; + this->Body = buildCoroutineBody(Body, S.getASTContext()); for (auto KV : Fn.CoroutineParameterMoves) this->ParamMovesVector.push_back(KV.second); diff --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp index e6d2bb46dfcd3..89954711804aa 100644 --- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp @@ -717,8 +717,8 @@ void coro() try { )cpp"; EXPECT_TRUE(matchesConditionally( CoroWithTryCatchDeclCode, - coroutineBodyStmt(hasBody(cxxTryStmt(has(compoundStmt(has( - declStmt(containsDeclaration(0, varDecl(hasName("thevar")))))))))), + coroutineBodyStmt(hasBody(compoundStmt(has(cxxTryStmt(has(compoundStmt(has( + declStmt(containsDeclaration(0, varDecl(hasName("thevar")))))))))))), true, {"-std=c++20", "-I/"}, M)); }