From 6b75f74519c99dc4d605ca740ee88a0ef351eac4 Mon Sep 17 00:00:00 2001 From: "Bernhart, Bryan" Date: Fri, 16 Dec 2022 15:14:59 -0800 Subject: [PATCH] Process async tasks using a reusable thread pool. --- src/gpgmm/common/BUILD.gn | 4 +- src/gpgmm/common/CMakeLists.txt | 4 +- src/gpgmm/common/MemoryAllocator.cpp | 9 +- src/gpgmm/common/MemoryAllocator.h | 3 +- src/gpgmm/common/ThreadPool.cpp | 250 ++++++++++++++++++++++ src/gpgmm/common/ThreadPool.h | 108 ++++++++++ src/gpgmm/common/WorkerThread.cpp | 99 --------- src/gpgmm/common/WorkerThread.h | 90 -------- src/gpgmm/d3d12/ResidencyManagerD3D12.cpp | 15 +- src/gpgmm/d3d12/ResidencyManagerD3D12.h | 3 +- src/tests/BUILD.gn | 1 + src/tests/GPGMMTest.cpp | 1 + src/tests/GPGMMTest.h | 3 + src/tests/unittests/ThreadPoolTests.cpp | 73 +++++++ 14 files changed, 452 insertions(+), 211 deletions(-) create mode 100644 src/gpgmm/common/ThreadPool.cpp create mode 100644 src/gpgmm/common/ThreadPool.h delete mode 100644 src/gpgmm/common/WorkerThread.cpp delete mode 100644 src/gpgmm/common/WorkerThread.h create mode 100644 src/tests/unittests/ThreadPoolTests.cpp diff --git a/src/gpgmm/common/BUILD.gn b/src/gpgmm/common/BUILD.gn index 2566a7c83..d64206b2e 100644 --- a/src/gpgmm/common/BUILD.gn +++ b/src/gpgmm/common/BUILD.gn @@ -213,9 +213,9 @@ source_set("gpgmm_common_sources") { "SlabBlockAllocator.h", "SlabMemoryAllocator.cpp", "SlabMemoryAllocator.h", + "ThreadPool.cpp", + "ThreadPool.h", "TraceEvent.cpp", "TraceEvent.h", - "WorkerThread.cpp", - "WorkerThread.h", ] } diff --git a/src/gpgmm/common/CMakeLists.txt b/src/gpgmm/common/CMakeLists.txt index 960d5774d..e7b6d64f1 100644 --- a/src/gpgmm/common/CMakeLists.txt +++ b/src/gpgmm/common/CMakeLists.txt @@ -52,10 +52,10 @@ target_sources(gpgmm_common PRIVATE "SlabBlockAllocator.h" "SlabMemoryAllocator.cpp" "SlabMemoryAllocator.h" + "ThreadPool.cpp" + "ThreadPool.h" "TraceEvent.cpp" "TraceEvent.h" - "WorkerThread.cpp" - "WorkerThread.h" ) target_link_libraries(gpgmm_common PRIVATE gpgmm_common_config) diff --git a/src/gpgmm/common/MemoryAllocator.cpp b/src/gpgmm/common/MemoryAllocator.cpp index f8c11f6af..cd06fb355 100644 --- a/src/gpgmm/common/MemoryAllocator.cpp +++ b/src/gpgmm/common/MemoryAllocator.cpp @@ -20,8 +20,6 @@ namespace gpgmm { - static constexpr const char* kPrefetchMemoryWorkerThreadName = "GPGMM_ThreadBudgetChangeWorker"; - class AllocateMemoryTask : public VoidCallback { public: AllocateMemoryTask(MemoryAllocator* allocator, const MemoryAllocationRequest& request) @@ -86,11 +84,10 @@ namespace gpgmm { // MemoryAllocator - MemoryAllocator::MemoryAllocator() : mThreadPool(ThreadPool::Create()) { + MemoryAllocator::MemoryAllocator() { } - MemoryAllocator::MemoryAllocator(std::unique_ptr next) - : mThreadPool(ThreadPool::Create()) { + MemoryAllocator::MemoryAllocator(std::unique_ptr next) { InsertIntoChain(std::move(next)); } @@ -126,7 +123,7 @@ namespace gpgmm { std::shared_ptr task = std::make_shared(this, request); return std::make_shared( - ThreadPool::PostTask(mThreadPool, task, kPrefetchMemoryWorkerThreadName), task); + TaskScheduler::GetOrCreateInstance()->PostTask(task), task); } uint64_t MemoryAllocator::ReleaseMemory(uint64_t bytesToRelease) { diff --git a/src/gpgmm/common/MemoryAllocator.h b/src/gpgmm/common/MemoryAllocator.h index 0aac145e8..fec2ccfa8 100644 --- a/src/gpgmm/common/MemoryAllocator.h +++ b/src/gpgmm/common/MemoryAllocator.h @@ -20,7 +20,7 @@ #include "gpgmm/common/Error.h" #include "gpgmm/common/Memory.h" #include "gpgmm/common/MemoryAllocation.h" -#include "gpgmm/common/WorkerThread.h" +#include "gpgmm/common/ThreadPool.h" #include "gpgmm/utils/Assert.h" #include "gpgmm/utils/Limits.h" #include "gpgmm/utils/LinkedList.h" @@ -282,7 +282,6 @@ namespace gpgmm { MemoryAllocatorStats mStats = {}; mutable std::mutex mMutex; - std::shared_ptr mThreadPool; private: MemoryAllocator* mNext = nullptr; diff --git a/src/gpgmm/common/ThreadPool.cpp b/src/gpgmm/common/ThreadPool.cpp new file mode 100644 index 000000000..a1cbf525f --- /dev/null +++ b/src/gpgmm/common/ThreadPool.cpp @@ -0,0 +1,250 @@ +// Copyright 2022 The GPGMM Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gpgmm/common/ThreadPool.h" + +#include "gpgmm/common/TraceEvent.h" +#include "gpgmm/utils/Assert.h" +#include "gpgmm/utils/PlatformUtils.h" +#include "gpgmm/utils/Utils.h" + +#include +#include +#include +#include +#include + +namespace gpgmm { + + static constexpr const char* kBackgroundThreadName = "GPGMM Background Thread"; + + // Limit minimum running threads: one for allocation and another for budgeting. + static constexpr uint32_t kMinThreadCount = 2u; + + using AsyncTask = std::pair, std::shared_ptr>; + + class AsyncEventImpl final : public Event { + public: + AsyncEventImpl() = default; + + void Wait() override { + TRACE_EVENT0(TraceEventCategory::kDefault, "Event.Wait"); + + std::unique_lock lock(mMutex); + mCondition.wait(lock, [this] { return mIsSignaled; }); + } + + bool IsSignaled() override { + std::unique_lock lock(mMutex); + return mIsSignaled; + } + + void Signal() override { + { + std::unique_lock lock(mMutex); + mIsSignaled = true; + } + mCondition.notify_all(); + } + + private: + std::mutex mMutex; + std::condition_variable mCondition; + bool mIsSignaled = false; + }; + + class AsyncTaskThreadPoolImpl final : public ThreadPool { + public: + AsyncTaskThreadPoolImpl(uint32_t minThreadCount, uint32_t maxThreadCount) { + ASSERT(minThreadCount <= maxThreadCount); + + // Create the storage upfront so Resize can't modify the storage in-use by an existing + // std::thread. + mThreads.reserve(maxThreadCount); + + Resize(minThreadCount); + } + + ~AsyncTaskThreadPoolImpl() override { + Shutdown(); + } + + uint32_t GetCurrentThreadCount() const override { + return mThreads.size(); + } + + uint32_t GetMaxThreadCount() const override { + return mThreads.capacity(); + } + + void Resize(uint32_t threadCount) override { + const uint32_t numThreadsToCreate = (GetCurrentThreadCount() < threadCount) + ? threadCount - GetCurrentThreadCount() + : 0u; + + if (numThreadsToCreate + GetCurrentThreadCount() > GetMaxThreadCount()) { + return; + } + + for (uint32_t threadIndex = 0; threadIndex < numThreadsToCreate; ++threadIndex) { + // Concat the assigned thread index to the name for debugging. + std::string threadNameWithIndex(kBackgroundThreadName); + threadNameWithIndex += " "; + threadNameWithIndex += std::to_string(mThreadIndex++); + + mThreads.push_back(std::thread([this, threadNameWithIndex]() { + SetThreadName(threadNameWithIndex.c_str()); + TRACE_EVENT_METADATA1(TraceEventCategory::kMetadata, "thread_name", "name", + threadNameWithIndex.c_str()); + RunExecutionLoop(); + })); + } + } + + bool HasTasksToExecute() override { + bool hasTasksToExecute; + { + std::unique_lock lock(mQueueMutex); + hasTasksToExecute = !mTaskQueue.empty(); + } + return hasTasksToExecute; + } + + void Shutdown() override { + if (mThreads.size() == 0) { + return; + } + + // Inform thread to terminate after it finishes the current job, if any in progress. + { + std::unique_lock lock(mQueueMutex); + mStopQueueProcessingTasks = true; + } + + // Wait for the threads to terminate. + mQueueCondition.notify_all(); + for (std::thread& thread : mThreads) { + thread.join(); + } + + mThreads.clear(); + } + + private: + std::shared_ptr postTaskImpl(std::shared_ptr callback) override { + std::shared_ptr event = std::make_shared(); + { + std::unique_lock lock(mQueueMutex); + mTaskQueue.push(std::make_pair(callback, event)); + } + mQueueCondition.notify_one(); + return event; + } + + void RunExecutionLoop() { + for (;;) { + AsyncTask task; + { + std::unique_lock lock(mQueueMutex); + mQueueCondition.wait( + lock, [this] { return !mTaskQueue.empty() || mStopQueueProcessingTasks; }); + if (mStopQueueProcessingTasks) { + return; + } + task = mTaskQueue.front(); + mTaskQueue.pop(); + } + (*task.first)(); // Execute + task.second->Signal(); + } + } + + uint32_t mThreadIndex = 0; + std::vector mThreads; + + std::mutex mQueueMutex; // Protects access for below members. + std::condition_variable mQueueCondition; // Allow threads to wait on new tasks. + std::queue mTaskQueue; + bool mStopQueueProcessingTasks = false; + }; + + // Event + + void Event::SetThreadPool(std::shared_ptr pool) { + mPool = pool; + } + + // ThreadPool + + // static + std::shared_ptr ThreadPool::Create(uint32_t minThreadCount, + uint32_t maxThreadCount) { + return std::shared_ptr( + new AsyncTaskThreadPoolImpl(minThreadCount, maxThreadCount)); + } + + // static + std::shared_ptr ThreadPool::PostTask(std::shared_ptr pool, + std::shared_ptr task) { + // Grow the pool only when tasks need processing and the thread limit hasn't been reached. + const uint32_t currentThreadCount = pool->GetCurrentThreadCount(); + if (currentThreadCount == 0 || + (pool->HasTasksToExecute() && currentThreadCount < pool->GetMaxThreadCount())) { + pool->Resize(currentThreadCount + 1); + } + + // Ensure the pool is able to process the returned event by ensuring the event cannot + // outlive it. + std::shared_ptr event = pool->postTaskImpl(task); + if (event != nullptr) { + event->SetThreadPool(pool); + } + + return event; + } + + // TaskScheduler + + static TaskScheduler* sTaskScheduler = nullptr; + static std::mutex sTaskSchedulerAccessMutex; + + TaskScheduler::TaskScheduler() + : mThreadPool(new AsyncTaskThreadPoolImpl( + kMinThreadCount, + std::max(kMinThreadCount, std::thread::hardware_concurrency()))) { + } + + // static + TaskScheduler* TaskScheduler::GetOrCreateInstance() { + std::lock_guard lock(sTaskSchedulerAccessMutex); + if (!sTaskScheduler) { + sTaskScheduler = new TaskScheduler(); + } + return sTaskScheduler; + } + + // static + void TaskScheduler::ReleaseInstanceForTesting() { + std::lock_guard lock(sTaskSchedulerAccessMutex); + if (sTaskScheduler) { + SafeDelete(sTaskScheduler); + } + } + + std::shared_ptr TaskScheduler::PostTask(std::shared_ptr task) { + ASSERT(mThreadPool); + return mThreadPool->PostTask(mThreadPool, task); + } + +} // namespace gpgmm diff --git a/src/gpgmm/common/ThreadPool.h b/src/gpgmm/common/ThreadPool.h new file mode 100644 index 000000000..da735a261 --- /dev/null +++ b/src/gpgmm/common/ThreadPool.h @@ -0,0 +1,108 @@ +// Copyright 2022 The GPGMM Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GPGMM_COMMON_THREADPOOL_H_ +#define GPGMM_COMMON_THREADPOOL_H_ + +#include "gpgmm/utils/NonCopyable.h" + +#include + +namespace gpgmm { + + class VoidCallback : public NonCopyable { + public: + virtual ~VoidCallback() = default; + + // Define operator () that accepts no parameters (). + virtual void operator()() = 0; + }; + + class ThreadPool; + + // An event that we can wait on. + // Used for waiting for results or joining worker threads. + class Event : public NonCopyable { + public: + Event() = default; + virtual ~Event() = default; + + // Wait for the event to complete. + // Blocks the calling thread indefinitely until the event gets signaled. + virtual void Wait() = 0; + + // Check if event was signaled. + // Event will be in signaled state once the event is completed. + virtual bool IsSignaled() = 0; + + // Signals the event is ready. + // If ready, wait() will not block. + virtual void Signal() = 0; + + // Associates a thread pool with this event. + void SetThreadPool(std::shared_ptr pool); + + private: + std::shared_ptr mPool; + }; + + // Collection of threads that can process tasks as function call-backs. + class ThreadPool : public NonCopyable { + public: + ThreadPool() = default; + virtual ~ThreadPool() = default; + + // Creates a pool with up to |maxThreadCount| threads. + static std::shared_ptr Create(uint32_t minThreadCount, uint32_t maxThreadCount); + + static std::shared_ptr PostTask(std::shared_ptr pool, + std::shared_ptr task); + + // Returns True if threads in the pool have tasks to execute. + virtual bool HasTasksToExecute() = 0; + + // Tells the pool to stop processing more tasks and to exit threads. + virtual void Shutdown() = 0; + + // Returns the number of running threads in the pool. + virtual uint32_t GetCurrentThreadCount() const = 0; + + // Returns the maximum number of running threads allowed in the pool. + virtual uint32_t GetMaxThreadCount() const = 0; + + // Expands the size of the pool, to the specified number of threads. + virtual void Resize(uint32_t threadCount) = 0; + + private: + // Return event to wait on until the callback runs. + virtual std::shared_ptr postTaskImpl(std::shared_ptr task) = 0; + }; + + // Singleton class to process tasks using a single thread pool. + class TaskScheduler : public NonCopyable { + public: + static TaskScheduler* GetOrCreateInstance(); + std::shared_ptr PostTask(std::shared_ptr task); + + static void ReleaseInstanceForTesting(); + + private: + TaskScheduler(); + + std::shared_ptr mThreadPool; + }; + +} // namespace gpgmm + +#endif // GPGMM_COMMON_THREADPOOL_H_ diff --git a/src/gpgmm/common/WorkerThread.cpp b/src/gpgmm/common/WorkerThread.cpp deleted file mode 100644 index 14833e476..000000000 --- a/src/gpgmm/common/WorkerThread.cpp +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2022 The GPGMM Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "gpgmm/common/WorkerThread.h" - -#include "gpgmm/common/TraceEvent.h" -#include "gpgmm/utils/PlatformUtils.h" - -#include -#include -#include - -namespace gpgmm { - - class AsyncEventImpl final : public Event { - public: - AsyncEventImpl() = default; - - void Wait() override { - TRACE_EVENT0(TraceEventCategory::kDefault, "AsyncEventImpl.Wait"); - - std::unique_lock lock(mMutex); - mCondition.wait(lock, [this] { return mIsSignaled; }); - } - - bool IsSignaled() override { - std::unique_lock lock(mMutex); - return mIsSignaled; - } - - void Signal() override { - { - std::unique_lock lock(mMutex); - mIsSignaled = true; - } - mCondition.notify_all(); - } - - private: - std::mutex mMutex; - std::condition_variable mCondition; - bool mIsSignaled = false; - }; - - class AsyncThreadPoolImpl final : public ThreadPool { - public: - AsyncThreadPoolImpl() = default; - ~AsyncThreadPoolImpl() override = default; - - std::shared_ptr postTaskImpl(std::shared_ptr callback, - const char* name) override { - std::shared_ptr event = std::make_shared(); - std::thread thread([callback, event, name]() { - SetThreadName(name); - TRACE_EVENT_METADATA1(TraceEventCategory::kMetadata, "thread_name", "name", name); - (*callback)(); - event->Signal(); - }); - thread.detach(); - return event; - } - }; - - // Event - - void Event::SetThreadPool(std::shared_ptr pool) { - mPool = pool; - } - - // ThreadPool - - // static - std::shared_ptr ThreadPool::Create() { - return std::shared_ptr(new AsyncThreadPoolImpl()); - } - - // static - std::shared_ptr ThreadPool::PostTask(std::shared_ptr pool, - std::shared_ptr callback, - const char* name) { - std::shared_ptr event = pool->postTaskImpl(callback, name); - if (event != nullptr) { - event->SetThreadPool(pool); - } - return event; - } - -} // namespace gpgmm diff --git a/src/gpgmm/common/WorkerThread.h b/src/gpgmm/common/WorkerThread.h deleted file mode 100644 index 4a892700b..000000000 --- a/src/gpgmm/common/WorkerThread.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2022 The GPGMM Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GPGMM_COMMON_WORKERTHREAD_H_ -#define GPGMM_COMMON_WORKERTHREAD_H_ - -#include "gpgmm/utils/NonCopyable.h" - -#include - -namespace gpgmm { - - class VoidCallback : public NonCopyable { - public: - virtual ~VoidCallback() = default; - - // Define operator () that accepts no parameters (). - virtual void operator()() = 0; - }; - - class ThreadPool; - - /** \brief An event that we can wait on. - - Used for waiting for results or joining worker threads. - */ - class Event : public NonCopyable { - public: - Event() = default; - virtual ~Event() = default; - - /** \brief Wait for the event to complete. - - Wait blocks the calling thread indefinitely until the event gets signaled. - */ - virtual void Wait() = 0; - - /** \brief Check if event was signaled. - - Event will be in signaled state once the event is completed. - */ - virtual bool IsSignaled() = 0; - - /** \brief Signals the event is ready. - - If ready, wait() will not block. - */ - virtual void Signal() = 0; - - /** \brief Associates a thread pool with this event. - - @param pool Shared pointer to the thread pool this event belongs with. - */ - void SetThreadPool(std::shared_ptr pool); - - private: - std::shared_ptr mPool; - }; - - class ThreadPool : public NonCopyable { - public: - ThreadPool() = default; - virtual ~ThreadPool() = default; - - static std::shared_ptr Create(); - - static std::shared_ptr PostTask(std::shared_ptr pool, - std::shared_ptr callback, - const char* name); - - private: - // Return event to wait on until the callback runs. - virtual std::shared_ptr postTaskImpl(std::shared_ptr callback, - const char* name) = 0; - }; - -} // namespace gpgmm - -#endif // GPGMM_COMMON_WORKERTHREAD_H_ diff --git a/src/gpgmm/d3d12/ResidencyManagerD3D12.cpp b/src/gpgmm/d3d12/ResidencyManagerD3D12.cpp index 398c859cb..befc8bb92 100644 --- a/src/gpgmm/d3d12/ResidencyManagerD3D12.cpp +++ b/src/gpgmm/d3d12/ResidencyManagerD3D12.cpp @@ -17,8 +17,8 @@ #include "gpgmm/common/EventMessage.h" #include "gpgmm/common/SizeClass.h" +#include "gpgmm/common/ThreadPool.h" #include "gpgmm/common/TraceEvent.h" -#include "gpgmm/common/WorkerThread.h" #include "gpgmm/d3d12/CapsD3D12.h" #include "gpgmm/d3d12/ErrorD3D12.h" #include "gpgmm/d3d12/FenceD3D12.h" @@ -36,7 +36,6 @@ namespace gpgmm::d3d12 { static constexpr uint64_t kDefaultEvictSizeInBytes = GPGMM_MB_TO_BYTES(50); static constexpr float kDefaultMaxPctOfVideoMemoryToBudget = 0.95f; // 95% static constexpr float kDefaultMinPctOfBudgetToReserve = 0.50f; // 50% - static constexpr const char* kBudgetChangeWorkerThreadName = "GPGMM_ThreadBudgetChangeWorker"; // Creates a long-lived task to recieve and process OS budget change events. class BudgetUpdateTask : public VoidCallback { @@ -252,7 +251,7 @@ namespace gpgmm::d3d12 { // Dump out the initialized memory segment status. residencyManager->ReportSegmentInfoForTesting(DXGI_MEMORY_SEGMENT_GROUP_LOCAL); - if (!residencyManager->mIsUMA){ + if (!residencyManager->mIsUMA) { residencyManager->ReportSegmentInfoForTesting(DXGI_MEMORY_SEGMENT_GROUP_NON_LOCAL); } @@ -283,8 +282,7 @@ namespace gpgmm::d3d12 { RESIDENCY_FLAG_NEVER_UPDATE_BUDGET_ON_WORKER_THREAD), mFlushEventBuffersOnDestruct(descriptor.RecordOptions.EventScope & EVENT_RECORD_SCOPE_PER_INSTANCE), - mResidencyFence(std::move(residencyFence)), - mThreadPool(ThreadPool::Create()) { + mResidencyFence(std::move(residencyFence)) { GPGMM_TRACE_EVENT_OBJECT_NEW(this); ASSERT(mDevice != nullptr); @@ -863,7 +861,7 @@ namespace gpgmm::d3d12 { std::shared_ptr task = std::make_shared(this, mAdapter); mBudgetNotificationUpdateEvent = std::make_shared( - ThreadPool::PostTask(mThreadPool, task, kBudgetChangeWorkerThreadName), task); + TaskScheduler::GetOrCreateInstance()->PostTask(task), task); } ASSERT(mBudgetNotificationUpdateEvent != nullptr); @@ -895,8 +893,9 @@ namespace gpgmm::d3d12 { void ResidencyManager::ReportSegmentInfoForTesting(DXGI_MEMORY_SEGMENT_GROUP segmentGroup) { DXGI_QUERY_VIDEO_MEMORY_INFO* info = GetVideoMemoryInfo(segmentGroup); ASSERT(info != nullptr); - - gpgmm::DebugLog() << "GPU memory segment status (" << GetMemorySegmentName(segmentGroup, IsUMA()) << "):"; + + gpgmm::DebugLog() << "GPU memory segment status (" + << GetMemorySegmentName(segmentGroup, IsUMA()) << "):"; gpgmm::DebugLog() << "\tBudget: " << GPGMM_BYTES_TO_MB(info->Budget) << " MBs (" << GPGMM_BYTES_TO_MB(info->CurrentUsage) << " used)."; gpgmm::DebugLog() << "\tReserved: " << GPGMM_BYTES_TO_MB(info->CurrentReservation) diff --git a/src/gpgmm/d3d12/ResidencyManagerD3D12.h b/src/gpgmm/d3d12/ResidencyManagerD3D12.h index 185a46371..1b263dda3 100644 --- a/src/gpgmm/d3d12/ResidencyManagerD3D12.h +++ b/src/gpgmm/d3d12/ResidencyManagerD3D12.h @@ -26,7 +26,7 @@ #include namespace gpgmm { - class ThreadPool; + class TaskScheduler; } // namespace gpgmm namespace gpgmm::d3d12 { @@ -135,7 +135,6 @@ namespace gpgmm::d3d12 { VideoMemorySegment mNonLocalVideoMemorySegment; RESIDENCY_STATS mStats = {}; - std::shared_ptr mThreadPool; std::shared_ptr mBudgetNotificationUpdateEvent; }; diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn index cdc399653..c38203a29 100644 --- a/src/tests/BUILD.gn +++ b/src/tests/BUILD.gn @@ -120,6 +120,7 @@ test("gpgmm_unittests") { "unittests/SlabBlockAllocatorTests.cpp", "unittests/SlabMemoryAllocatorTests.cpp", "unittests/StableListTests.cpp", + "unittests/ThreadPoolTests.cpp", "unittests/UtilsTest.cpp", ] diff --git a/src/tests/GPGMMTest.cpp b/src/tests/GPGMMTest.cpp index cfc1a3a29..4c14fa1b5 100644 --- a/src/tests/GPGMMTest.cpp +++ b/src/tests/GPGMMTest.cpp @@ -32,6 +32,7 @@ void GPGMMTestBase::SetUp() { } void GPGMMTestBase::TearDown() { + gpgmm::TaskScheduler::ReleaseInstanceForTesting(); } gpgmm::DebugPlatform* GPGMMTestBase::GetDebugPlatform() { diff --git a/src/tests/GPGMMTest.h b/src/tests/GPGMMTest.h index bfaa64ef2..c26c922c3 100644 --- a/src/tests/GPGMMTest.h +++ b/src/tests/GPGMMTest.h @@ -37,8 +37,11 @@ GetDebugPlatform()->StartMemoryCheck(); \ } while (0) +// TearDown must be called before the end check since process-wide memory +// created by the test needs a chance to release to avoid become a false-positive. #define GPGMM_TEST_MEMORY_LEAK_END() \ do { \ + TearDown(); \ EXPECT_FALSE(GetDebugPlatform()->EndMemoryCheck()); \ } while (0) diff --git a/src/tests/unittests/ThreadPoolTests.cpp b/src/tests/unittests/ThreadPoolTests.cpp new file mode 100644 index 000000000..4641b980a --- /dev/null +++ b/src/tests/unittests/ThreadPoolTests.cpp @@ -0,0 +1,73 @@ +// Copyright 2022 The GPGMM Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gpgmm/common/ThreadPool.h" +#include "gpgmm/utils/Assert.h" +#include "tests/GPGMMTest.h" + +#include +#include + +using namespace gpgmm; + +class Task : public VoidCallback { + public: + void operator()() { + ASSERT(true); + } +}; + +TEST(ThreadPoolTests, Create) { + std::shared_ptr pool = + ThreadPool::Create(/*minThreadCount*/ 0, /*maxThreadCount*/ 0); + EXPECT_NE(pool, nullptr); + EXPECT_EQ(pool->GetCurrentThreadCount(), 0u); + EXPECT_EQ(pool->GetMaxThreadCount(), 0u); +} + +TEST(ThreadPoolTests, SingleTask) { + std::shared_ptr pool = + ThreadPool::Create(/*minThreadCount*/ 0, /*maxThreadCount*/ 1); + EXPECT_NE(pool, nullptr); + + auto event = ThreadPool::PostTask(pool, std::make_shared()); + EXPECT_NE(event, nullptr); + EXPECT_EQ(pool->GetCurrentThreadCount(), 1u); + EXPECT_EQ(pool->GetMaxThreadCount(), 1u); + + event->Wait(); + EXPECT_TRUE(event->IsSignaled()); + + pool->Shutdown(); + EXPECT_EQ(pool->GetCurrentThreadCount(), 0u); +} + +TEST(ThreadPoolTests, ManyTasks) { + std::shared_ptr pool = + ThreadPool::Create(/*minThreadCount*/ 0, /*maxThreadCount*/ 2); + + constexpr uint32_t kMaxTaskCount = 10000u; + for (uint32_t numOfTasks = 0; numOfTasks < kMaxTaskCount; numOfTasks++) { + std::shared_ptr task = std::make_shared(); + EXPECT_NE(ThreadPool::PostTask(pool, task), nullptr); + } + + EXPECT_GT(pool->GetCurrentThreadCount(), 0u); + EXPECT_EQ(pool->GetMaxThreadCount(), 2u); + + pool->Shutdown(); + EXPECT_EQ(pool->GetCurrentThreadCount(), 0u); +}