Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions Source/Task/TaskQueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ HRESULT TaskQueuePortImpl::Initialize(
RETURN_IF_FAILED(m_timer.Initialize(this, [](void* context)
{
TaskQueuePortImpl* pthis = static_cast<TaskQueuePortImpl*>(context);
pthis->SubmitPendingCallback();
pthis->SubmitPendingCallbacks();
}));

#ifdef _WIN32
Expand Down Expand Up @@ -1000,15 +1000,13 @@ void TaskQueuePortImpl::CancelPendingEntries(
}
}

#ifdef HC_UNITTEST_API
// Test hook: let unit tests enqueue a sibling delayed callback while this
// termination path still owns the interleaving window that used to race
// with SubmitPendingCallback().
// with SubmitPendingCallbacks().
if (auto hooks = portContext->GetQueue()->GetTestHooks(); hooks != nullptr)
{
hooks->PendingEntriesRemovedDuringTermination(portContext->GetType());
}
#endif

#ifdef _WIN32

Expand Down Expand Up @@ -1169,7 +1167,6 @@ void TaskQueuePortImpl::PromoteReadyPendingCallbacks(
// No future entries remain in the pending list.
uint64_t noDueTime = UINT64_MAX;

#ifdef HC_UNITTEST_API
m_attachedContexts.Visit([&](ITaskQueuePortContext* portContext)
{
auto hooks = portContext->GetQueue()->GetTestHooks();
Expand All @@ -1180,22 +1177,20 @@ void TaskQueuePortImpl::PromoteReadyPendingCallbacks(
dueTime);
}
});
#endif

if (m_timerDue.compare_exchange_strong(dueTime, noDueTime))
{
// Bug fix: ScheduleNextPendingCallback timer race results
// in lost delayed task wakes. Don't cancel the timer here
// as another scheduled callback could have been added.
// The CAS above is sufficient: the timer has already fired
// (call site 1: SubmitPendingCallback) or was already
// (call site 1: SubmitPendingCallbacks) or was already
// canceled (call site 2: CancelPendingEntries). A Cancel()
// here raced with concurrent QueueItem/Start calls on other
// threads, permanently stranding entries in m_pendingList.
// See VerifyDelayedCallbackTimerRaceOnManualQueue for full
// analysis. The test hook here allows unit tests to verify
// there is no race.
#ifdef HC_UNITTEST_API
m_attachedContexts.Visit([&](ITaskQueuePortContext* portContext)
{
auto hooks = portContext->GetQueue()->GetTestHooks();
Expand All @@ -1206,7 +1201,6 @@ void TaskQueuePortImpl::PromoteReadyPendingCallbacks(
noDueTime);
}
});
#endif

// A concurrent QueueItem can append a future entry after our
// sweep has already concluded there is no next item, but before
Expand All @@ -1226,7 +1220,7 @@ void TaskQueuePortImpl::PromoteReadyPendingCallbacks(
}
}

