From d73a8aad9676ca601c4d882556f41c61af614121 Mon Sep 17 00:00:00 2001 From: Bryan Bernhart Date: Thu, 7 Jul 2022 09:21:29 -0700 Subject: [PATCH] Make ScopedTraceBufferInTLS thread safe. Prevents unsafe access to trace event buffer when using multiple threads. --- src/gpgmm/common/EventTraceWriter.cpp | 58 ++++++++++++------- src/gpgmm/common/EventTraceWriter.h | 14 ++--- src/gpgmm/common/TraceEvent.cpp | 12 ++-- src/gpgmm/common/TraceEvent.h | 3 +- src/gpgmm/d3d12/ResidencyManagerD3D12.cpp | 3 +- src/gpgmm/d3d12/ResourceAllocatorD3D12.cpp | 6 +- src/tests/unittests/EventTraceWriterTests.cpp | 2 +- 7 files changed, 52 insertions(+), 46 deletions(-) diff --git a/src/gpgmm/common/EventTraceWriter.cpp b/src/gpgmm/common/EventTraceWriter.cpp index 056d51e31..30fb87f65 100644 --- a/src/gpgmm/common/EventTraceWriter.cpp +++ b/src/gpgmm/common/EventTraceWriter.cpp @@ -31,20 +31,36 @@ namespace gpgmm { // Trace buffer that flushes and unlinks itself from the cache once destroyed. class ScopedTraceBufferInTLS { public: - ScopedTraceBufferInTLS(EventTraceWriter* writer) : mWriter(writer) { - ASSERT(writer != nullptr); + ScopedTraceBufferInTLS(std::shared_ptr writer) + : mWriter(std::move(writer)) { + ASSERT(mWriter != nullptr); } ~ScopedTraceBufferInTLS() { - mWriter->FlushAndRemoveBufferEntry(GetBuffer()); + mWriter->FlushAndRemoveBufferEntry(&mBuffer); } - std::vector* GetBuffer() { - return &mBuffer; + void AddEvent(const TraceEvent& event) { + std::unique_lock lock(mMutex); + mBuffer.push_back(event); + } + + std::vector GetAndClearBuffer() { + std::unique_lock lock(mMutex); + std::vector tmp = mBuffer; + mBuffer.clear(); + return tmp; + } + + size_t GetBufferSize() const { + std::unique_lock lock(mMutex); + return mBuffer.size(); } private: - EventTraceWriter* mWriter = nullptr; + std::shared_ptr mWriter; + + mutable std::mutex mMutex; // Protect access for members below. std::vector mBuffer; }; @@ -53,20 +69,17 @@ namespace gpgmm { } void EventTraceWriter::SetConfiguration(const std::string& traceFile, - const TraceEventPhase& ignoreMask, - bool flushOnDestruct) { + const TraceEventPhase& ignoreMask) { mTraceFile = (traceFile.empty()) ? mTraceFile : traceFile; mIgnoreMask = ignoreMask; - mFlushOnDestruct = flushOnDestruct; } EventTraceWriter::~EventTraceWriter() { - if (mFlushOnDestruct) { - FlushQueuedEventsToDisk(); - } + FlushQueuedEventsToDisk(); } - void EventTraceWriter::EnqueueTraceEvent(char phase, + void EventTraceWriter::EnqueueTraceEvent(std::shared_ptr writer, + char phase, TraceEventCategory category, const char* name, uint64_t id, @@ -75,8 +88,8 @@ namespace gpgmm { const double timestampInSeconds = mPlatformTime->GetRelativeTime(); const uint32_t threadID = std::stoi(ToString(std::this_thread::get_id())); if (timestampInSeconds != 0) { - GetOrCreateBufferFromTLS()->push_back( - {phase, category, name, id, threadID, timestampInSeconds, flags, args}); + GetOrCreateBufferFromTLS(std::move(writer)) + ->AddEvent({phase, category, name, id, threadID, timestampInSeconds, flags, args}); } } @@ -201,16 +214,17 @@ namespace gpgmm { DebugLog() << "Flushed " << mergedBuffer.size() << " events to disk."; } - std::vector* EventTraceWriter::GetOrCreateBufferFromTLS() { + ScopedTraceBufferInTLS* EventTraceWriter::GetOrCreateBufferFromTLS( + std::shared_ptr writer) { thread_local std::unique_ptr bufferInTLS; if (bufferInTLS == nullptr) { - bufferInTLS.reset(new ScopedTraceBufferInTLS(this)); + bufferInTLS.reset(new ScopedTraceBufferInTLS(std::move(writer))); std::lock_guard mutex(mMutex); mBufferPerThread[std::this_thread::get_id()] = bufferInTLS.get(); } ASSERT(bufferInTLS != nullptr); - return bufferInTLS->GetBuffer(); + return bufferInTLS.get(); } void EventTraceWriter::FlushAndRemoveBufferEntry(std::vector* buffer) { @@ -227,9 +241,9 @@ namespace gpgmm { mUnmergedBuffer.clear(); for (auto& bufferOfThread : mBufferPerThread) { - std::vector* bufferToMerge = bufferOfThread.second->GetBuffer(); - mergedBuffer.insert(mergedBuffer.end(), bufferToMerge->begin(), bufferToMerge->end()); - bufferToMerge->clear(); + std::vector bufferToMerge = bufferOfThread.second->GetAndClearBuffer(); + mergedBuffer.insert(mergedBuffer.end(), bufferToMerge.begin(), bufferToMerge.end()); + bufferToMerge.clear(); } return mergedBuffer; } @@ -239,7 +253,7 @@ namespace gpgmm { size_t numOfEvents = 0; numOfEvents += mUnmergedBuffer.size(); for (auto& bufferOfThread : mBufferPerThread) { - numOfEvents += bufferOfThread.second->GetBuffer()->size(); + numOfEvents += bufferOfThread.second->GetBufferSize(); } return numOfEvents; } diff --git a/src/gpgmm/common/EventTraceWriter.h b/src/gpgmm/common/EventTraceWriter.h index 7520a3eb9..d8c855875 100644 --- a/src/gpgmm/common/EventTraceWriter.h +++ b/src/gpgmm/common/EventTraceWriter.h @@ -32,11 +32,10 @@ namespace gpgmm { EventTraceWriter(); ~EventTraceWriter(); - void SetConfiguration(const std::string& traceFile, - const TraceEventPhase& ignoreMask, - bool flushOnDestruct); + void SetConfiguration(const std::string& traceFile, const TraceEventPhase& ignoreMask); - void EnqueueTraceEvent(char phase, + void EnqueueTraceEvent(std::shared_ptr writer, + char phase, TraceEventCategory category, const char* name, uint64_t id, @@ -50,14 +49,13 @@ namespace gpgmm { size_t GetQueuedEventsForTesting() const; private: - std::vector* GetOrCreateBufferFromTLS(); + ScopedTraceBufferInTLS* GetOrCreateBufferFromTLS(std::shared_ptr writer); std::vector MergeAndClearBuffers(); std::string mTraceFile; - std::unique_ptr mPlatformTime; - TraceEventPhase mIgnoreMask; - bool mFlushOnDestruct = true; + + std::unique_ptr mPlatformTime; mutable std::mutex mMutex; diff --git a/src/gpgmm/common/TraceEvent.cpp b/src/gpgmm/common/TraceEvent.cpp index 990151791..ec94cc72e 100644 --- a/src/gpgmm/common/TraceEvent.cpp +++ b/src/gpgmm/common/TraceEvent.cpp @@ -22,26 +22,24 @@ namespace gpgmm { - static std::unique_ptr gEventTrace; + static std::shared_ptr gEventTrace; static std::mutex mMutex; static EventTraceWriter* GetInstance() { std::lock_guard lock(mMutex); if (gEventTrace == nullptr) { - gEventTrace = std::make_unique(); + gEventTrace = std::make_shared(); } return gEventTrace.get(); } - void StartupEventTrace(const std::string& traceFile, - const TraceEventPhase& ignoreMask, - bool flushOnDestruct) { + void StartupEventTrace(const std::string& traceFile, const TraceEventPhase& ignoreMask) { #if defined(GPGMM_DISABLE_TRACING) gpgmm::WarningLog() << "Event tracing enabled but unable to record due to GPGMM_DISABLE_TRACING."; #endif - GetInstance()->SetConfiguration(traceFile, ignoreMask, flushOnDestruct); + GetInstance()->SetConfiguration(traceFile, ignoreMask); TRACE_EVENT_METADATA1(TraceEventCategory::Metadata, "thread_name", "name", "GPGMM_MainThread"); } @@ -90,7 +88,7 @@ namespace gpgmm { uint32_t flags, const JSONDict& args) { if (IsEventTraceEnabled()) { - GetInstance()->EnqueueTraceEvent(phase, category, name, id, flags, args); + GetInstance()->EnqueueTraceEvent(gEventTrace, phase, category, name, id, flags, args); } } } // namespace gpgmm diff --git a/src/gpgmm/common/TraceEvent.h b/src/gpgmm/common/TraceEvent.h index 835f47d68..fccf4926e 100644 --- a/src/gpgmm/common/TraceEvent.h +++ b/src/gpgmm/common/TraceEvent.h @@ -176,8 +176,7 @@ namespace gpgmm { class PlatformTime; void StartupEventTrace(const std::string& traceFile, - const TraceEventPhase& ignoreMask, - bool flushOnDestruct); + const TraceEventPhase& ignoreMask); void FlushEventTraceToDisk(); diff --git a/src/gpgmm/d3d12/ResidencyManagerD3D12.cpp b/src/gpgmm/d3d12/ResidencyManagerD3D12.cpp index dd6af6015..b9da06b27 100644 --- a/src/gpgmm/d3d12/ResidencyManagerD3D12.cpp +++ b/src/gpgmm/d3d12/ResidencyManagerD3D12.cpp @@ -164,8 +164,7 @@ namespace gpgmm::d3d12 { if (descriptor.RecordOptions.Flags != EVENT_RECORD_FLAG_NONE) { StartupEventTrace(descriptor.RecordOptions.TraceFile, - static_cast(~descriptor.RecordOptions.Flags | 0), - descriptor.RecordOptions.EventScope & EVENT_RECORD_SCOPE_PER_PROCESS); + static_cast(~descriptor.RecordOptions.Flags | 0)); SetEventMessageLevel(GetLogSeverity(descriptor.RecordOptions.MinMessageLevel)); } diff --git a/src/gpgmm/d3d12/ResourceAllocatorD3D12.cpp b/src/gpgmm/d3d12/ResourceAllocatorD3D12.cpp index 37da8bc2d..bba197ede 100644 --- a/src/gpgmm/d3d12/ResourceAllocatorD3D12.cpp +++ b/src/gpgmm/d3d12/ResourceAllocatorD3D12.cpp @@ -342,10 +342,8 @@ namespace gpgmm::d3d12 { if (pResidencyManager == nullptr && newDescriptor.RecordOptions.Flags != EVENT_RECORD_FLAG_NONE) { - StartupEventTrace( - allocatorDescriptor.RecordOptions.TraceFile, - static_cast(~newDescriptor.RecordOptions.Flags | 0), - allocatorDescriptor.RecordOptions.EventScope & EVENT_RECORD_SCOPE_PER_PROCESS); + StartupEventTrace(allocatorDescriptor.RecordOptions.TraceFile, + static_cast(~newDescriptor.RecordOptions.Flags | 0)); SetEventMessageLevel(GetLogSeverity(newDescriptor.RecordOptions.MinMessageLevel)); } else { diff --git a/src/tests/unittests/EventTraceWriterTests.cpp b/src/tests/unittests/EventTraceWriterTests.cpp index 0b4c4212a..46e8b8d4a 100644 --- a/src/tests/unittests/EventTraceWriterTests.cpp +++ b/src/tests/unittests/EventTraceWriterTests.cpp @@ -26,7 +26,7 @@ using namespace gpgmm; class EventTraceWriterTests : public testing::Test { public: void SetUp() override { - StartupEventTrace(kDummyTrace, TraceEventPhase::None, /*flushOnDestruct*/ true); + StartupEventTrace(kDummyTrace, TraceEventPhase::None); } void TearDown() override {