Skip to content

Commit

Permalink
[cuda] Add some comment and format
Browse files Browse the repository at this point in the history
  • Loading branch information
antiagainst committed Nov 7, 2023
1 parent 127e101 commit 82e99f6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
7 changes: 6 additions & 1 deletion experimental/cuda2/graph_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_push_descriptor_set(
iree_hal_buffer_allocated_buffer(binding->buffer));
iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer);
device_ptr = device_buffer + offset + binding->offset;
};
}
current_bindings[binding->binding] = device_ptr;
}

Expand Down Expand Up @@ -665,6 +665,7 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_dispatch(
iree_host_size_t set_count =
iree_hal_cuda2_pipeline_layout_descriptor_set_count(kernel_info.layout);
for (iree_host_size_t i = 0; i < set_count; ++i) {
// TODO: cache this information in the kernel info to avoid recomputation.
iree_host_size_t binding_count =
iree_hal_cuda2_descriptor_set_layout_binding_count(
iree_hal_cuda2_pipeline_layout_descriptor_set_layout(
Expand All @@ -678,6 +679,10 @@ static iree_status_t iree_hal_cuda2_graph_command_buffer_dispatch(
// Append the push constants to the kernel arguments.
iree_host_size_t base_index =
iree_hal_cuda2_pipeline_layout_push_constant_index(kernel_info.layout);
// As commented in the above, what each kernel parameter points to is a
// CUdeviceptr, which as the size of a pointer on the target machine. we are
// just storing a 32-bit value for the push constant here instead. So we must
// process one element each type, for 64-bit machines.
for (iree_host_size_t i = 0; i < push_constant_count; i++) {
*((uint32_t*)params_ptr[base_index + i]) =
command_buffer->push_constants[i];
Expand Down
7 changes: 6 additions & 1 deletion experimental/cuda2/stream_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ static iree_status_t iree_hal_cuda2_stream_command_buffer_push_descriptor_set(
iree_hal_buffer_allocated_buffer(binding->buffer));
iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer);
device_ptr = device_buffer + offset + binding->offset;
};
}
current_bindings[binding->binding] = device_ptr;
}

Expand Down Expand Up @@ -571,6 +571,7 @@ static iree_status_t iree_hal_cuda2_stream_command_buffer_dispatch(
iree_host_size_t set_count =
iree_hal_cuda2_pipeline_layout_descriptor_set_count(kernel_info.layout);
for (iree_host_size_t i = 0; i < set_count; ++i) {
// TODO: cache this information in the kernel info to avoid recomputation.
iree_host_size_t binding_count =
iree_hal_cuda2_descriptor_set_layout_binding_count(
iree_hal_cuda2_pipeline_layout_descriptor_set_layout(
Expand All @@ -584,6 +585,10 @@ static iree_status_t iree_hal_cuda2_stream_command_buffer_dispatch(
// Append the push constants to the kernel arguments.
iree_host_size_t base_index =
iree_hal_cuda2_pipeline_layout_push_constant_index(kernel_info.layout);
// As commented in the above, what each kernel parameter points to is a
// CUdeviceptr, which as the size of a pointer on the target machine. we are
// just storing a 32-bit value for the push constant here instead. So we must
// process one element each type, for 64-bit machines.
for (iree_host_size_t i = 0; i < push_constant_count; i++) {
*((uint32_t*)params_ptr[base_index + i]) =
command_buffer->push_constants[i];
Expand Down

0 comments on commit 82e99f6

Please sign in to comment.