Skip to content

Commit

Permalink
Mirror change to CUDA driver
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Apr 12, 2024
1 parent 7f6e004 commit e99cbd2
Show file tree
Hide file tree
Showing 9 changed files with 266 additions and 63 deletions.
7 changes: 4 additions & 3 deletions runtime/src/iree/hal/drivers/cuda/cuda_device.c
Original file line number Diff line number Diff line change
Expand Up @@ -543,9 +543,10 @@ static iree_status_t iree_hal_cuda_device_create_command_buffer(
switch (device->params.command_buffer_mode) {
case IREE_HAL_CUDA_COMMAND_BUFFER_MODE_GRAPH:
return iree_hal_cuda_graph_command_buffer_create(
base_device, device->cuda_symbols, device->cu_context, mode,
command_categories, queue_affinity, binding_capacity,
&device->block_pool, device->host_allocator, out_command_buffer);
base_device, device->cuda_symbols, device->tracing_context,
device->cu_context, mode, command_categories, queue_affinity,
binding_capacity, &device->block_pool, device->host_allocator,
out_command_buffer);
case IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM:
return iree_hal_deferred_command_buffer_create(
base_device, mode, command_categories, binding_capacity,
Expand Down
2 changes: 2 additions & 0 deletions runtime/src/iree/hal/drivers/cuda/cuda_dynamic_symbol_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ IREE_CU_PFN_DECL(cuEventSynchronize, CUevent)
IREE_CU_PFN_DECL(cuGetProcAddress, const char*, void**, int, cuuint64_t)
IREE_CU_PFN_DECL(cuGraphAddEmptyNode, CUgraphNode*, CUgraph, const CUgraphNode*,
size_t)
IREE_CU_PFN_DECL(cuGraphAddEventRecordNode, CUgraphNode*, CUgraph,
const CUgraphNode*, size_t, CUevent)
IREE_CU_PFN_DECL(cuGraphAddMemcpyNode, CUgraphNode*, CUgraph,
const CUgraphNode*, size_t, const CUDA_MEMCPY3D*, CUcontext)
IREE_CU_PFN_DECL(cuGraphAddMemsetNode, CUgraphNode*, CUgraph,
Expand Down
159 changes: 140 additions & 19 deletions runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "iree/hal/drivers/cuda/cuda_status_util.h"
#include "iree/hal/drivers/cuda/native_executable.h"
#include "iree/hal/drivers/cuda/pipeline_layout.h"
#include "iree/hal/drivers/cuda/tracing.h"
#include "iree/hal/utils/collective_batch.h"
#include "iree/hal/utils/resource_set.h"

Expand All @@ -30,6 +31,9 @@ typedef struct iree_hal_cuda_graph_command_buffer_t {
iree_allocator_t host_allocator;
const iree_hal_cuda_dynamic_symbols_t* symbols;

// Per-stream CUDA tracing context.
iree_hal_cuda_tracing_context_t* tracing_context;

// A resource set to maintain references to all resources used within the
// command buffer.
iree_hal_resource_set_t* resource_set;
Expand Down Expand Up @@ -65,15 +69,95 @@ typedef struct iree_hal_cuda_graph_command_buffer_t {
static const iree_hal_command_buffer_vtable_t
iree_hal_cuda_graph_command_buffer_vtable;

static iree_status_t
iree_hal_cuda_graph_command_buffer_execution_barrier_internal(
iree_hal_cuda_graph_command_buffer_t* command_buffer);

static iree_hal_cuda_graph_command_buffer_t*
iree_hal_cuda_graph_command_buffer_cast(iree_hal_command_buffer_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_graph_command_buffer_vtable);
return (iree_hal_cuda_graph_command_buffer_t*)base_value;
}

#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE

static void iree_cuda_graph_command_buffer_trace_zone_begin_external(
iree_hal_cuda_graph_command_buffer_t* command_buffer, const char* file_name,
size_t file_name_length, uint32_t line, const char* function_name,
size_t function_name_length, const char* name, size_t name_length) {
// Make sure there are no new nodes after the last barrier.
// Work should start after the event.
if (IREE_UNLIKELY(command_buffer->graph_node_count != 0)) {
iree_hal_cuda_graph_command_buffer_execution_barrier_internal(
command_buffer);
}

CUgraphNode* tracing_event_node =
&command_buffer->cu_graph_nodes[command_buffer->graph_node_count++];
size_t dependency_count = command_buffer->cu_barrier_node ? 1 : 0;
IREE_CUDA_GRAPH_TRACE_ZONE_BEGIN_EXTERNAL(
command_buffer->tracing_context, tracing_event_node,
command_buffer->cu_graph, &command_buffer->cu_barrier_node,
dependency_count, file_name, file_name_length, line, function_name,
function_name_length, name, name_length);

// Move the barrier forward to make sure that the tracing event is recorded
// before work starts.
// Downstream operations will wait on the tracing node.
command_buffer->cu_barrier_node = *tracing_event_node;
}

static void iree_cuda_graph_command_buffer_trace_zone_end(
iree_hal_cuda_graph_command_buffer_t* command_buffer) {
// Make sure there are no new nodes after the last barrier.
// Prior work should end before the tracing event is recorded.
if (IREE_UNLIKELY(command_buffer->graph_node_count != 0)) {
iree_hal_cuda_graph_command_buffer_execution_barrier_internal(
command_buffer);
}

CUgraphNode* tracing_event_node =
&command_buffer->cu_graph_nodes[command_buffer->graph_node_count++];
size_t dependency_count = command_buffer->cu_barrier_node ? 1 : 0;
IREE_ASSERT_GT(dependency_count, 0,
"ending a zone should at least depend on the beginning");
IREE_CUDA_GRAPH_TRACE_ZONE_END(command_buffer->tracing_context,
tracing_event_node, command_buffer->cu_graph,
&command_buffer->cu_barrier_node,
dependency_count);

// We need to wait on the tracing end before other work starts.
// GPU tracing zones are first-in, last-out.
command_buffer->cu_barrier_node = *tracing_event_node;
}

#define IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN_EXTERNAL( \
command_buffer, file_name, file_name_length, line, function_name, \
function_name_length, name, name_length) \
iree_cuda_graph_command_buffer_trace_zone_begin_external( \
command_buffer, file_name, file_name_length, line, function_name, \
function_name_length, name, name_length)
#define IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN(command_buffer) \
IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN_EXTERNAL( \
command_buffer, /*file_name=*/NULL, 0, /*line=*/0, __FUNCTION__, \
strlen(__FUNCTION__), /*name=*/NULL, 0)
#define IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_END(command_buffer) \
iree_cuda_graph_command_buffer_trace_zone_end(command_buffer)

#else // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE

#define IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN_EXTERNAL( \
command_buffer, file_name, file_name_length, line, function_name, \
function_name_length, name, name_length)
#define IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN(command_buffer)
#define IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_END(command_buffer)

#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE

iree_status_t iree_hal_cuda_graph_command_buffer_create(
iree_hal_device_t* device,
const iree_hal_cuda_dynamic_symbols_t* cuda_symbols, CUcontext context,
const iree_hal_cuda_dynamic_symbols_t* cuda_symbols,
iree_hal_cuda_tracing_context_t* tracing_context, CUcontext context,
iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
Expand Down Expand Up @@ -101,6 +185,7 @@ iree_status_t iree_hal_cuda_graph_command_buffer_create(
&iree_hal_cuda_graph_command_buffer_vtable, &command_buffer->base);
command_buffer->host_allocator = host_allocator;
command_buffer->symbols = cuda_symbols;
command_buffer->tracing_context = tracing_context;
iree_arena_initialize(block_pool, &command_buffer->arena);
command_buffer->cu_context = context;
command_buffer->cu_graph = NULL;
Expand Down Expand Up @@ -227,6 +312,8 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_begin(
command_buffer->symbols,
cuGraphCreate(&command_buffer->cu_graph, /*flags=*/0), "cuGraphCreate");

IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN(command_buffer);

return iree_ok_status();
}

Expand All @@ -239,6 +326,8 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_end(
IREE_RETURN_IF_ERROR(
iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));

IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_END(command_buffer);

// Reset state used during recording.
command_buffer->cu_barrier_node = NULL;
command_buffer->graph_node_count = 0;
Expand Down Expand Up @@ -267,27 +356,27 @@ static void iree_hal_cuda_graph_command_buffer_begin_debug_group(
iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label,
iree_hal_label_color_t label_color,
const iree_hal_label_location_t* location) {
// TODO(benvanik): tracy event stack.
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);

(void)command_buffer;
IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN_EXTERNAL(
command_buffer, location ? location->file.data : NULL,
location ? location->file.size : 0, location ? location->line : 0,
/*func_name=*/NULL, 0, label.data, label.size);
}

static void iree_hal_cuda_graph_command_buffer_end_debug_group(
iree_hal_command_buffer_t* base_command_buffer) {
// TODO(benvanik): tracy event stack.
}

static iree_status_t iree_hal_cuda_graph_command_buffer_execution_barrier(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
iree_hal_execution_barrier_flags_t flags,
iree_host_size_t memory_barrier_count,
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
(void)command_buffer;
IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_END(command_buffer);
}

static iree_status_t
iree_hal_cuda_graph_command_buffer_execution_barrier_internal(
iree_hal_cuda_graph_command_buffer_t* command_buffer) {
IREE_RETURN_IF_ERROR(
iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));

Expand All @@ -298,23 +387,42 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_execution_barrier(
if (IREE_LIKELY(command_buffer->graph_node_count == 1)) {
command_buffer->cu_barrier_node = command_buffer->cu_graph_nodes[0];
command_buffer->graph_node_count = 0;
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}

IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
z0, command_buffer->symbols,
IREE_CUDA_RETURN_IF_ERROR(
command_buffer->symbols,
cuGraphAddEmptyNode(
&command_buffer->cu_barrier_node, command_buffer->cu_graph,
command_buffer->cu_graph_nodes, command_buffer->graph_node_count),
"cuGraphAddEmptyNode");

command_buffer->graph_node_count = 0;

IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}

static iree_status_t iree_hal_cuda_graph_command_buffer_execution_barrier(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
iree_hal_execution_barrier_flags_t flags,
iree_host_size_t memory_barrier_count,
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);

iree_status_t status =
iree_hal_cuda_graph_command_buffer_execution_barrier_internal(
command_buffer);

IREE_TRACE_ZONE_END(z0);
return status;
}

