Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chenta/avoid thread local #13003

Merged
merged 9 commits into from Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/onnxruntime/core/framework/stream_handles.h
Expand Up @@ -47,6 +47,7 @@ struct Stream {
return {};
};
virtual void Flush(){};
virtual Status CleanUpOnRunEnd() = 0;
};

namespace synchronize {
Expand Down Expand Up @@ -91,4 +92,4 @@ class IStreamCommandHandleRegistry {
virtual void RegisterCreateStreamFn(const OrtDevice::DeviceType device_type, CreateStreamFn f) = 0;
};

}
} // namespace onnxruntime
51 changes: 25 additions & 26 deletions onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc
Expand Up @@ -257,34 +257,33 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
disable_compact_memory);
auto workspace_buffer = GetScratchBuffer<void>(workSpaceSize, OrtStream(context));
ORT_RETURN_IF_ERROR(LaunchLongformerAttentionKernel(
device_prop,
cublas,
stream,
reinterpret_cast<const CudaT*>(gemm_buffer.get()),
reinterpret_cast<const CudaT*>(bias->Data<T>()),
reinterpret_cast<const CudaT*>(attention_mask->Data<T>()),
reinterpret_cast<const CudaT*>(global_gemm_buffer),
reinterpret_cast<const CudaT*>(global_bias->Data<T>()),
global_attention_mask->Data<int>(),
global_index_buffer.get(),
batch_global_num_buffer.get(),
pinned_buffer.get(),
workspace_buffer.get(),
output->MutableData<T>(),
batch_size,
sequence_length,
num_heads_,
head_size,
window_,
max_num_global,
element_size,
disable_compact_memory,
use_merged_qkv_weights,
use_half4_))
;
device_prop,
cublas,
stream,
reinterpret_cast<const CudaT*>(gemm_buffer.get()),
reinterpret_cast<const CudaT*>(bias->Data<T>()),
reinterpret_cast<const CudaT*>(attention_mask->Data<T>()),
reinterpret_cast<const CudaT*>(global_gemm_buffer),
reinterpret_cast<const CudaT*>(global_bias->Data<T>()),
global_attention_mask->Data<int>(),
global_index_buffer.get(),
batch_global_num_buffer.get(),
pinned_buffer.get(),
workspace_buffer.get(),
output->MutableData<T>(),
batch_size,
sequence_length,
num_heads_,
head_size,
window_,
max_num_global,
element_size,
disable_compact_memory,
use_merged_qkv_weights,
use_half4_));

// Defer release of pinned memory since cudaStreamSynchronize is not used here and kernel need access the buffer.
this->AddDeferredReleaseCPUPtr(pinned_buffer.release(), GetCudaStreamFromContext(context));
this->AddDeferredReleaseCPUPtr(pinned_buffer.release(), context->GetComputeStream());

return Status::OK();
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/allocation_planner.cc
Expand Up @@ -2000,7 +2000,7 @@ class PlannerImpl {
onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType();
const IExecutionProvider* ep = execution_providers.Get(exec_provider_name);
auto& node_device_mem_location = ep->GetAllocator(0, OrtMemType::OrtMemTypeDefault)->Info();
ORT_ENFORCE(execution_plan[node_stream_map_[node_index]]->device_ == node_device_mem_location.device);
ORT_ENFORCE(execution_plan[node_stream_map_[node_index]]->device_.Type() == node_device_mem_location.device.Type());
}
}
// 4. set notification owners
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/framework/execution_context.cc
Expand Up @@ -18,9 +18,15 @@ class DeviceStreamCollectionImpl {
}

virtual ~DeviceStreamCollectionImpl() {
}

Status CleanUp() {
for (auto& device_stream : device_streams_) {
if (device_stream) {
ORT_RETURN_IF_ERROR(device_stream->CleanUpOnRunEnd());
#ifndef ENABLE_TRAINING
device_stream->Flush();
#endif
}
}
// only clean the streams that is owned by current context
Expand All @@ -41,6 +47,7 @@ class DeviceStreamCollectionImpl {
}
}
}
return Status::OK();
}

