diff --git a/clang/test/Interpreter/recovery-after-failure.cpp b/clang/test/Interpreter/recovery-after-failure.cpp new file mode 100644 index 0000000000000..c9f1b52cbede9 --- /dev/null +++ b/clang/test/Interpreter/recovery-after-failure.cpp @@ -0,0 +1,13 @@ +// REQUIRES: host-supports-jit +// UNSUPPORTED: system-aix +// RUN: cat %s | clang-repl 2>&1 | FileCheck %s + +// Failed materialization shouldn't poison subsequent statements +extern "C" int undefined_function(); +int result = undefined_function(); +// CHECK: error: Failed to materialize symbols + +int x = 42; +// CHECK-NOT: error: Failed to materialize symbols + +%quit diff --git a/llvm/include/llvm/ExecutionEngine/Orc/Core.h b/llvm/include/llvm/ExecutionEngine/Orc/Core.h index 24a0cb74b1fbc..2f1a66e1af5f5 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/Core.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/Core.h @@ -1289,6 +1289,13 @@ class LLVM_ABI Platform { /// ResourceTracker is removed. virtual Error notifyRemoving(ResourceTracker &RT) = 0; + /// This method will be called when materialization fails for symbols managed + /// by the given MaterializationResponsibility. Platforms can override this to + /// clean up internal bookkeeping (e.g., init/deinit symbol tracking). + virtual Error notifyFailed(MaterializationResponsibility &MR) { + return Error::success(); + } + /// A utility function for looking up initializer symbols. Performs a blocking /// lookup for the given symbols in each of the given JITDylibs. /// diff --git a/llvm/lib/ExecutionEngine/Orc/Core.cpp b/llvm/lib/ExecutionEngine/Orc/Core.cpp index d029ac587fb9a..9f0dc1def9001 100644 --- a/llvm/lib/ExecutionEngine/Orc/Core.cpp +++ b/llvm/lib/ExecutionEngine/Orc/Core.cpp @@ -3162,6 +3162,10 @@ void ExecutionSession::OL_notifyFailed(MaterializationResponsibility &MR) { if (MR.SymbolFlags.empty()) return; + // Notify the platform to clean up any failed-symbol bookkeeping. + if (auto *Plat = getPlatform()) + cantFail(Plat->notifyFailed(MR)); + SymbolNameVector SymbolsToFail; for (auto &[Name, Flags] : MR.SymbolFlags) SymbolsToFail.push_back(Name); diff --git a/llvm/lib/ExecutionEngine/Orc/LLJIT.cpp b/llvm/lib/ExecutionEngine/Orc/LLJIT.cpp index 7487526c5d059..b2da46bc89264 100644 --- a/llvm/lib/ExecutionEngine/Orc/LLJIT.cpp +++ b/llvm/lib/ExecutionEngine/Orc/LLJIT.cpp @@ -99,6 +99,7 @@ class GenericLLVMIRPlatform : public Platform { // Noop -- Nothing to do (yet). return Error::success(); } + Error notifyFailed(MaterializationResponsibility &MR) override; private: GenericLLVMIRPlatformSupport &S; @@ -227,6 +228,44 @@ class GenericLLVMIRPlatformSupport : public LLJIT::PlatformSupport { return Error::success(); } + Error notifyFailed(MaterializationResponsibility &MR) { + auto &JD = MR.getTargetJITDylib(); + + // We only care about symbols with our known (init/deinit) prefixes + DenseSet FailedInitSyms; + DenseSet FailedDeInitSyms; + + for (auto &[Name, Flags] : MR.getSymbols()) { + if ((*Name).starts_with(InitFunctionPrefix)) + FailedInitSyms.insert(Name); + else if ((*Name).starts_with(DeInitFunctionPrefix)) + FailedDeInitSyms.insert(Name); + } + + // Remove failed symbols from tracking maps. + auto cleanMap = [&](auto &Map, + const DenseSet &FailedSyms) { + if (FailedSyms.empty()) + return; + + auto It = Map.find(&JD); + if (It == Map.end()) + return; + + It->second.remove_if([&](const SymbolStringPtr &Name, SymbolLookupFlags) { + return FailedSyms.contains(Name); + }); + + if (It->second.empty()) + Map.erase(It); + }; + + cleanMap(InitFunctions, FailedInitSyms); + cleanMap(DeInitFunctions, FailedDeInitSyms); + + return Error::success(); + } + Error initialize(JITDylib &JD) override { LLVM_DEBUG({ dbgs() << "GenericLLVMIRPlatformSupport getting initializers to run\n"; @@ -505,6 +544,10 @@ Error GenericLLVMIRPlatform::notifyAdding(ResourceTracker &RT, return S.notifyAdding(RT, MU); } +Error GenericLLVMIRPlatform::notifyFailed(MaterializationResponsibility &MR) { + return S.notifyFailed(MR); +} + Expected GlobalCtorDtorScraper::operator()(ThreadSafeModule TSM, MaterializationResponsibility &R) { diff --git a/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt b/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt index 7b563d7bcc68c..92bc935c3fd78 100644 --- a/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt +++ b/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt @@ -27,6 +27,7 @@ add_llvm_unittest(OrcJITTests JITTargetMachineBuilderTest.cpp LazyCallThroughAndReexportsTest.cpp LibraryResolverTest.cpp + LLJITTest.cpp LookupAndRecordAddrsTest.cpp MachOPlatformTest.cpp MapperJITLinkMemoryManagerTest.cpp diff --git a/llvm/unittests/ExecutionEngine/Orc/LLJITTest.cpp b/llvm/unittests/ExecutionEngine/Orc/LLJITTest.cpp new file mode 100644 index 0000000000000..a8432f2d0e2e6 --- /dev/null +++ b/llvm/unittests/ExecutionEngine/Orc/LLJITTest.cpp @@ -0,0 +1,105 @@ +#include "llvm/ExecutionEngine/Orc/LLJIT.h" +#include "OrcTestCommon.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Testing/Support/Error.h" + +using namespace llvm; +using namespace llvm::orc; + +namespace { + +static ThreadSafeModule parseModule(llvm::StringRef Source, + llvm::StringRef Name) { + auto Ctx = std::make_unique(); + SMDiagnostic Err; + auto M = parseIR(MemoryBufferRef(Source, Name), Err, *Ctx); + if (!M) { + Err.print("Testcase source failed to parse: ", errs()); + exit(1); + } + return ThreadSafeModule(std::move(M), std::move(Ctx)); +} + +TEST(LLJITTest, CleanupFailedInitializers) { + OrcNativeTarget::initialize(); + auto J = cantFail(LLJITBuilder().create()); + auto &JD = J->getMainJITDylib(); + + // ctor references undefined symbol 'testing' + auto TSM_A = parseModule(R"( + declare void @testing() + + define internal void @ctor_A() { + call void @testing() + ret void + } + + @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [ + { i32, ptr, ptr } { i32 65535, ptr @ctor_A, ptr null } + ] + )", + "A"); + + cantFail(J->addIRModule(std::move(TSM_A))); + + // Initialize fails: "Symbols not found: [ testing ]" + EXPECT_THAT_ERROR(J->initialize(JD), Failed()); + + // Clean module should succeed if A's bookkeeping was cleaned up + auto TSM_B = parseModule(R"( + @i = global i32 42 + )", + "B"); + + cantFail(J->addIRModule(std::move(TSM_B))); + + EXPECT_THAT_ERROR(J->initialize(JD), Succeeded()); +} + +TEST(LLJITTest, RepeatedInitializationFailures) { + // Consecutive failures don't accumulate stale state + OrcNativeTarget::initialize(); + auto J = cantFail(LLJITBuilder().create()); + auto &JD = J->getMainJITDylib(); + + // First failure + auto TSM_A = parseModule(R"( + declare void @undefined_a() + define internal void @ctor_A() { + call void @undefined_a() + ret void + } + @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [ + { i32, ptr, ptr } { i32 65535, ptr @ctor_A, ptr null } + ] + )", + "A"); + cantFail(J->addIRModule(std::move(TSM_A))); + EXPECT_THAT_ERROR(J->initialize(JD), Failed()); + + // Second failure + auto TSM_B = parseModule(R"( + declare void @undefined_b() + define internal void @ctor_B() { + call void @undefined_b() + ret void + } + @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [ + { i32, ptr, ptr } { i32 65535, ptr @ctor_B, ptr null } + ] + )", + "B"); + cantFail(J->addIRModule(std::move(TSM_B))); + EXPECT_THAT_ERROR(J->initialize(JD), Failed()); + + // Should succeed, both A and B cleaned up + auto TSM_C = parseModule(R"( + @x = global i32 0 + )", + "C"); + cantFail(J->addIRModule(std::move(TSM_C))); + EXPECT_THAT_ERROR(J->initialize(JD), Succeeded()); +} + +} // anonymous namespace