diff --git a/runtime/src/iree/hal/drivers/cuda2/cuda_device.c b/runtime/src/iree/hal/drivers/cuda2/cuda_device.c index ac519a2f3104..1d0a74a481b7 100644 --- a/runtime/src/iree/hal/drivers/cuda2/cuda_device.c +++ b/runtime/src/iree/hal/drivers/cuda2/cuda_device.c @@ -742,6 +742,11 @@ static iree_status_t iree_hal_cuda2_device_queue_write( return loop_status; } +static void iree_hal_cuda2_device_collect_tracing_context(void* user_data) { + iree_hal_cuda2_tracing_context_collect( + (iree_hal_cuda2_tracing_context_t*)user_data); +} + static iree_status_t iree_hal_cuda2_device_queue_execute( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, const iree_hal_semaphore_list_t wait_semaphore_list, @@ -753,15 +758,16 @@ static iree_status_t iree_hal_cuda2_device_queue_execute( iree_status_t status = iree_hal_cuda2_pending_queue_actions_enqueue_execution( base_device, device->dispatch_cu_stream, device->callback_cu_stream, - device->pending_queue_actions, wait_semaphore_list, signal_semaphore_list, - command_buffer_count, command_buffers); + device->pending_queue_actions, + iree_hal_cuda2_device_collect_tracing_context, device->tracing_context, + wait_semaphore_list, signal_semaphore_list, command_buffer_count, + command_buffers); if (iree_status_is_ok(status)) { // Try to advance the pending workload queue. status = iree_hal_cuda2_pending_queue_actions_issue( device->pending_queue_actions); } - iree_hal_cuda2_tracing_context_collect(device->tracing_context); IREE_TRACE_ZONE_END(z0); return status; } diff --git a/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.c b/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.c index 4886ebc757db..ab76728f2be3 100644 --- a/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.c +++ b/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.c @@ -49,6 +49,12 @@ typedef struct iree_hal_cuda2_queue_action_t { // Retained to make sure it outlives the current action. iree_hal_cuda2_pending_queue_actions_t* owning_actions; + // The callback to run after completing this action and before freeing + // all resources. + iree_hal_cuda2_pending_action_cleanup_callback_t cleanup_callback; + // User data to pass into the callback. + void* callback_user_data; + iree_hal_cuda2_queue_action_kind_t kind; union { struct { @@ -403,6 +409,8 @@ static void iree_hal_cuda2_free_semaphore_list( iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution( iree_hal_device_t* device, CUstream dispatch_stream, CUstream callback_stream, iree_hal_cuda2_pending_queue_actions_t* actions, + iree_hal_cuda2_pending_action_cleanup_callback_t cleanup_callback, + void* callback_user_data, const iree_hal_semaphore_list_t wait_semaphore_list, const iree_hal_semaphore_list_t signal_semaphore_list, iree_host_size_t command_buffer_count, @@ -417,6 +425,8 @@ iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution( (void**)&action)); action->kind = IREE_HAL_CUDA2_QUEUE_ACTION_TYPE_EXECUTION; + action->cleanup_callback = cleanup_callback; + action->callback_user_data = callback_user_data; action->device = device; action->dispatch_cu_stream = dispatch_stream; action->callback_cu_stream = callback_stream; @@ -604,6 +614,8 @@ static void iree_hal_cuda2_pending_queue_actions_cleanup_execution( iree_allocator_t host_allocator = actions->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); + action->cleanup_callback(action->callback_user_data); + iree_hal_resource_set_free(action->resource_set); iree_hal_cuda2_free_semaphore_list(host_allocator, &action->wait_semaphore_list); diff --git a/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.h b/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.h index 1484c2bda8ff..574d4c39a6ad 100644 --- a/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.h +++ b/runtime/src/iree/hal/drivers/cuda2/pending_queue_actions.h @@ -45,11 +45,20 @@ iree_status_t iree_hal_cuda2_pending_queue_actions_create( // Destroys the pending |actions| queue. void iree_hal_cuda2_pending_queue_actions_destroy(iree_hal_resource_t* actions); +// Callback to execute user code after action completion but before resource +// releasing. +// +// Data behind |user_data| must remain alive before the action is released. +typedef void(IREE_API_PTR* iree_hal_cuda2_pending_action_cleanup_callback_t)( + void* user_data); + // Enqueues the given list of |command_buffers| that waits on // |wait_semaphore_list| and signals |signal_semaphore_lsit|. iree_status_t iree_hal_cuda2_pending_queue_actions_enqueue_execution( iree_hal_device_t* device, CUstream dispatch_stream, CUstream callback_stream, iree_hal_cuda2_pending_queue_actions_t* actions, + iree_hal_cuda2_pending_action_cleanup_callback_t cleanup_callback, + void* callback_user_data, const iree_hal_semaphore_list_t wait_semaphore_list, const iree_hal_semaphore_list_t signal_semaphore_list, iree_host_size_t command_buffer_count,