@@ -80,10 +80,9 @@ struct ucp_perf_cuda_params {
8080 ucp_device_mem_list_handle_h mem_list;
8181 size_t length;
8282 unsigned *indices;
83- void **addresses ;
84- uint64_t *remote_addresses ;
83+ size_t *local_offsets ;
84+ size_t *remote_offsets ;
8585 size_t *lengths;
86- uint64_t counter_remote;
8786 uint64_t *counter_send;
8887 uint64_t *counter_recv;
8988 ucp_device_flags_t flags;
@@ -102,8 +101,8 @@ public:
102101 {
103102 ucp_device_mem_list_release (m_params.mem_list );
104103 CUDA_CALL_WARN (cudaFree, m_params.indices );
105- CUDA_CALL_WARN (cudaFree, m_params.addresses );
106- CUDA_CALL_WARN (cudaFree, m_params.remote_addresses );
104+ CUDA_CALL_WARN (cudaFree, m_params.local_offsets );
105+ CUDA_CALL_WARN (cudaFree, m_params.remote_offsets );
107106 CUDA_CALL_WARN (cudaFree, m_params.lengths );
108107 }
109108
@@ -113,13 +112,23 @@ private:
113112 void init_mem_list (const ucx_perf_context_t &perf)
114113 {
115114 /* +1 for the counter */
116- size_t count = perf.params .msg_size_cnt + 1 ;
115+ size_t count = perf.params .msg_size_cnt + 1 ;
116+ size_t offset = 0 ;
117117 ucp_device_mem_list_elem_t elems[count];
118+
118119 for (size_t i = 0 ; i < count; ++i) {
119120 elems[i].field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_MEMH |
120- UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY;
121+ UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY |
122+ UCP_DEVICE_MEM_LIST_ELEM_FIELD_LOCAL_ADDR |
123+ UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR |
124+ UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH;
121125 elems[i].memh = perf.ucp .send_memh ;
122126 elems[i].rkey = perf.ucp .rkey ;
127+ elems[i].local_addr = UCS_PTR_BYTE_OFFSET (perf.send_buffer , offset);
128+ elems[i].remote_addr = perf.ucp .remote_addr + offset;
129+ elems[i].length = (i == count - 1 ) ? ONESIDED_SIGNAL_SIZE :
130+ perf.params .msg_size_list [i];
131+ offset += elems[i].length ;
123132 }
124133
125134 ucp_device_mem_list_params_t params;
@@ -140,33 +149,30 @@ private:
140149 void init_elements (const ucx_perf_context_t &perf)
141150 {
142151 /* +1 for the counter */
143- size_t count = perf.params .msg_size_cnt + 1 ;
152+ size_t count = perf.params .msg_size_cnt + 1 ;
153+ size_t offset = 0 ;
144154
145155 std::vector<unsigned > indices (count);
146- std::vector<void *> addresses (count);
147- std::vector<uint64_t > remote_addresses (count);
156+ std::vector<size_t > local_offsets (count, 0 );
157+ std::vector<size_t > remote_offsets (count, 0 );
148158 std::vector<size_t > lengths (count);
149- for (unsigned i = 0 , offset = 0 ; i < count; ++i) {
150- indices[i] = i;
151- addresses[i] = (char *)perf.send_buffer + offset;
152- remote_addresses[i] = perf.ucp .remote_addr + offset;
153- lengths[i] = (i == count - 1 ) ? ONESIDED_SIGNAL_SIZE :
154- perf.params .msg_size_list [i];
155- offset += lengths[i];
159+
160+ for (unsigned i = 0 ; i < count; ++i) {
161+ indices[i] = i;
162+ lengths[i] = (i == count - 1 ) ? ONESIDED_SIGNAL_SIZE :
163+ perf.params .msg_size_list [i];
164+ offset += lengths[i];
156165 }
157166
158167 device_clone (&m_params.indices , indices.data (), count);
159- device_clone (&m_params.addresses , addresses .data (), count);
160- device_clone (&m_params.remote_addresses , remote_addresses .data (), count);
168+ device_clone (&m_params.local_offsets , local_offsets .data (), count);
169+ device_clone (&m_params.remote_offsets , remote_offsets .data (), count);
161170 device_clone (&m_params.lengths , lengths.data (), count);
162171 }
163172
164173 void init_counters (const ucx_perf_context_t &perf)
165174 {
166175 m_params.length = ucx_perf_get_message_size (&perf.params );
167- m_params.counter_remote = (uint64_t )ucx_perf_cuda_get_sn (
168- (void *)perf.ucp .remote_addr ,
169- m_params.length );
170176 m_params.counter_send = ucx_perf_cuda_get_sn (perf.send_buffer ,
171177 m_params.length );
172178 m_params.counter_recv = ucx_perf_cuda_get_sn (perf.recv_buffer ,
@@ -195,28 +201,20 @@ ucp_perf_cuda_send_nbx(ucp_perf_cuda_params ¶ms, ucx_perf_counter_t idx,
195201 /* TODO: Change to ucp_device_counter_write */
196202 *params.counter_send = idx + 1 ;
197203 return ucp_device_put_single<level>(params.mem_list , params.indices [0 ],
198- params. addresses [ 0 ] ,
199- params.remote_addresses [ 0 ],
200- params. length + ONESIDED_SIGNAL_SIZE,
201- params.flags , &req);
204+ 0 , 0 ,
205+ params.length +
206+ ONESIDED_SIGNAL_SIZE,
207+ 0 , params.flags , &req);
202208 case UCX_PERF_CMD_PUT_MULTI:
203- return ucp_device_put_multi<level>(params.mem_list , params.addresses ,
204- params.remote_addresses ,
205- params.lengths , 1 ,
206- params.counter_remote , params.flags ,
209+ return ucp_device_put_multi<level>(params.mem_list , 1 , 0 , params.flags ,
207210 &req);
208- case UCX_PERF_CMD_PUT_PARTIAL:{
211+ case UCX_PERF_CMD_PUT_PARTIAL: {
209212 unsigned counter_index = params.mem_list ->mem_list_length - 1 ;
210- return ucp_device_put_multi_partial<level>(params.mem_list ,
211- params.indices ,
212- counter_index,
213- params.addresses ,
214- params.remote_addresses ,
215- params.lengths ,
216- counter_index, 1 ,
217- params.counter_remote ,
218- params.flags , &req);
219- }
213+ return ucp_device_put_multi_partial<level>(
214+ params.mem_list , params.indices , counter_index,
215+ params.local_offsets , params.remote_offsets , params.lengths ,
216+ counter_index, 1 , 0 , 0 , params.flags , &req);
217+ }
220218 }
221219
222220 return UCS_ERR_INVALID_PARAM;
0 commit comments