diff --git a/clang/include/clang/Lex/PPCallbacks.h b/clang/include/clang/Lex/PPCallbacks.h index 313b730afbab8..e6120c5648798 100644 --- a/clang/include/clang/Lex/PPCallbacks.h +++ b/clang/include/clang/Lex/PPCallbacks.h @@ -499,10 +499,10 @@ class PPChainedCallbacks : public PPCallbacks { } bool EmbedFileNotFound(StringRef FileName) override { - bool Skip = First->FileNotFound(FileName); + bool Skip = First->EmbedFileNotFound(FileName); // Make sure to invoke the second callback, no matter if the first already // returned true to skip the file. - Skip |= Second->FileNotFound(FileName); + Skip |= Second->EmbedFileNotFound(FileName); return Skip; } diff --git a/clang/lib/Lex/PPDirectives.cpp b/clang/lib/Lex/PPDirectives.cpp index 891c8ab7f3155..764a893eebe3c 100644 --- a/clang/lib/Lex/PPDirectives.cpp +++ b/clang/lib/Lex/PPDirectives.cpp @@ -1397,11 +1397,12 @@ void Preprocessor::HandleDirective(Token &Result) { return HandleIdentSCCSDirective(Result); case tok::pp_sccs: return HandleIdentSCCSDirective(Result); - case tok::pp_embed: - return HandleEmbedDirective(SavedHash.getLocation(), Result, - getCurrentFileLexer() - ? *getCurrentFileLexer()->getFileEntry() - : static_cast(nullptr)); + case tok::pp_embed: { + if (PreprocessorLexer *CurrentFileLexer = getCurrentFileLexer()) + if (OptionalFileEntryRef FERef = CurrentFileLexer->getFileEntry()) + return HandleEmbedDirective(SavedHash.getLocation(), Result, *FERef); + return HandleEmbedDirective(SavedHash.getLocation(), Result, nullptr); + } case tok::pp_assert: //isExtension = true; // FIXME: implement #assert break; diff --git a/clang/unittests/Lex/PPCallbacksTest.cpp b/clang/unittests/Lex/PPCallbacksTest.cpp index 990689c6b1e45..9533fbc776e6e 100644 --- a/clang/unittests/Lex/PPCallbacksTest.cpp +++ b/clang/unittests/Lex/PPCallbacksTest.cpp @@ -437,6 +437,7 @@ TEST_F(PPCallbacksTest, FileNotFoundSkipped) { PreprocessorOptions PPOpts; HeaderSearch HeaderInfo(HSOpts, SourceMgr, Diags, LangOpts, Target.get()); + unsigned int NumCalls = 0; DiagnosticConsumer *DiagConsumer = new DiagnosticConsumer; DiagnosticsEngine FileNotFoundDiags(DiagID, DiagOpts, DiagConsumer); Preprocessor PP(PPOpts, FileNotFoundDiags, LangOpts, SourceMgr, HeaderInfo, @@ -445,21 +446,68 @@ TEST_F(PPCallbacksTest, FileNotFoundSkipped) { class FileNotFoundCallbacks : public PPCallbacks { public: - unsigned int NumCalls = 0; + unsigned int &NumCalls; + + FileNotFoundCallbacks(unsigned int &NumCalls) : NumCalls(NumCalls) {} + bool FileNotFound(StringRef FileName) override { NumCalls++; return FileName == "skipped.h"; } }; - auto *Callbacks = new FileNotFoundCallbacks; - PP.addPPCallbacks(std::unique_ptr(Callbacks)); + PP.addPPCallbacks(std::make_unique(NumCalls)); + + // Lex source text. + PP.EnterMainSourceFile(); + PP.LexTokensUntilEOF(); + + ASSERT_EQ(1u, NumCalls); + ASSERT_EQ(0u, DiagConsumer->getNumErrors()); +} + +TEST_F(PPCallbacksTest, EmbedFileNotFoundChained) { + const char *SourceText = "#embed \"notfound.h\"\n"; + + std::unique_ptr SourceBuf = + llvm::MemoryBuffer::getMemBuffer(SourceText); + SourceMgr.setMainFileID(SourceMgr.createFileID(std::move(SourceBuf))); + + unsigned int NumCalls = 0; + HeaderSearchOptions HSOpts; + TrivialModuleLoader ModLoader; + PreprocessorOptions PPOpts; + HeaderSearch HeaderInfo(HSOpts, SourceMgr, Diags, LangOpts, Target.get()); + + DiagnosticConsumer *DiagConsumer = new DiagnosticConsumer; + DiagnosticsEngine EmbedFileNotFoundDiags(DiagID, DiagOpts, DiagConsumer); + Preprocessor PP(PPOpts, EmbedFileNotFoundDiags, LangOpts, SourceMgr, + HeaderInfo, ModLoader, /*IILookup=*/nullptr, + /*OwnsHeaderSearch=*/false); + PP.Initialize(*Target); + + class EmbedFileNotFoundCallbacks : public PPCallbacks { + public: + unsigned int &NumCalls; + + EmbedFileNotFoundCallbacks(unsigned int &NumCalls) : NumCalls(NumCalls) {} + + bool EmbedFileNotFound(StringRef FileName) override { + NumCalls++; + return true; + } + }; + + // Add two instances of `EmbedFileNotFoundCallbacks` to ensure the + // preprocessor is using an instance of `PPChainedCallbaks`. + PP.addPPCallbacks(std::make_unique(NumCalls)); + PP.addPPCallbacks(std::make_unique(NumCalls)); // Lex source text. PP.EnterMainSourceFile(); PP.LexTokensUntilEOF(); - ASSERT_EQ(1u, Callbacks->NumCalls); + ASSERT_EQ(2u, NumCalls); ASSERT_EQ(0u, DiagConsumer->getNumErrors()); }