|
12 | 12 |
|
13 | 13 | namespace ucx_cuda { |
14 | 14 |
|
15 | | -static __global__ void memcmp_kernel(const void* s1, const void* s2, |
16 | | - int* result, size_t size) |
| 15 | +static __global__ void |
| 16 | +memcmp_kernel(const void *s1, const void *s2, int *result, size_t size) |
17 | 17 | { |
18 | 18 | unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; |
19 | 19 |
|
20 | 20 | *result = 0; |
21 | 21 | for (size_t i = idx; i < size; i += blockDim.x * gridDim.x) { |
22 | | - if (reinterpret_cast<const uint8_t*>(s1)[i] |
23 | | - != reinterpret_cast<const uint8_t*>(s2)[i]) { |
| 22 | + if (reinterpret_cast<const uint8_t*>(s1)[i] != |
| 23 | + reinterpret_cast<const uint8_t*>(s2)[i]) { |
24 | 24 | *result = 1; |
25 | 25 | break; |
26 | 26 | } |
@@ -116,6 +116,20 @@ ucp_counter_inc_kernel(const kernel_params params, ucs_status_t *status) |
116 | 116 | *status = ucp_device_wait_req(&req); |
117 | 117 | } |
118 | 118 |
|
| 119 | +static __global__ void |
| 120 | +ucp_counter_write_kernel(const kernel_params params, ucs_status_t *status) |
| 121 | +{ |
| 122 | + ucp_device_counter_write(params.counter.address, params.counter.value); |
| 123 | + *status = UCS_OK; |
| 124 | +} |
| 125 | + |
| 126 | +static __global__ void |
| 127 | +ucp_counter_read_kernel(const kernel_params params, ucs_status_t *status) |
| 128 | +{ |
| 129 | + uint64_t value = ucp_device_counter_read(params.counter.address); |
| 130 | + *status = (value == params.counter.value ? UCS_OK : UCS_ERR_IO_ERROR); |
| 131 | +} |
| 132 | + |
119 | 133 | /** |
120 | 134 | * @brief Compares two blocks of device memory. |
121 | 135 | * |
@@ -183,4 +197,18 @@ ucs_status_t launch_ucp_counter_inc(const kernel_params ¶ms) |
183 | 197 | return status.sync_read(); |
184 | 198 | } |
185 | 199 |
|
| 200 | +ucs_status_t launch_ucp_counter_write(const kernel_params ¶ms) |
| 201 | +{ |
| 202 | + device_status_result_ptr status; |
| 203 | + ucp_counter_write_kernel<<<1, 1>>>(params, status.device_ptr()); |
| 204 | + return status.sync_read(); |
| 205 | +} |
| 206 | + |
| 207 | +ucs_status_t launch_ucp_counter_read(const kernel_params ¶ms) |
| 208 | +{ |
| 209 | + device_status_result_ptr status; |
| 210 | + ucp_counter_read_kernel<<<1, 1>>>(params, status.device_ptr()); |
| 211 | + return status.sync_read(); |
| 212 | +} |
| 213 | + |
186 | 214 | } // namespace ucx_cuda |
0 commit comments