void SetDeviceStream(size_t idx, std::unique_ptr<Stream> stream) {
Expand Down Expand Up @@ -89,6 +96,10 @@ size_t DeviceStreamCollection::NumStreams() const {
return impl_->NumStreams();
}

Status DeviceStreamCollection::CleanUp() {
return impl_->CleanUp();
}

ExecutionContext::ExecutionContext(const SessionState& sess_state,
int32_t num_streams,
std::vector<size_t> notification_owners,
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/framework/execution_context.h
Expand Up @@ -29,6 +29,7 @@ class DeviceStreamCollection {
void SetDeviceStream(size_t, Stream* stream);
const std::vector<Stream*>& GetStreams() const;
size_t NumStreams() const;
Status CleanUp();

private:
std::unique_ptr<DeviceStreamCollectionImpl> impl_;
Expand Down
11 changes: 9 additions & 2 deletions onnxruntime/core/framework/partial_graph_execution_state.cc
Expand Up @@ -38,9 +38,16 @@ ProgramRegion& PartialGraphExecutionState::GetProgramRegions(const SessionState&
return program_regions_.back();
}

DeviceStreamCollection& PartialGraphExecutionState::GetDeviceStreamCollection(size_t num_streams, const SessionState& session_state) {
PartialGraphExecutionState::~PartialGraphExecutionState() {
if (device_stream_deleter_ && device_stream_collection_) {
device_stream_deleter_(device_stream_collection_.release());
}
}

DeviceStreamCollection& PartialGraphExecutionState::GetDeviceStreamCollection(const SessionState& session_state) {
if (device_stream_collection_ == nullptr) {
device_stream_collection_ = std::make_unique<DeviceStreamCollection>(num_streams, session_state);
device_stream_collection_ = session_state.AcquireDeviceStreamCollection();
device_stream_deleter_ = [&](DeviceStreamCollection* ptr) { session_state.RecycleDeviceStreamCollection(ptr); };
}
return *device_stream_collection_;
}
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/core/framework/partial_graph_execution_state.h
Expand Up @@ -17,10 +17,10 @@ class DeviceStreamCollection;

struct PartialGraphExecutionState {
public:
PartialGraphExecutionState() : execution_context_(nullptr), device_stream_collection_(nullptr) {
PartialGraphExecutionState() : execution_context_(nullptr), device_stream_collection_(nullptr), device_stream_deleter_(nullptr) {
}

~PartialGraphExecutionState() = default;
~PartialGraphExecutionState();

void SetProgramCounterStart(size_t start) { program_counter_start_ = start; }
void SetProgramCounterEnd(size_t end) { program_counter_end_ = end; }
Expand All @@ -36,7 +36,7 @@ struct PartialGraphExecutionState {
const SessionState& session_state,
const logging::Logger& sess_logger,
const DeviceStreamCollection& device_streams);
DeviceStreamCollection& GetDeviceStreamCollection(size_t num_streams, const SessionState& session_state);
DeviceStreamCollection& GetDeviceStreamCollection(const SessionState& session_state);

private:
std::unique_ptr<ExecutionContext> execution_context_;
Expand All @@ -45,6 +45,7 @@ struct PartialGraphExecutionState {

std::vector<ProgramRegion> program_regions_;
std::unique_ptr<DeviceStreamCollection> device_stream_collection_;
std::function<void(DeviceStreamCollection*)> device_stream_deleter_;
};
} // namespace onnxruntime
#endif
37 changes: 37 additions & 0 deletions onnxruntime/core/framework/session_state.cc
Expand Up @@ -1705,4 +1705,41 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
return Status::OK();
}

static void BindToDeviceStream(const SequentialExecutionPlan& execution_plan,
DeviceStreamCollection& device_stream_map,
IStreamCommandHandleRegistry& stream_handle_registry) {
for (size_t i = 0; i < execution_plan.execution_plan.size(); ++i) {
auto& logic_stream = execution_plan.execution_plan[i];
if (logic_stream->steps_.size() > 0) {
auto create_stream_fn = stream_handle_registry.GetCreateStreamFn(logic_stream->device_.Type());
if (create_stream_fn) {
auto device_stream = create_stream_fn(logic_stream->device_);
device_stream_map.SetDeviceStream(i, std::move(device_stream));
} else {
device_stream_map.SetDeviceStream(i, nullptr);
}
} else {
device_stream_map.SetDeviceStream(i, nullptr);
}
}
}

std::unique_ptr<DeviceStreamCollection> SessionState::AcquireDeviceStreamCollection() const {
std::lock_guard<onnxruntime::OrtMutex> lock(mem_patterns_lock_);
if (!device_stream_pool_.empty()) {
auto device_stream = std::move(device_stream_pool_.back());
device_stream_pool_.pop_back();
return device_stream;
} else {
auto device_stream = std::make_unique<DeviceStreamCollection>(this->GetExecutionPlan()->execution_plan.size(), *this);
BindToDeviceStream(*this->GetExecutionPlan(), *device_stream, *stream_handles_registry_);
return device_stream;
}
}

void SessionState::RecycleDeviceStreamCollection(DeviceStreamCollection* device_stream_collection_ptr) const {
souptc marked this conversation as resolved.
Show resolved Hide resolved
std::lock_guard<onnxruntime::OrtMutex> lock(device_stream_pool_mutex_);
device_stream_pool_.emplace_back(std::unique_ptr<DeviceStreamCollection>(device_stream_collection_ptr));
}

} // namespace onnxruntime
12 changes: 11 additions & 1 deletion onnxruntime/core/framework/session_state.h
Expand Up @@ -18,6 +18,7 @@
#include "core/framework/callback.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/execution_providers.h"
#include "core/framework/execution_context.h"
#include "core/framework/feeds_fetches_manager.h"
#include "core/framework/framework_common.h"
#include "core/framework/prepacked_weights_container.h"
Expand Down Expand Up @@ -60,6 +61,7 @@ class OpKernel;
class NodeIndexInfo;
struct SequentialExecutionPlan;
struct MemoryPatternGroup;
class DeviceStreamCollection;
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
class MemoryInfo;
#endif
Expand Down Expand Up @@ -335,8 +337,12 @@ class SessionState {
return subgraph_session_states_;
}

std::unique_ptr<DeviceStreamCollection> AcquireDeviceStreamCollection() const;

void RecycleDeviceStreamCollection(DeviceStreamCollection* device_stream_collection) const;
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
void IncrementGraphExecutionCounter() {
void
IncrementGraphExecutionCounter() {
++graph_executions_counter_;
}

Expand Down Expand Up @@ -563,6 +569,10 @@ class SessionState {
size_t graph_executions_counter_ = 0;
#endif
std::unique_ptr<IStreamCommandHandleRegistry> stream_handles_registry_;

// lock for the device stream pool
mutable OrtMutex device_stream_pool_mutex_;
mutable std::vector<std::unique_ptr<DeviceStreamCollection>> device_stream_pool_;
};

} // namespace onnxruntime
65 changes: 52 additions & 13 deletions onnxruntime/core/framework/utils.cc
Expand Up @@ -553,18 +553,58 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state
return Status::OK();
}

static common::Status ExecuteGraphImpl(const SessionState& session_state,
const FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
const InlinedHashMap<size_t, IExecutor::CustomAllocator>& fetch_allocators,
ExecutionMode execution_mode, const bool* terminate_flag,
const logging::Logger& logger, const bool only_execute_path_to_fetches = false,
Stream* parent_stream = nullptr) {
struct DeviceStreamCollectionHolder {
DeviceStreamCollectionHolder(const SessionState& session_state) : session_state_(session_state),
p_(session_state.AcquireDeviceStreamCollection()) {
}

~DeviceStreamCollectionHolder() {
session_state_.RecycleDeviceStreamCollection(p_.release());
}

const SessionState& session_state_;
std::unique_ptr<DeviceStreamCollection> p_;
};

static void UpdateWithParentStream(DeviceStreamCollection& device_stream_collection,
Stream* parent_stream) {
if (parent_stream) {
// TODO: in theory, we should make current subgraph's stream depends on parent stream.
// but in current code structure, it causing issues with the resource sharing and stream
// lifetime. it also may cause additional cost of stream sync for single stream case.
// In first phase, let's just put all the subgraph execution on the parent stream.
for (size_t i = 0; i < device_stream_collection.NumStreams(); ++i) {
auto* stream = device_stream_collection.GetStreams()[i];
if (stream) {
// if current logic stream is not on the same EP instance as parent stream
// and the EP instance does have async streams (not EP like CPU)
// throw error as we don't have the code to setup the dependency at this moment.
if (stream->device != parent_stream->device) {
ORT_THROW("Subgraph has nodes running on device: ", stream->device.Type(),
" while parent graph node running on device: ", parent_stream->device.Type(),
", this is not supported yet.");
}
device_stream_collection.SetDeviceStream(i, parent_stream);
}
}
}
}

static common::Status
ExecuteGraphImpl(const SessionState& session_state,
const FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
const InlinedHashMap<size_t, IExecutor::CustomAllocator>& fetch_allocators,
ExecutionMode execution_mode, const bool* terminate_flag,
const logging::Logger& logger, const bool only_execute_path_to_fetches = false,
Stream* parent_stream = nullptr) {
const auto& feeds_fetches_info = feeds_fetches_manager.GetFeedsFetchesInfo();
const auto& device_copy_checks = feeds_fetches_manager.GetDeviceCopyChecks();
auto* execution_plan = session_state.GetExecutionPlan();

DeviceStreamCollection device_stream_collection(execution_plan->execution_plan.size(), session_state);
DeviceStreamCollectionHolder device_stream_collection_holder(session_state);
DeviceStreamCollection& device_stream_collection = *device_stream_collection_holder.p_;
UpdateWithParentStream(device_stream_collection, parent_stream);

bool is_subgraph = session_state.GetGraphViewer().ParentNode() != nullptr;
// in following two cases, we execute the workload in main thread:
Expand All @@ -577,7 +617,6 @@ static common::Status ExecuteGraphImpl(const SessionState& session_state,
single_thread_mode = true;
#endif

ORT_ENFORCE(BindToDeviceStream(parent_stream, *execution_plan, device_stream_collection, session_state.GetStreamHandleRegistryInstance()).IsOK());
// see if we can skip copies due to the types of execution providers available
if (device_copy_checks.status == DeviceCopyCheck::NoCopy) {
// no device copies are needed so simple execute
Expand Down Expand Up @@ -666,8 +705,8 @@ static common::Status ExecuteGraphImpl(const SessionState& session_state,
ORT_RETURN_IF_ERROR(CopyOutputsAcrossDevices(session_state, *p_fetches, fetches, fetch_copy_info, fetches_streams));
}
}

return Status::OK();
// clean up stream on run end
return device_stream_collection.CleanUp();
}

common::Status ExecuteGraph(const SessionState& session_state,
Expand Down Expand Up @@ -703,9 +742,9 @@ common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetch
bool single_thread_mode = true;

auto* execution_plan = session_state.GetExecutionPlan();
DeviceStreamCollection& device_stream_collection = state.GetDeviceStreamCollection(execution_plan->execution_plan.size(), session_state);
DeviceStreamCollection& device_stream_collection = state.GetDeviceStreamCollection(session_state);
UpdateWithParentStream(device_stream_collection, parent_stream);

ORT_ENFORCE(BindToDeviceStream(parent_stream, *execution_plan, device_stream_collection, session_state.GetStreamHandleRegistryInstance()).IsOK());
// see if we can skip copies due to the types of execution providers available
if (device_copy_checks.status == DeviceCopyCheck::NoCopy) {
// no device copies are needed so simple execute
Expand Down