Skip to content

Commit 5d40ef7

Browse files
authored
UCP/DEVICE: add method to write to local counter. (#10895)
1 parent 29831d3 commit 5d40ef7

File tree

6 files changed

+104
-5
lines changed

6 files changed

+104
-5
lines changed

src/ucp/api/device/ucp_device_impl.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,28 @@ UCS_F_DEVICE uint64_t ucp_device_counter_read(const void *counter_ptr)
364364
}
365365

366366

367+
/**
368+
* @ingroup UCP_DEVICE
369+
* @brief Write value to the counter memory area.
370+
*
371+
* This function can be used to set counter to a specific value.
372+
*
373+
* The counter memory area must be initialized with the host function
374+
* @ref ucp_device_counter_init.
375+
*
376+
* @tparam level Level of cooperation of the transfer.
377+
* @param [in] counter_ptr Counter memory area.
378+
* @param [in] value Value to write.
379+
*
380+
*/
381+
template<ucs_device_level_t level = UCS_DEVICE_LEVEL_THREAD>
382+
UCS_F_DEVICE void ucp_device_counter_write(void *counter_ptr, uint64_t value)
383+
{
384+
return ucs_device_atomic64_write(
385+
reinterpret_cast<uint64_t*>(counter_ptr), value);
386+
}
387+
388+
367389
/**
368390
* @ingroup UCP_DEVICE
369391
* @brief Progress a device request containing a batch of operations.

src/ucp/core/ucp_device.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ static ucs_status_t ucp_device_mem_list_create_handle(
271271
}
272272

273273
if (i == 0) {
274-
ucs_error("failed to select lane");
274+
ucs_error("failed to select lane for local device %s",
275+
ucs_topo_sys_device_get_name(local_sys_dev));
275276
return UCS_ERR_NO_RESOURCE;
276277
}
277278

src/ucs/sys/device_code.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,22 @@ UCS_F_DEVICE uint64_t ucs_device_atomic64_read(const uint64_t *ptr)
4747
}
4848

4949

50+
/*
51+
* Write a 64-bit value to counter global memory address.
52+
*/
53+
UCS_F_DEVICE void ucs_device_atomic64_write(uint64_t *ptr, uint64_t value)
54+
{
55+
#ifdef __NVCC__
56+
asm volatile("st.release.sys.u64 [%0], %1;"
57+
:
58+
: "l"(ptr), "l"(value)
59+
: "memory");
60+
#else
61+
*ptr = value;
62+
#endif
63+
}
64+
65+
5066
/* Helper macro to print a message from a device function including the
5167
* thread and block indices */
5268
#define ucs_device_printf(_title, _fmt, ...) \

test/gtest/ucp/cuda/test_kernels.cu

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212

1313
namespace ucx_cuda {
1414

15-
static __global__ void memcmp_kernel(const void* s1, const void* s2,
16-
int* result, size_t size)
15+
static __global__ void
16+
memcmp_kernel(const void *s1, const void *s2, int *result, size_t size)
1717
{
1818
unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
1919

2020
*result = 0;
2121
for (size_t i = idx; i < size; i += blockDim.x * gridDim.x) {
22-
if (reinterpret_cast<const uint8_t*>(s1)[i]
23-
!= reinterpret_cast<const uint8_t*>(s2)[i]) {
22+
if (reinterpret_cast<const uint8_t*>(s1)[i] !=
23+
reinterpret_cast<const uint8_t*>(s2)[i]) {
2424
*result = 1;
2525
break;
2626
}
@@ -116,6 +116,20 @@ ucp_counter_inc_kernel(const kernel_params params, ucs_status_t *status)
116116
*status = ucp_device_wait_req(&req);
117117
}
118118

119+
static __global__ void
120+
ucp_counter_write_kernel(const kernel_params params, ucs_status_t *status)
121+
{
122+
ucp_device_counter_write(params.counter.address, params.counter.value);
123+
*status = UCS_OK;
124+
}
125+
126+
static __global__ void
127+
ucp_counter_read_kernel(const kernel_params params, ucs_status_t *status)
128+
{
129+
uint64_t value = ucp_device_counter_read(params.counter.address);
130+
*status = (value == params.counter.value ? UCS_OK : UCS_ERR_IO_ERROR);
131+
}
132+
119133
/**
120134
* @brief Compares two blocks of device memory.
121135
*
@@ -183,4 +197,18 @@ ucs_status_t launch_ucp_counter_inc(const kernel_params &params)
183197
return status.sync_read();
184198
}
185199

200+
ucs_status_t launch_ucp_counter_write(const kernel_params &params)
201+
{
202+
device_status_result_ptr status;
203+
ucp_counter_write_kernel<<<1, 1>>>(params, status.device_ptr());
204+
return status.sync_read();
205+
}
206+
207+
ucs_status_t launch_ucp_counter_read(const kernel_params &params)
208+
{
209+
device_status_result_ptr status;
210+
ucp_counter_read_kernel<<<1, 1>>>(params, status.device_ptr());
211+
return status.sync_read();
212+
}
213+
186214
} // namespace ucx_cuda

test/gtest/ucp/cuda/test_kernels.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ struct kernel_params {
2222
} single;
2323
struct {
2424
unsigned mem_list_index;
25+
void* address;
26+
uint64_t value;
2527
uint64_t inc_value;
2628
uint64_t remote_address;
2729
} counter;
@@ -56,6 +58,10 @@ ucs_status_t launch_ucp_put_multi_partial(const kernel_params &params);
5658

5759
ucs_status_t launch_ucp_counter_inc(const kernel_params &params);
5860

61+
ucs_status_t launch_ucp_counter_write(const kernel_params &params);
62+
63+
ucs_status_t launch_ucp_counter_read(const kernel_params &params);
64+
5965
}; // namespace ucx_cuda
6066

6167
#endif

test/gtest/ucp/test_ucp_device.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,32 @@ UCS_TEST_P(test_ucp_device, counter)
350350
EXPECT_EQ(1, list.dst_counter_read(mem_list_index));
351351
}
352352

353+
UCS_TEST_P(test_ucp_device, local_counter)
354+
{
355+
const size_t size = counter_size();
356+
mem_list list(sender(), receiver(), size, 1);
357+
358+
static constexpr unsigned mem_list_index = 0;
359+
list.dst_counter_init(mem_list_index);
360+
361+
// Perform the write
362+
ucx_cuda::kernel_params params;
363+
params.counter.address = list.src_ptr(mem_list_index);
364+
params.counter.value = 1;
365+
ucs_status_t status = ucx_cuda::launch_ucp_counter_write(params);
366+
ASSERT_EQ(UCS_OK, status);
367+
368+
// Check counter value
369+
EXPECT_EQ(true, mem_buffer::compare(&params.counter.value,
370+
list.src_ptr(mem_list_index),
371+
sizeof(params.counter.value),
372+
UCS_MEMORY_TYPE_CUDA));
373+
374+
// Check counter value using device API
375+
status = ucx_cuda::launch_ucp_counter_read(params);
376+
ASSERT_EQ(UCS_OK, status);
377+
}
378+
353379
UCS_TEST_P(test_ucp_device, create_fail)
354380
{
355381
ucp_device_mem_list_handle_h handle = nullptr;

0 commit comments

Comments
 (0)