diff --git a/clang/include/clang/AST/Decl.h b/clang/include/clang/AST/Decl.h index 61117cc5ce71f..a5879591f4c65 100644 --- a/clang/include/clang/AST/Decl.h +++ b/clang/include/clang/AST/Decl.h @@ -4419,7 +4419,7 @@ class FileScopeAsmDecl : public Decl { /// /// \note This is used in libInterpreter, clang -cc1 -fincremental-extensions /// and in tools such as clang-repl. -class TopLevelStmtDecl : public Decl { +class TopLevelStmtDecl : public Decl, public DeclContext { friend class ASTDeclReader; friend class ASTDeclWriter; @@ -4427,7 +4427,7 @@ class TopLevelStmtDecl : public Decl { bool IsSemiMissing = false; TopLevelStmtDecl(DeclContext *DC, SourceLocation L, Stmt *S) - : Decl(TopLevelStmt, DC, L), Statement(S) {} + : Decl(TopLevelStmt, DC, L), DeclContext(TopLevelStmt), Statement(S) {} virtual void anchor(); @@ -4438,15 +4438,19 @@ class TopLevelStmtDecl : public Decl { SourceRange getSourceRange() const override LLVM_READONLY; Stmt *getStmt() { return Statement; } const Stmt *getStmt() const { return Statement; } - void setStmt(Stmt *S) { - assert(IsSemiMissing && "Operation supported for printing values only!"); - Statement = S; - } + void setStmt(Stmt *S); bool isSemiMissing() const { return IsSemiMissing; } void setSemiMissing(bool Missing = true) { IsSemiMissing = Missing; } static bool classof(const Decl *D) { return classofKind(D->getKind()); } static bool classofKind(Kind K) { return K == TopLevelStmt; } + + static DeclContext *castToDeclContext(const TopLevelStmtDecl *D) { + return static_cast(const_cast(D)); + } + static TopLevelStmtDecl *castFromDeclContext(const DeclContext *DC) { + return static_cast(const_cast(DC)); + } }; /// Represents a block literal declaration, which is like an diff --git a/clang/include/clang/AST/DeclBase.h b/clang/include/clang/AST/DeclBase.h index 9a4736019d1b1..76810a86a78a4 100644 --- a/clang/include/clang/AST/DeclBase.h +++ b/clang/include/clang/AST/DeclBase.h @@ -2120,6 +2120,7 @@ class DeclContext { case Decl::Block: case Decl::Captured: case Decl::ObjCMethod: + case Decl::TopLevelStmt: return true; default: return getDeclKind() >= Decl::firstFunction && diff --git a/clang/include/clang/Basic/DeclNodes.td b/clang/include/clang/Basic/DeclNodes.td index 8b1f415dd5fe2..48396e85c5ada 100644 --- a/clang/include/clang/Basic/DeclNodes.td +++ b/clang/include/clang/Basic/DeclNodes.td @@ -95,7 +95,7 @@ def LinkageSpec : DeclNode, DeclContext; def Export : DeclNode, DeclContext; def ObjCPropertyImpl : DeclNode; def FileScopeAsm : DeclNode; -def TopLevelStmt : DeclNode; +def TopLevelStmt : DeclNode, DeclContext; def AccessSpec : DeclNode; def Friend : DeclNode; def FriendTemplate : DeclNode; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 2d949f3fc9a71..592c7871a4a55 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -3263,7 +3263,8 @@ class Sema final { Decl *ActOnFileScopeAsmDecl(Expr *expr, SourceLocation AsmLoc, SourceLocation RParenLoc); - Decl *ActOnTopLevelStmtDecl(Stmt *Statement); + TopLevelStmtDecl *ActOnStartTopLevelStmtDecl(Scope *S); + void ActOnFinishTopLevelStmtDecl(TopLevelStmtDecl *D, Stmt *Statement); void ActOnPopScope(SourceLocation Loc, Scope *S); diff --git a/clang/lib/AST/Decl.cpp b/clang/lib/AST/Decl.cpp index 59c039f1f8dae..d681791d3920c 100644 --- a/clang/lib/AST/Decl.cpp +++ b/clang/lib/AST/Decl.cpp @@ -5552,14 +5552,13 @@ FileScopeAsmDecl *FileScopeAsmDecl::CreateDeserialized(ASTContext &C, void TopLevelStmtDecl::anchor() {} TopLevelStmtDecl *TopLevelStmtDecl::Create(ASTContext &C, Stmt *Statement) { - assert(Statement); assert(C.getLangOpts().IncrementalExtensions && "Must be used only in incremental mode"); - SourceLocation BeginLoc = Statement->getBeginLoc(); + SourceLocation Loc = Statement ? Statement->getBeginLoc() : SourceLocation(); DeclContext *DC = C.getTranslationUnitDecl(); - return new (C, DC) TopLevelStmtDecl(DC, BeginLoc, Statement); + return new (C, DC) TopLevelStmtDecl(DC, Loc, Statement); } TopLevelStmtDecl *TopLevelStmtDecl::CreateDeserialized(ASTContext &C, @@ -5572,6 +5571,12 @@ SourceRange TopLevelStmtDecl::getSourceRange() const { return SourceRange(getLocation(), Statement->getEndLoc()); } +void TopLevelStmtDecl::setStmt(Stmt *S) { + assert(S); + Statement = S; + setLocation(Statement->getBeginLoc()); +} + void EmptyDecl::anchor() {} EmptyDecl *EmptyDecl::Create(ASTContext &C, DeclContext *DC, SourceLocation L) { diff --git a/clang/lib/AST/DeclBase.cpp b/clang/lib/AST/DeclBase.cpp index 10fe8bb97ce66..fcedb3cfd176a 100644 --- a/clang/lib/AST/DeclBase.cpp +++ b/clang/lib/AST/DeclBase.cpp @@ -1352,6 +1352,7 @@ DeclContext *DeclContext::getPrimaryContext() { case Decl::ExternCContext: case Decl::LinkageSpec: case Decl::Export: + case Decl::TopLevelStmt: case Decl::Block: case Decl::Captured: case Decl::OMPDeclareReduction: diff --git a/clang/lib/Parse/ParseDecl.cpp b/clang/lib/Parse/ParseDecl.cpp index 81f1c71126944..64b234eb460d2 100644 --- a/clang/lib/Parse/ParseDecl.cpp +++ b/clang/lib/Parse/ParseDecl.cpp @@ -5678,24 +5678,32 @@ Parser::DeclGroupPtrTy Parser::ParseTopLevelStmtDecl() { // Parse a top-level-stmt. Parser::StmtVector Stmts; ParsedStmtContext SubStmtCtx = ParsedStmtContext(); - Actions.PushFunctionScope(); + ParseScope FnScope(this, Scope::FnScope | Scope::DeclScope | + Scope::CompoundStmtScope); + TopLevelStmtDecl *TLSD = Actions.ActOnStartTopLevelStmtDecl(getCurScope()); StmtResult R = ParseStatementOrDeclaration(Stmts, SubStmtCtx); - Actions.PopFunctionScopeInfo(); if (!R.isUsable()) return nullptr; - SmallVector DeclsInGroup; - DeclsInGroup.push_back(Actions.ActOnTopLevelStmtDecl(R.get())); + Actions.ActOnFinishTopLevelStmtDecl(TLSD, R.get()); if (Tok.is(tok::annot_repl_input_end) && Tok.getAnnotationValue() != nullptr) { ConsumeAnnotationToken(); - cast(DeclsInGroup.back())->setSemiMissing(); + TLSD->setSemiMissing(); } - // Currently happens for things like -fms-extensions and use `__if_exists`. - for (Stmt *S : Stmts) - DeclsInGroup.push_back(Actions.ActOnTopLevelStmtDecl(S)); + SmallVector DeclsInGroup; + DeclsInGroup.push_back(TLSD); + + // Currently happens for things like -fms-extensions and use `__if_exists`. + for (Stmt *S : Stmts) { + // Here we should be safe as `__if_exists` and friends are not introducing + // new variables which need to live outside file scope. + TopLevelStmtDecl *D = Actions.ActOnStartTopLevelStmtDecl(getCurScope()); + Actions.ActOnFinishTopLevelStmtDecl(D, S); + DeclsInGroup.push_back(D); + } return Actions.BuildDeclaratorGroup(DeclsInGroup); } diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index 6b81ee183cc44..67e56a917a51d 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -20519,12 +20519,22 @@ Decl *Sema::ActOnFileScopeAsmDecl(Expr *expr, return New; } -Decl *Sema::ActOnTopLevelStmtDecl(Stmt *Statement) { - auto *New = TopLevelStmtDecl::Create(Context, Statement); - Context.getTranslationUnitDecl()->addDecl(New); +TopLevelStmtDecl *Sema::ActOnStartTopLevelStmtDecl(Scope *S) { + auto *New = TopLevelStmtDecl::Create(Context, /*Statement=*/nullptr); + CurContext->addDecl(New); + PushDeclContext(S, New); + PushFunctionScope(); + PushCompoundScope(false); return New; } +void Sema::ActOnFinishTopLevelStmtDecl(TopLevelStmtDecl *D, Stmt *Statement) { + D->setStmt(Statement); + PopCompoundScope(); + PopFunctionScopeInfo(); + PopDeclContext(); +} + void Sema::ActOnPragmaRedefineExtname(IdentifierInfo* Name, IdentifierInfo* AliasName, SourceLocation PragmaLoc, diff --git a/clang/test/Interpreter/execute-stmts.cpp b/clang/test/Interpreter/execute-stmts.cpp index 2d4c17e0c91e6..433c6811777da 100644 --- a/clang/test/Interpreter/execute-stmts.cpp +++ b/clang/test/Interpreter/execute-stmts.cpp @@ -9,7 +9,6 @@ //CODEGEN-CHECK-COUNT-2: define internal void @__stmts__ //CODEGEN-CHECK-NOT: define internal void @__stmts__ - extern "C" int printf(const char*,...); template T call() { printf("called\n"); return T(); } @@ -41,3 +40,26 @@ for (; i > 4; --i) { printf("i = %d\n", i); }; int j = i; printf("j = %d\n", j); // CHECK-NEXT: j = 4 + +{i = 0; printf("i = %d (global scope)\n", i);} +// CHECK-NEXT: i = 0 + +while (int i = 1) { printf("i = %d (while condition)\n", i--); break; } +// CHECK-NEXT: i = 1 + +if (int i = 2) printf("i = %d (if condition)\n", i); +// CHECK-NEXT: i = 2 + +switch (int i = 3) { default: printf("i = %d (switch condition)\n", i); } +// CHECK-NEXT: i = 3 + +for (int i = 4; i > 3; --i) printf("i = %d (for-init)\n", i); +// CHECK-NEXT: i = 4 + +for (const auto &i : "5") printf("i = %c (range-based for-init)\n", i); +// CHECK-NEXT: i = 5 + +int *aa=nullptr; +if (auto *b=aa) *b += 1; +while (auto *b=aa) ; +for (auto *b=aa; b; *b+=1) ;