void TaskQueuePortImpl::SubmitPendingCallback()
void TaskQueuePortImpl::SubmitPendingCallbacks()
{
while (true)
{
Expand Down Expand Up @@ -1261,7 +1255,14 @@ void TaskQueuePortImpl::SubmitPendingCallback()
if (m_timerDue.compare_exchange_weak(expectedDueTime, dueTime))
{
m_timer.Start(dueTime);
return;

// It's possible someone snuck a change into m_timerDue after the CAS
// but before the start call, so we've just written the wrong value to
// the timer. Verify dueTime again before returning.
if (m_timerDue.load() == dueTime)
{
return;
}
}

continue;
Expand Down Expand Up @@ -2464,7 +2465,6 @@ STDAPI_(bool) XTaskQueueUninitialize(
return ApiRefs::WaitZeroRefs(timeoutMilliseconds);
}

#ifdef HC_UNITTEST_API
/// <summary>
/// Sets or clears test hooks on a task queue.
/// </summary>
Expand All @@ -2479,7 +2479,11 @@ STDAPI XTaskQueueSetTestHooks(
return S_OK;
}

STDAPI XTaskQueueSubmitPendingCallbackForTests(
/// <summary>
/// Submits any pending delayed callbacks that are due to run. This is
/// intended for use in unit tests.
/// </summary>
STDAPI XTaskQueueSubmitPendingCallbacks(
_In_ XTaskQueueHandle queue,
_In_ XTaskQueuePort port
) noexcept
Expand All @@ -2490,9 +2494,7 @@ STDAPI XTaskQueueSubmitPendingCallbackForTests(
referenced_ptr<ITaskQueuePortContext> portContext;
RETURN_IF_FAILED(aq->GetPortContext(port, portContext.address_of()));

auto* portImpl = static_cast<TaskQueuePortImpl*>(portContext->GetPort());
portImpl->SubmitPendingCallbackForTests();
portContext->GetPort()->SubmitPendingCallbacks();
return S_OK;
}
#endif

13 changes: 1 addition & 12 deletions Source/Task/TaskQueueImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,7 @@ class TaskQueuePortImpl: public Api<ApiId::TaskQueuePort, ITaskQueuePort>
void __stdcall SuspendPort();
void __stdcall ResumePort();

#ifdef HC_UNITTEST_API
void __stdcall SubmitPendingCallbackForTests()
{
SubmitPendingCallback();
}
#endif
void __stdcall SubmitPendingCallbacks();

private:

Expand Down Expand Up @@ -315,8 +310,6 @@ class TaskQueuePortImpl: public Api<ApiId::TaskQueuePort, ITaskQueuePort>
_In_ uint64_t dueTime,
_In_ uint64_t now);

void SubmitPendingCallback();

void SignalTerminations();
void ScheduleTermination(_In_ TerminationEntry* term);
bool TerminationListEmpty();
Expand Down Expand Up @@ -408,10 +401,8 @@ class TaskQueueImpl : public Api<ApiId::TaskQueue, ITaskQueue>
_In_ XTaskQueuePortHandle completionPort);

XTaskQueueHandle __stdcall GetHandle() override { return &m_header; }
#ifdef HC_UNITTEST_API
XTaskQueueTestHooks* __stdcall GetTestHooks() override { return m_testHooks; }
void __stdcall SetTestHooks(_In_ XTaskQueueTestHooks* testHooks) override { m_testHooks = testHooks; }
#endif

HRESULT __stdcall GetPortContext(
_In_ XTaskQueuePort port,
Expand Down Expand Up @@ -483,9 +474,7 @@ class TaskQueueImpl : public Api<ApiId::TaskQueue, ITaskQueue>
TerminationData m_termination;
TaskQueuePortContextImpl m_work;
TaskQueuePortContextImpl m_completion;
#ifdef HC_UNITTEST_API
XTaskQueueTestHooks* m_testHooks = nullptr;
#endif

#ifdef SUSPEND_API
SuspendResumeHandler m_suspendHandler;
Expand Down
3 changes: 1 addition & 2 deletions Source/Task/TaskQueueP.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ struct ITaskQueuePort: IApi
virtual void __stdcall SuspendPort() = 0;
virtual void __stdcall ResumePort() = 0;

virtual void __stdcall SubmitPendingCallbacks() = 0;
};

// The status of a port on the queue. This status is used in
Expand Down Expand Up @@ -125,10 +126,8 @@ struct ITaskQueuePortContext : IApi
struct ITaskQueue : IApi
{
virtual XTaskQueueHandle __stdcall GetHandle() = 0;
#ifdef HC_UNITTEST_API
virtual XTaskQueueTestHooks* __stdcall GetTestHooks() = 0;
virtual void __stdcall SetTestHooks(_In_ XTaskQueueTestHooks* testHooks) = 0;
#endif

virtual HRESULT __stdcall GetPortContext(
_In_ XTaskQueuePort port,
Expand Down
36 changes: 35 additions & 1 deletion Source/Task/ThreadPool_win32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,37 @@

namespace OS
{
// This API is documented but not defined in the headers. Load it dynamcially
// so the world doesn't have to add linkage to ntdll. Ntdll is not unloadable,
// so safe to leak the module ref here.
static inline BOOLEAN __stdcall RtlDllShutdownInProgress() noexcept
{
static decltype(RtlDllShutdownInProgress)* s_pfnRtlDllShutdownInProgress = nullptr;
static HMODULE s_ntdllModuleHandle = nullptr;

// No locking needed -- if these race the threads race to copy
// the same values, and worst case is we get an addl ref on
// a dll we'll never unload anyway. GetModuleHandle is not defined
// for UWP apps, so don't do this safety check for them.

#ifdef GetModuleHandle
if (s_pfnRtlDllShutdownInProgress == nullptr)
{
if (s_ntdllModuleHandle == nullptr)
{
s_ntdllModuleHandle = GetModuleHandleW(L"ntdll.dll");
}

if (s_ntdllModuleHandle != nullptr)
{
s_pfnRtlDllShutdownInProgress = reinterpret_cast<decltype(RtlDllShutdownInProgress)*>(GetProcAddress(s_ntdllModuleHandle, "RtlDllShutdownInProgress"));
}
}
#endif

return s_pfnRtlDllShutdownInProgress ? s_pfnRtlDllShutdownInProgress() : FALSE;
}

class ThreadPoolImpl
{
public:
Expand Down Expand Up @@ -52,7 +83,10 @@ namespace OS

void Submit() noexcept
{
SubmitThreadpoolWork(m_work);
if (!RtlDllShutdownInProgress())
{
SubmitThreadpoolWork(m_work);
}
}

private:
Expand Down
9 changes: 3 additions & 6 deletions Source/Task/XTaskQueuePriv.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ STDAPI_(void) XTaskQueueResumeTermination(
_In_ XTaskQueueHandle queue
) noexcept;

#ifdef HC_UNITTEST_API
/// <summary>
/// This structure can be passed as a pointer to the task queue so unit tests
/// can hook into its behavior. Some race conditions are very difficult to get
Expand Down Expand Up @@ -84,15 +83,13 @@ STDAPI XTaskQueueSetTestHooks(
) noexcept;

/// <summary>
/// Directly invokes the delayed-callback notification path for unit tests.
/// This is used to model stale threadpool timer callbacks that were already
/// queued before the timer was retargeted.
/// Submits any pending delayed callbacks that are due to run. This is
/// intended for use in unit tests.
/// </summary>
STDAPI XTaskQueueSubmitPendingCallbackForTests(
STDAPI XTaskQueueSubmitPendingCallbacks(
_In_ XTaskQueueHandle queue,
_In_ XTaskQueuePort port
) noexcept;
#endif

//----------------------------------------------------------------//
//
Expand Down
2 changes: 1 addition & 1 deletion Tests/UnitTests/Tests/TaskQueueTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2390,7 +2390,7 @@ DEFINE_TEST_CLASS(TaskQueueTests)
// Simulate a stale delayed-callback notification that was already
// queued before the timer was re-armed for secondState. This must not
// promote the later pending entry before its own deadline.
VERIFY_SUCCEEDED(XTaskQueueSubmitPendingCallbackForTests(queue, XTaskQueuePort::Work));
VERIFY_SUCCEEDED(XTaskQueueSubmitPendingCallbacks(queue, XTaskQueuePort::Work));

VERIFY_IS_FALSE(XTaskQueueDispatch(queue, XTaskQueuePort::Work, 0));
VERIFY_IS_FALSE(XTaskQueueDispatch(queue, XTaskQueuePort::Work, 200));
Expand Down