static iree_status_t iree_hal_cuda_graph_command_buffer_signal_event(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
Expand Down Expand Up @@ -376,6 +484,7 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_fill_buffer(
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN(command_buffer);

IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));
Expand Down Expand Up @@ -412,6 +521,7 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_fill_buffer(
dependency_count, &params, command_buffer->cu_context),
"cuGraphAddMemsetNode");

IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_END(command_buffer);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
Expand All @@ -423,6 +533,7 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_update_buffer(
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN(command_buffer);

IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));
Expand Down Expand Up @@ -471,6 +582,7 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_update_buffer(
dependency_count, &params, command_buffer->cu_context),
"cuGraphAddMemcpyNode");

IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_END(command_buffer);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
Expand All @@ -483,6 +595,7 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_copy_buffer(
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN(command_buffer);

IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));
Expand Down Expand Up @@ -526,6 +639,7 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_copy_buffer(
dependency_count, &params, command_buffer->cu_context),
"cuGraphAddMemcpyNode");

IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_END(command_buffer);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
Expand Down Expand Up @@ -612,6 +726,12 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch(
z0, iree_hal_cuda_native_executable_entry_point_kernel_info(
executable, entry_point, &kernel_info));

IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN_EXTERNAL(
command_buffer, kernel_info.source_filename.data,
kernel_info.source_filename.size, kernel_info.source_line,
kernel_info.function_name.data, kernel_info.function_name.size,
/*name=*/NULL, 0);

IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
&executable));
Expand Down Expand Up @@ -709,6 +829,7 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch(
dependency_count, &params),
"cuGraphAddKernelNode");

IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_END(command_buffer);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
Expand Down
4 changes: 3 additions & 1 deletion runtime/src/iree/hal/drivers/cuda/graph_command_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ extern "C" {
#endif // __cplusplus

typedef struct iree_arena_block_pool_t iree_arena_block_pool_t;
typedef struct iree_hal_cuda_tracing_context_t iree_hal_cuda_tracing_context_t;

// Creates a command buffer that records into a CUDA graph.
//
Expand All @@ -25,7 +26,8 @@ typedef struct iree_arena_block_pool_t iree_arena_block_pool_t;
// buffers that use it.
iree_status_t iree_hal_cuda_graph_command_buffer_create(
iree_hal_device_t* device,
const iree_hal_cuda_dynamic_symbols_t* cuda_symbols, CUcontext context,
const iree_hal_cuda_dynamic_symbols_t* cuda_symbols,
iree_hal_cuda_tracing_context_t* tracing_context, CUcontext context,
iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
Expand Down
4 changes: 2 additions & 2 deletions runtime/src/iree/hal/drivers/cuda/nccl_channel.c
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ iree_status_t iree_hal_cuda_nccl_submit_batch(
iree_hal_collective_batch_entry_t* entry = &batch->entries[i];
iree_string_view_t collective_str =
iree_hal_collective_op_format(&entry->op, &string_temp);
IREE_CUDA_TRACE_ZONE_BEGIN_EXTERNAL(
IREE_CUDA_STREAM_TRACE_ZONE_BEGIN_EXTERNAL(
tracing_context, stream, __FILE__, strlen(__FILE__), (uint32_t)__LINE__,
__FUNCTION__, strlen(__FUNCTION__), collective_str.data,
collective_str.size);
Expand All @@ -577,7 +577,7 @@ iree_status_t iree_hal_cuda_nccl_submit_batch(
// order doesn't matter so long as we end the right number of zones.
IREE_TRACE({
for (iree_host_size_t i = 0; i < batch->count; ++i) {
IREE_CUDA_TRACE_ZONE_END(tracing_context, stream);
IREE_CUDA_STREAM_TRACE_ZONE_END(tracing_context, stream);
}
});

Expand Down
5 changes: 5 additions & 0 deletions runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,10 @@ static iree_status_t iree_hal_cuda_pending_queue_actions_issue_execution(
}

// Then launch all command buffers to the dispatch stream.
IREE_TRACE_ZONE_BEGIN(dispatch_command_buffers);
IREE_TRACE_ZONE_APPEND_TEXT(dispatch_command_buffers,
" dispatch_command_buffers",
strlen(" dispatch_command_buffers"));
for (iree_host_size_t i = 0; i < action->payload.command_buffers.count; ++i) {
iree_hal_command_buffer_t* command_buffer =
action->payload.command_buffers.ptr[i];
Expand Down Expand Up @@ -622,6 +626,7 @@ static iree_status_t iree_hal_cuda_pending_queue_actions_issue_execution(
iree_hal_buffer_binding_table_empty()));
}
}
IREE_TRACE_ZONE_END(dispatch_command_buffers);

// Last record CUevent signals in the dispatch stream.
for (iree_host_size_t i = 0; i < action->signal_semaphore_list.count; ++i) {
Expand Down

0 comments on commit e99cbd2

Please sign in to comment.