Skip to content

Commit

Permalink
[StandardInstrumentations] Check function analysis invalidation in mo…
Browse files Browse the repository at this point in the history
…dule passes as well

See comments for why we now need to pass in the MAM instead of the FAM.

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D146160
  • Loading branch information
aeubanks committed Mar 15, 2023
1 parent 20a7ea4 commit d6c0724
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 60 deletions.
4 changes: 2 additions & 2 deletions llvm/include/llvm/Passes/StandardInstrumentations.h
Expand Up @@ -153,7 +153,7 @@ class PreservedCFGCheckerInstrumentation {
#endif

void registerCallbacks(PassInstrumentationCallbacks &PIC,
FunctionAnalysisManager &FAM);
ModuleAnalysisManager &MAM);
};

// Base class for classes that report changes to the IR.
Expand Down Expand Up @@ -574,7 +574,7 @@ class StandardInstrumentations {
// Register all the standard instrumentation callbacks. If \p FAM is nullptr
// then PreservedCFGChecker is not enabled.
void registerCallbacks(PassInstrumentationCallbacks &PIC,
FunctionAnalysisManager *FAM = nullptr);
ModuleAnalysisManager *MAM = nullptr);

TimePassesHandler &getTimePasses() { return TimePasses; }
};
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/LTO/LTOBackend.cpp
Expand Up @@ -260,7 +260,7 @@ static void runNewPMPasses(const Config &Conf, Module &Mod, TargetMachine *TM,

PassInstrumentationCallbacks PIC;
StandardInstrumentations SI(Mod.getContext(), Conf.DebugPassManager);
SI.registerCallbacks(PIC, &FAM);
SI.registerCallbacks(PIC, &MAM);
PassBuilder PB(TM, Conf.PTO, PGOOpt, &PIC);

RegisterPassPlugins(Conf.PassPlugins, PB);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/LTO/ThinLTOCodeGenerator.cpp
Expand Up @@ -245,7 +245,7 @@ static void optimizeModule(Module &TheModule, TargetMachine &TM,

PassInstrumentationCallbacks PIC;
StandardInstrumentations SI(TheModule.getContext(), DebugPassManager);
SI.registerCallbacks(PIC, &FAM);
SI.registerCallbacks(PIC, &MAM);
PipelineTuningOptions PTO;
PTO.LoopVectorization = true;
PTO.SLPVectorization = true;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Passes/PassBuilderBindings.cpp
Expand Up @@ -66,7 +66,7 @@ LLVMErrorRef LLVMRunPasses(LLVMModuleRef M, const char *Passes,
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);

StandardInstrumentations SI(Mod->getContext(), Debug, VerifyEach);
SI.registerCallbacks(PIC, &FAM);
SI.registerCallbacks(PIC, &MAM);
ModulePassManager MPM;
if (VerifyEach) {
MPM.addPass(VerifierPass());
Expand Down
119 changes: 71 additions & 48 deletions llvm/lib/Passes/StandardInstrumentations.cpp
Expand Up @@ -1075,29 +1075,46 @@ bool PreservedCFGCheckerInstrumentation::CFG::invalidate(
PAC.preservedSet<CFGAnalyses>());
}

static SmallVector<Function *, 1> GetFunctions(Any IR) {
SmallVector<Function *, 1> Functions;

if (const auto **MaybeF = any_cast<const Function *>(&IR)) {
Functions.push_back(*const_cast<Function **>(MaybeF));
} else if (const auto **MaybeM = any_cast<const Module *>(&IR)) {
for (Function &F : **const_cast<Module **>(MaybeM))
Functions.push_back(&F);
}
return Functions;
}

void PreservedCFGCheckerInstrumentation::registerCallbacks(
PassInstrumentationCallbacks &PIC, FunctionAnalysisManager &FAM) {
PassInstrumentationCallbacks &PIC, ModuleAnalysisManager &MAM) {
if (!VerifyAnalysisInvalidation)
return;

FAM.registerPass([&] { return PreservedCFGCheckerAnalysis(); });
FAM.registerPass([&] { return PreservedFunctionHashAnalysis(); });

PIC.registerBeforeNonSkippedPassCallback(
[this, &FAM](StringRef P, Any IR) {
bool Registered = false;
PIC.registerBeforeNonSkippedPassCallback([this, &MAM, Registered](
StringRef P, Any IR) mutable {
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
assert(&PassStack.emplace_back(P));
assert(&PassStack.emplace_back(P));
#endif
(void)this;
const auto **F = any_cast<const Function *>(&IR);
if (!F)
return;
(void)this;

// Make sure a fresh CFG snapshot is available before the pass.
FAM.getResult<PreservedCFGCheckerAnalysis>(*const_cast<Function *>(*F));
FAM.getResult<PreservedFunctionHashAnalysis>(
*const_cast<Function *>(*F));
});
auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(
*const_cast<Module *>(unwrapModule(IR, /*Force=*/true)))
.getManager();
if (!Registered) {
FAM.registerPass([&] { return PreservedCFGCheckerAnalysis(); });
FAM.registerPass([&] { return PreservedFunctionHashAnalysis(); });
Registered = true;
}

for (Function *F : GetFunctions(IR)) {
// Make sure a fresh CFG snapshot is available before the pass.
FAM.getResult<PreservedCFGCheckerAnalysis>(*F);
FAM.getResult<PreservedFunctionHashAnalysis>(*F);
}
});

PIC.registerAfterPassInvalidatedCallback(
[this](StringRef P, const PreservedAnalyses &PassPA) {
Expand All @@ -1108,44 +1125,50 @@ void PreservedCFGCheckerInstrumentation::registerCallbacks(
(void)this;
});

PIC.registerAfterPassCallback([this, &FAM](StringRef P, Any IR,
PIC.registerAfterPassCallback([this, &MAM](StringRef P, Any IR,
const PreservedAnalyses &PassPA) {
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
assert(PassStack.pop_back_val() == P &&
"Before and After callbacks must correspond");
#endif
(void)this;

const auto **MaybeF = any_cast<const Function *>(&IR);
if (!MaybeF)
return;
Function &F = *const_cast<Function *>(*MaybeF);

if (auto *HashBefore =
FAM.getCachedResult<PreservedFunctionHashAnalysis>(F)) {
if (HashBefore->Hash != StructuralHash(F)) {
report_fatal_error(formatv(
"Function @{0} changed by {1} without invalidating analyses",
F.getName(), P));
// We have to get the FAM via the MAM, rather than directly use a passed in
// FAM because if MAM has not cached the FAM, it won't invalidate function
// analyses in FAM.
auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(
*const_cast<Module *>(unwrapModule(IR, /*Force=*/true)))
.getManager();

for (Function *F : GetFunctions(IR)) {
if (auto *HashBefore =
FAM.getCachedResult<PreservedFunctionHashAnalysis>(*F)) {
if (HashBefore->Hash != StructuralHash(*F)) {
report_fatal_error(formatv(
"Function @{0} changed by {1} without invalidating analyses",
F->getName(), P));
}
}
}

auto CheckCFG = [](StringRef Pass, StringRef FuncName,
const CFG &GraphBefore, const CFG &GraphAfter) {
if (GraphAfter == GraphBefore)
return;

dbgs() << "Error: " << Pass
<< " does not invalidate CFG analyses but CFG changes detected in "
"function @"
<< FuncName << ":\n";
CFG::printDiff(dbgs(), GraphBefore, GraphAfter);
report_fatal_error(Twine("CFG unexpectedly changed by ", Pass));
};

if (auto *GraphBefore = FAM.getCachedResult<PreservedCFGCheckerAnalysis>(F))
CheckCFG(P, F.getName(), *GraphBefore,
CFG(&F, /* TrackBBLifetime */ false));
auto CheckCFG = [](StringRef Pass, StringRef FuncName,
const CFG &GraphBefore, const CFG &GraphAfter) {
if (GraphAfter == GraphBefore)
return;

dbgs()
<< "Error: " << Pass
<< " does not invalidate CFG analyses but CFG changes detected in "
"function @"
<< FuncName << ":\n";
CFG::printDiff(dbgs(), GraphBefore, GraphAfter);
report_fatal_error(Twine("CFG unexpectedly changed by ", Pass));
};

if (auto *GraphBefore =
FAM.getCachedResult<PreservedCFGCheckerAnalysis>(*F))
CheckCFG(P, F->getName(), *GraphBefore,
CFG(F, /* TrackBBLifetime */ false));
}
});
}

Expand Down Expand Up @@ -2175,7 +2198,7 @@ void PrintCrashIRInstrumentation::registerCallbacks(
}

void StandardInstrumentations::registerCallbacks(
PassInstrumentationCallbacks &PIC, FunctionAnalysisManager *FAM) {
PassInstrumentationCallbacks &PIC, ModuleAnalysisManager *MAM) {
PrintIR.registerCallbacks(PIC);
PrintPass.registerCallbacks(PIC);
TimePasses.registerCallbacks(PIC);
Expand All @@ -2189,8 +2212,8 @@ void StandardInstrumentations::registerCallbacks(
WebsiteChangeReporter.registerCallbacks(PIC);
ChangeTester.registerCallbacks(PIC);
PrintCrashIR.registerCallbacks(PIC);
if (FAM)
PreservedCFGChecker.registerCallbacks(PIC, *FAM);
if (MAM)
PreservedCFGChecker.registerCallbacks(PIC, *MAM);

// TimeProfiling records the pass running time cost.
// Its 'BeforePassCallback' can be appended at the tail of all the
Expand Down
2 changes: 1 addition & 1 deletion llvm/tools/opt/NewPMDriver.cpp
Expand Up @@ -395,7 +395,7 @@ bool llvm::runPassPipeline(StringRef Arg0, Module &M, TargetMachine *TM,
PrintPassOpts.SkipAnalyses = DebugPM == DebugLogging::Quiet;
StandardInstrumentations SI(M.getContext(), DebugPM != DebugLogging::None,
VerifyEachPass, PrintPassOpts);
SI.registerCallbacks(PIC, &FAM);
SI.registerCallbacks(PIC, &MAM);
DebugifyEachInstrumentation Debugify;
DebugifyStatsMap DIStatsMap;
DebugInfoPerPass DebugInfoBeforePass;
Expand Down
57 changes: 51 additions & 6 deletions llvm/unittests/IR/PassManagerTest.cpp
Expand Up @@ -824,10 +824,13 @@ TEST_F(PassManagerTest, FunctionPassCFGChecker) {

auto *F = M->getFunction("foo");
FunctionAnalysisManager FAM;
ModuleAnalysisManager MAM;
FunctionPassManager FPM;
PassInstrumentationCallbacks PIC;
StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ true);
SI.registerCallbacks(PIC, &FAM);
SI.registerCallbacks(PIC, &MAM);
MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
FAM.registerPass([&] { return DominatorTreeAnalysis(); });
FAM.registerPass([&] { return AssumptionAnalysis(); });
Expand Down Expand Up @@ -870,10 +873,13 @@ TEST_F(PassManagerTest, FunctionPassCFGCheckerInvalidateAnalysis) {

auto *F = M->getFunction("foo");
FunctionAnalysisManager FAM;
ModuleAnalysisManager MAM;
FunctionPassManager FPM;
PassInstrumentationCallbacks PIC;
StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ true);
SI.registerCallbacks(PIC, &FAM);
SI.registerCallbacks(PIC, &MAM);
MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
FAM.registerPass([&] { return DominatorTreeAnalysis(); });
FAM.registerPass([&] { return AssumptionAnalysis(); });
Expand Down Expand Up @@ -935,10 +941,13 @@ TEST_F(PassManagerTest, FunctionPassCFGCheckerWrapped) {

auto *F = M->getFunction("foo");
FunctionAnalysisManager FAM;
ModuleAnalysisManager MAM;
FunctionPassManager FPM;
PassInstrumentationCallbacks PIC;
StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ true);
SI.registerCallbacks(PIC, &FAM);
SI.registerCallbacks(PIC, &MAM);
MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
FAM.registerPass([&] { return DominatorTreeAnalysis(); });
FAM.registerPass([&] { return AssumptionAnalysis(); });
Expand All @@ -961,17 +970,20 @@ struct WrongFunctionPass : PassInfoMixin<WrongFunctionPass> {
static StringRef name() { return "WrongFunctionPass"; }
};

TEST_F(PassManagerTest, FunctionAnalysisMissedInvalidation) {
TEST_F(PassManagerTest, FunctionPassMissedFunctionAnalysisInvalidation) {
LLVMContext Context;
auto M = parseIR(Context, "define void @foo() {\n"
" %a = add i32 0, 0\n"
" ret void\n"
"}\n");

FunctionAnalysisManager FAM;
ModuleAnalysisManager MAM;
PassInstrumentationCallbacks PIC;
StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ false);
SI.registerCallbacks(PIC, &FAM);
SI.registerCallbacks(PIC, &MAM);
MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });

FunctionPassManager FPM;
Expand All @@ -981,6 +993,39 @@ TEST_F(PassManagerTest, FunctionAnalysisMissedInvalidation) {
EXPECT_DEATH(FPM.run(*F, FAM), "Function @foo changed by WrongFunctionPass without invalidating analyses");
}

#endif
struct WrongModulePass : PassInfoMixin<WrongModulePass> {
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM) {
for (Function &F : M)
F.getEntryBlock().begin()->eraseFromParent();

return PreservedAnalyses::all();
}
static StringRef name() { return "WrongModulePass"; }
};

TEST_F(PassManagerTest, ModulePassMissedFunctionAnalysisInvalidation) {
LLVMContext Context;
auto M = parseIR(Context, "define void @foo() {\n"
" %a = add i32 0, 0\n"
" ret void\n"
"}\n");

FunctionAnalysisManager FAM;
ModuleAnalysisManager MAM;
PassInstrumentationCallbacks PIC;
StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ false);
SI.registerCallbacks(PIC, &MAM);
MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });

ModulePassManager MPM;
MPM.addPass(WrongModulePass());

EXPECT_DEATH(
MPM.run(*M, MAM),
"Function @foo changed by WrongModulePass without invalidating analyses");
}

#endif
}

0 comments on commit d6c0724

Please sign in to comment.