From 9dcd4bb53bbb8956d7d4570a0479822d62394c08 Mon Sep 17 00:00:00 2001 From: Brian Pepin Date: Fri, 15 May 2026 15:12:45 -0700 Subject: [PATCH 1/2] Cleanup --- Source/Task/TaskQueue.cpp | 25 +++++++--------- Source/Task/TaskQueueImpl.h | 13 +-------- Source/Task/TaskQueueP.h | 3 +- Source/Task/ThreadPool_win32.cpp | 36 +++++++++++++++++++++++- Source/Task/XTaskQueuePriv.h | 9 ++---- Tests/UnitTests/Tests/TaskQueueTests.cpp | 2 +- 6 files changed, 51 insertions(+), 37 deletions(-) diff --git a/Source/Task/TaskQueue.cpp b/Source/Task/TaskQueue.cpp index cc6b1e3e..83a3eb79 100644 --- a/Source/Task/TaskQueue.cpp +++ b/Source/Task/TaskQueue.cpp @@ -306,7 +306,7 @@ HRESULT TaskQueuePortImpl::Initialize( RETURN_IF_FAILED(m_timer.Initialize(this, [](void* context) { TaskQueuePortImpl* pthis = static_cast(context); - pthis->SubmitPendingCallback(); + pthis->SubmitPendingCallbacks(); })); #ifdef _WIN32 @@ -1000,15 +1000,13 @@ void TaskQueuePortImpl::CancelPendingEntries( } } -#ifdef HC_UNITTEST_API // Test hook: let unit tests enqueue a sibling delayed callback while this // termination path still owns the interleaving window that used to race - // with SubmitPendingCallback(). + // with SubmitPendingCallbacks(). if (auto hooks = portContext->GetQueue()->GetTestHooks(); hooks != nullptr) { hooks->PendingEntriesRemovedDuringTermination(portContext->GetType()); } -#endif #ifdef _WIN32 @@ -1169,7 +1167,6 @@ void TaskQueuePortImpl::PromoteReadyPendingCallbacks( // No future entries remain in the pending list. uint64_t noDueTime = UINT64_MAX; -#ifdef HC_UNITTEST_API m_attachedContexts.Visit([&](ITaskQueuePortContext* portContext) { auto hooks = portContext->GetQueue()->GetTestHooks(); @@ -1180,7 +1177,6 @@ void TaskQueuePortImpl::PromoteReadyPendingCallbacks( dueTime); } }); -#endif if (m_timerDue.compare_exchange_strong(dueTime, noDueTime)) { @@ -1188,14 +1184,13 @@ void TaskQueuePortImpl::PromoteReadyPendingCallbacks( // in lost delayed task wakes. Don't cancel the timer here // as another scheduled callback could have been added. // The CAS above is sufficient: the timer has already fired - // (call site 1: SubmitPendingCallback) or was already + // (call site 1: SubmitPendingCallbacks) or was already // canceled (call site 2: CancelPendingEntries). A Cancel() // here raced with concurrent QueueItem/Start calls on other // threads, permanently stranding entries in m_pendingList. // See VerifyDelayedCallbackTimerRaceOnManualQueue for full // analysis. The test hook here allows unit tests to verify // there is no race. -#ifdef HC_UNITTEST_API m_attachedContexts.Visit([&](ITaskQueuePortContext* portContext) { auto hooks = portContext->GetQueue()->GetTestHooks(); @@ -1206,7 +1201,6 @@ void TaskQueuePortImpl::PromoteReadyPendingCallbacks( noDueTime); } }); -#endif // A concurrent QueueItem can append a future entry after our // sweep has already concluded there is no next item, but before @@ -1226,7 +1220,7 @@ void TaskQueuePortImpl::PromoteReadyPendingCallbacks( } } -void TaskQueuePortImpl::SubmitPendingCallback() +void TaskQueuePortImpl::SubmitPendingCallbacks() { while (true) { @@ -2464,7 +2458,6 @@ STDAPI_(bool) XTaskQueueUninitialize( return ApiRefs::WaitZeroRefs(timeoutMilliseconds); } -#ifdef HC_UNITTEST_API /// /// Sets or clears test hooks on a task queue. /// @@ -2479,7 +2472,11 @@ STDAPI XTaskQueueSetTestHooks( return S_OK; } -STDAPI XTaskQueueSubmitPendingCallbackForTests( +/// +/// Submits any pending delayed callbacks that are due to run. This is +/// intended for use in unit tests. +/// +STDAPI XTaskQueueSubmitPendingCallbacks( _In_ XTaskQueueHandle queue, _In_ XTaskQueuePort port ) noexcept @@ -2490,9 +2487,7 @@ STDAPI XTaskQueueSubmitPendingCallbackForTests( referenced_ptr portContext; RETURN_IF_FAILED(aq->GetPortContext(port, portContext.address_of())); - auto* portImpl = static_cast(portContext->GetPort()); - portImpl->SubmitPendingCallbackForTests(); + portContext->GetPort()->SubmitPendingCallbacks(); return S_OK; } -#endif diff --git a/Source/Task/TaskQueueImpl.h b/Source/Task/TaskQueueImpl.h index 9dd18f0b..4b65fcb2 100644 --- a/Source/Task/TaskQueueImpl.h +++ b/Source/Task/TaskQueueImpl.h @@ -215,12 +215,7 @@ class TaskQueuePortImpl: public Api void __stdcall SuspendPort(); void __stdcall ResumePort(); -#ifdef HC_UNITTEST_API - void __stdcall SubmitPendingCallbackForTests() - { - SubmitPendingCallback(); - } -#endif + void __stdcall SubmitPendingCallbacks(); private: @@ -315,8 +310,6 @@ class TaskQueuePortImpl: public Api _In_ uint64_t dueTime, _In_ uint64_t now); - void SubmitPendingCallback(); - void SignalTerminations(); void ScheduleTermination(_In_ TerminationEntry* term); bool TerminationListEmpty(); @@ -408,10 +401,8 @@ class TaskQueueImpl : public Api _In_ XTaskQueuePortHandle completionPort); XTaskQueueHandle __stdcall GetHandle() override { return &m_header; } -#ifdef HC_UNITTEST_API XTaskQueueTestHooks* __stdcall GetTestHooks() override { return m_testHooks; } void __stdcall SetTestHooks(_In_ XTaskQueueTestHooks* testHooks) override { m_testHooks = testHooks; } -#endif HRESULT __stdcall GetPortContext( _In_ XTaskQueuePort port, @@ -483,9 +474,7 @@ class TaskQueueImpl : public Api TerminationData m_termination; TaskQueuePortContextImpl m_work; TaskQueuePortContextImpl m_completion; -#ifdef HC_UNITTEST_API XTaskQueueTestHooks* m_testHooks = nullptr; -#endif #ifdef SUSPEND_API SuspendResumeHandler m_suspendHandler; diff --git a/Source/Task/TaskQueueP.h b/Source/Task/TaskQueueP.h index 0f70d3c7..b34be4d4 100644 --- a/Source/Task/TaskQueueP.h +++ b/Source/Task/TaskQueueP.h @@ -83,6 +83,7 @@ struct ITaskQueuePort: IApi virtual void __stdcall SuspendPort() = 0; virtual void __stdcall ResumePort() = 0; + virtual void __stdcall SubmitPendingCallbacks() = 0; }; // The status of a port on the queue. This status is used in @@ -125,10 +126,8 @@ struct ITaskQueuePortContext : IApi struct ITaskQueue : IApi { virtual XTaskQueueHandle __stdcall GetHandle() = 0; -#ifdef HC_UNITTEST_API virtual XTaskQueueTestHooks* __stdcall GetTestHooks() = 0; virtual void __stdcall SetTestHooks(_In_ XTaskQueueTestHooks* testHooks) = 0; -#endif virtual HRESULT __stdcall GetPortContext( _In_ XTaskQueuePort port, diff --git a/Source/Task/ThreadPool_win32.cpp b/Source/Task/ThreadPool_win32.cpp index 082f1a1a..c8506ddb 100644 --- a/Source/Task/ThreadPool_win32.cpp +++ b/Source/Task/ThreadPool_win32.cpp @@ -3,6 +3,37 @@ namespace OS { + // This API is documented but not defined in the headers. Load it dynamcially + // so the world doesn't have to add linkage to ntdll. Ntdll is not unloadable, + // so safe to leak the module ref here. + static inline BOOLEAN __stdcall RtlDllShutdownInProgress() noexcept + { + static decltype(RtlDllShutdownInProgress)* s_pfnRtlDllShutdownInProgress = nullptr; + static HMODULE s_ntdllModuleHandle = nullptr; + + // No locking needed -- if these race the threads race to copy + // the same values, and worst case is we get an addl ref on + // a dll we'll never unload anyway. GetModuleHandle is not defined + // for UWP apps, so don't do this safety check for them. + +#ifdef GetModuleHandle + if (s_pfnRtlDllShutdownInProgress == nullptr) + { + if (s_ntdllModuleHandle == nullptr) + { + s_ntdllModuleHandle = GetModuleHandleW(L"ntdll.dll"); + } + + if (s_ntdllModuleHandle != nullptr) + { + s_pfnRtlDllShutdownInProgress = reinterpret_cast(GetProcAddress(s_ntdllModuleHandle, "RtlDllShutdownInProgress")); + } + } +#endif + + return s_pfnRtlDllShutdownInProgress ? s_pfnRtlDllShutdownInProgress() : FALSE; + } + class ThreadPoolImpl { public: @@ -52,7 +83,10 @@ namespace OS void Submit() noexcept { - SubmitThreadpoolWork(m_work); + if (!RtlDllShutdownInProgress()) + { + SubmitThreadpoolWork(m_work); + } } private: diff --git a/Source/Task/XTaskQueuePriv.h b/Source/Task/XTaskQueuePriv.h index ae9833f9..5e1d90e8 100644 --- a/Source/Task/XTaskQueuePriv.h +++ b/Source/Task/XTaskQueuePriv.h @@ -40,7 +40,6 @@ STDAPI_(void) XTaskQueueResumeTermination( _In_ XTaskQueueHandle queue ) noexcept; -#ifdef HC_UNITTEST_API /// /// This structure can be passed as a pointer to the task queue so unit tests /// can hook into its behavior. Some race conditions are very difficult to get @@ -84,15 +83,13 @@ STDAPI XTaskQueueSetTestHooks( ) noexcept; /// -/// Directly invokes the delayed-callback notification path for unit tests. -/// This is used to model stale threadpool timer callbacks that were already -/// queued before the timer was retargeted. +/// Submits any pending delayed callbacks that are due to run. This is +/// intended for use in unit tests. /// -STDAPI XTaskQueueSubmitPendingCallbackForTests( +STDAPI XTaskQueueSubmitPendingCallbacks( _In_ XTaskQueueHandle queue, _In_ XTaskQueuePort port ) noexcept; -#endif //----------------------------------------------------------------// // diff --git a/Tests/UnitTests/Tests/TaskQueueTests.cpp b/Tests/UnitTests/Tests/TaskQueueTests.cpp index 970fec36..443d8986 100644 --- a/Tests/UnitTests/Tests/TaskQueueTests.cpp +++ b/Tests/UnitTests/Tests/TaskQueueTests.cpp @@ -2390,7 +2390,7 @@ DEFINE_TEST_CLASS(TaskQueueTests) // Simulate a stale delayed-callback notification that was already // queued before the timer was re-armed for secondState. This must not // promote the later pending entry before its own deadline. - VERIFY_SUCCEEDED(XTaskQueueSubmitPendingCallbackForTests(queue, XTaskQueuePort::Work)); + VERIFY_SUCCEEDED(XTaskQueueSubmitPendingCallbacks(queue, XTaskQueuePort::Work)); VERIFY_IS_FALSE(XTaskQueueDispatch(queue, XTaskQueuePort::Work, 0)); VERIFY_IS_FALSE(XTaskQueueDispatch(queue, XTaskQueuePort::Work, 200)); From e513a71e373e8960139b030162724375b9485212 Mon Sep 17 00:00:00 2001 From: Brian Pepin Date: Tue, 19 May 2026 10:14:34 -0700 Subject: [PATCH 2/2] PR feedback --- Source/Task/TaskQueue.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/Source/Task/TaskQueue.cpp b/Source/Task/TaskQueue.cpp index 83a3eb79..da68e4f1 100644 --- a/Source/Task/TaskQueue.cpp +++ b/Source/Task/TaskQueue.cpp @@ -1255,7 +1255,14 @@ void TaskQueuePortImpl::SubmitPendingCallbacks() if (m_timerDue.compare_exchange_weak(expectedDueTime, dueTime)) { m_timer.Start(dueTime); - return; + + // It's possible someone snuck a change into m_timerDue after the CAS + // but before the start call, so we've just written the wrong value to + // the timer. Verify dueTime again before returning. + if (m_timerDue.load() == dueTime) + { + return; + } } continue;