Skip to content

Commit ae23727

Browse files
author
rhc54
authored
Merge pull request #2313 from yosefe/topic/v1.10-ucx-fixes
Topic/v1.10 ucx fixes
2 parents 62ada06 + 3c9b324 commit ae23727

File tree

8 files changed

+317
-79
lines changed

8 files changed

+317
-79
lines changed

ompi/mca/pml/ucx/pml_ucx.c

Lines changed: 146 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99

1010
#include "pml_ucx.h"
1111

12+
#include "opal/memoryhooks/memory.h"
1213
#include "opal/runtime/opal.h"
1314
#include "ompi/runtime/ompi_module_exchange.h"
1415
#include "ompi/message/message.h"
1516
#include "pml_ucx_request.h"
1617

18+
#include <ucm/api/ucm.h>
1719
#include <inttypes.h>
1820

1921

@@ -67,8 +69,9 @@ mca_pml_ucx_module_t ompi_pml_ucx = {
6769
1ul << (PML_UCX_TAG_BITS - 1),
6870
1ul << (PML_UCX_CONTEXT_BITS),
6971
},
70-
NULL,
71-
NULL
72+
0, /* using_mem_hooks */
73+
NULL, /* ucp_context */
74+
NULL /* ucp_worker */
7275
};
7376

7477
static int mca_pml_ucx_send_worker_address(void)
@@ -110,24 +113,51 @@ static int mca_pml_ucx_recv_worker_address(ompi_proc_t *proc,
110113
return ret;
111114
}
112115

116+
static void mca_pml_ucx_mem_release_cb(void *buf, size_t length, void *cbdata,
117+
bool from_alloc)
118+
{
119+
ucm_vm_munmap(buf, length);
120+
}
121+
113122
int mca_pml_ucx_open(void)
114123
{
124+
ucp_context_attr_t attr;
115125
ucp_params_t params;
116126
ucp_config_t *config;
117127
ucs_status_t status;
118128

119129
PML_UCX_VERBOSE(1, "mca_pml_ucx_open");
120130

131+
/* Set memory hooks */
132+
if ((OPAL_MEMORY_FREE_SUPPORT | OPAL_MEMORY_MUNMAP_SUPPORT) ==
133+
((OPAL_MEMORY_FREE_SUPPORT | OPAL_MEMORY_MUNMAP_SUPPORT) &
134+
opal_mem_hooks_support_level()))
135+
{
136+
PML_UCX_VERBOSE(1, "using opal memory hooks");
137+
ucm_set_external_event(UCM_EVENT_VM_UNMAPPED);
138+
ompi_pml_ucx.using_mem_hooks = 1;
139+
} else {
140+
PML_UCX_VERBOSE(1, "not using opal memory hooks");
141+
ompi_pml_ucx.using_mem_hooks = 0;
142+
}
143+
121144
/* Read options */
122145
status = ucp_config_read("MPI", NULL, &config);
123146
if (UCS_OK != status) {
124147
return OMPI_ERROR;
125148
}
126149

150+
/* Initialize UCX context */
151+
params.field_mask = UCP_PARAM_FIELD_FEATURES |
152+
UCP_PARAM_FIELD_REQUEST_SIZE |
153+
UCP_PARAM_FIELD_REQUEST_INIT |
154+
UCP_PARAM_FIELD_REQUEST_CLEANUP |
155+
UCP_PARAM_FIELD_TAG_SENDER_MASK;
127156
params.features = UCP_FEATURE_TAG;
128157
params.request_size = sizeof(ompi_request_t);
129158
params.request_init = mca_pml_ucx_request_init;
130159
params.request_cleanup = mca_pml_ucx_request_cleanup;
160+
params.tag_sender_mask = PML_UCX_SPECIFIC_SOURCE_MASK;
131161

132162
status = ucp_init(&params, config, &ompi_pml_ucx.ucp_context);
133163
ucp_config_release(config);
@@ -136,6 +166,17 @@ int mca_pml_ucx_open(void)
136166
return OMPI_ERROR;
137167
}
138168

169+
/* Query UCX attributes */
170+
attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE;
171+
status = ucp_context_query(ompi_pml_ucx.ucp_context, &attr);
172+
if (UCS_OK != status) {
173+
ucp_cleanup(ompi_pml_ucx.ucp_context);
174+
ompi_pml_ucx.ucp_context = NULL;
175+
return OMPI_ERROR;
176+
}
177+
178+
ompi_pml_ucx.request_size = attr.request_size;
179+
139180
return OMPI_SUCCESS;
140181
}
141182

@@ -157,9 +198,13 @@ int mca_pml_ucx_init(void)
157198

158199
PML_UCX_VERBOSE(1, "mca_pml_ucx_init");
159200

201+
if (ompi_pml_ucx.using_mem_hooks) {
202+
opal_mem_hooks_register_release(mca_pml_ucx_mem_release_cb, NULL);
203+
}
204+
160205
/* TODO check MPI thread mode */
161206
status = ucp_worker_create(ompi_pml_ucx.ucp_context, UCS_THREAD_MODE_SINGLE,
162-
&ompi_pml_ucx.ucp_worker);
207+
&ompi_pml_ucx.ucp_worker);
163208
if (UCS_OK != status) {
164209
return OMPI_ERROR;
165210
}
@@ -203,13 +248,18 @@ int mca_pml_ucx_cleanup(void)
203248
ompi_pml_ucx.ucp_worker = NULL;
204249
}
205250

251+
if (ompi_pml_ucx.using_mem_hooks) {
252+
opal_mem_hooks_unregister_release(mca_pml_ucx_mem_release_cb);
253+
}
254+
206255
return OMPI_SUCCESS;
207256
}
208257

209258
int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
210259
{
211260
ucp_address_t *address;
212261
ucs_status_t status;
262+
ompi_proc_t *proc;
213263
size_t addrlen;
214264
ucp_ep_h ep;
215265
size_t i;
@@ -222,44 +272,106 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
222272
}
223273

224274
for (i = 0; i < nprocs; ++i) {
225-
ret = mca_pml_ucx_recv_worker_address(procs[i], &address, &addrlen);
275+
proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs];
276+
277+
ret = mca_pml_ucx_recv_worker_address(proc, &address, &addrlen);
226278
if (ret < 0) {
279+
PML_UCX_ERROR("Failed to receive worker address from proc: %d",
280+
proc->proc_name.vpid);
227281
return ret;
228282
}
229283

230-
if (procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]) {
231-
PML_UCX_VERBOSE(3, "already connected to proc. %d", procs[i]->proc_name.vpid);
284+
if (proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]) {
285+
PML_UCX_VERBOSE(3, "already connected to proc. %d", proc->proc_name.vpid);
232286
continue;
233287
}
234288

235-
PML_UCX_VERBOSE(2, "connecting to proc. %d", procs[i]->proc_name.vpid);
289+
PML_UCX_VERBOSE(2, "connecting to proc. %d", proc->proc_name.vpid);
236290
status = ucp_ep_create(ompi_pml_ucx.ucp_worker, address, &ep);
237291
free(address);
238292

239293
if (UCS_OK != status) {
240-
PML_UCX_ERROR("Failed to connect");
294+
PML_UCX_ERROR("Failed to connect to proc: %d, %s", proc->proc_name.vpid,
295+
ucs_status_string(status));
241296
return OMPI_ERROR;
242297
}
243298

244-
procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep;
299+
proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep;
245300
}
246301

247302
return OMPI_SUCCESS;
248303
}
249304

305+
static void mca_pml_ucx_waitall(void **reqs, size_t *count_p)
306+
{
307+
ucs_status_t status;
308+
size_t i;
309+
310+
PML_UCX_VERBOSE(2, "waiting for %d disconnect requests", *count_p);
311+
for (i = 0; i < *count_p; ++i) {
312+
do {
313+
opal_progress();
314+
status = ucp_request_test(reqs[i], NULL);
315+
} while (status == UCS_INPROGRESS);
316+
if (status != UCS_OK) {
317+
PML_UCX_ERROR("disconnect request failed: %s",
318+
ucs_status_string(status));
319+
}
320+
ucp_request_release(reqs[i]);
321+
reqs[i] = NULL;
322+
}
323+
324+
*count_p = 0;
325+
}
326+
250327
int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
251328
{
329+
ompi_proc_t *proc;
330+
size_t num_reqs, max_reqs;
331+
void *dreq, **dreqs;
252332
ucp_ep_h ep;
253333
size_t i;
254334

335+
max_reqs = ompi_pml_ucx.num_disconnect;
336+
if (max_reqs > nprocs) {
337+
max_reqs = nprocs;
338+
}
339+
340+
dreqs = malloc(sizeof(*dreqs) * max_reqs);
341+
if (dreqs == NULL) {
342+
return OMPI_ERR_OUT_OF_RESOURCE;
343+
}
344+
345+
num_reqs = 0;
346+
255347
for (i = 0; i < nprocs; ++i) {
256-
PML_UCX_VERBOSE(2, "disconnecting from rank %d", procs[i]->proc_name.vpid);
257-
ep = procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
258-
if (ep != NULL) {
259-
ucp_ep_destroy(ep);
348+
proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs];
349+
ep = proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
350+
if (ep == NULL) {
351+
continue;
352+
}
353+
354+
PML_UCX_VERBOSE(2, "disconnecting from rank %d", proc->proc_name.vpid);
355+
dreq = ucp_disconnect_nb(ep);
356+
if (dreq != NULL) {
357+
if (UCS_PTR_IS_ERR(dreq)) {
358+
PML_UCX_ERROR("ucp_disconnect_nb(%d) failed: %s",
359+
proc->proc_name.vpid,
360+
ucs_status_string(UCS_PTR_STATUS(dreq)));
361+
} else {
362+
dreqs[num_reqs++] = dreq;
363+
}
364+
}
365+
366+
proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL;
367+
368+
if (num_reqs >= ompi_pml_ucx.num_disconnect) {
369+
mca_pml_ucx_waitall(dreqs, &num_reqs);
260370
}
261-
procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL;
262371
}
372+
373+
mca_pml_ucx_waitall(dreqs, &num_reqs);
374+
free(dreqs);
263375
return OMPI_SUCCESS;
264376
}
265377

@@ -276,14 +388,7 @@ int mca_pml_ucx_enable(bool enable)
276388

277389
int mca_pml_ucx_progress(void)
278390
{
279-
static int inprogress = 0;
280-
if (inprogress != 0) {
281-
return 0;
282-
}
283-
284-
++inprogress;
285391
ucp_worker_progress(ompi_pml_ucx.ucp_worker);
286-
--inprogress;
287392
return OMPI_SUCCESS;
288393
}
289394

@@ -348,53 +453,32 @@ int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype,
348453
return OMPI_SUCCESS;
349454
}
350455

351-
static void
352-
mca_pml_ucx_blocking_recv_completion(void *request, ucs_status_t status,
353-
ucp_tag_recv_info_t *info)
354-
{
355-
ompi_request_t *req = request;
356-
357-
PML_UCX_VERBOSE(8, "blocking receive request %p completed with status %s tag %"PRIx64" len %zu",
358-
(void*)req, ucs_status_string(status), info->sender_tag,
359-
info->length);
360-
361-
OPAL_THREAD_LOCK(&ompi_request_lock);
362-
mca_pml_ucx_set_recv_status(&req->req_status, status, info);
363-
PML_UCX_ASSERT(!req->req_complete);
364-
req->req_complete = true;
365-
OPAL_THREAD_UNLOCK(&ompi_request_lock);
366-
}
367-
368456
int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src,
369457
int tag, struct ompi_communicator_t* comm,
370458
ompi_status_public_t* mpi_status)
371459
{
372460
ucp_tag_t ucp_tag, ucp_tag_mask;
373-
ompi_request_t *req;
461+
ucp_tag_recv_info_t info;
462+
ucs_status_t status;
463+
void *req;
374464

375465
PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv");
376466

377467
PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
378-
req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker, buf, count,
379-
mca_pml_ucx_get_datatype(datatype),
380-
ucp_tag, ucp_tag_mask,
381-
mca_pml_ucx_blocking_recv_completion);
382-
if (UCS_PTR_IS_ERR(req)) {
383-
PML_UCX_ERROR("ucx recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
384-
return OMPI_ERROR;
385-
}
468+
req = alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size;
469+
status = ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count,
470+
mca_pml_ucx_get_datatype(datatype),
471+
ucp_tag, ucp_tag_mask, req);
386472

387-
while (!req->req_complete) {
473+
ucp_worker_progress(ompi_pml_ucx.ucp_worker);
474+
for (;;) {
475+
status = ucp_request_test(req, &info);
476+
if (status != UCS_INPROGRESS) {
477+
mca_pml_ucx_set_recv_status_safe(mpi_status, status, &info);
478+
return OMPI_SUCCESS;
479+
}
388480
opal_progress();
389481
}
390-
391-
if (mpi_status != MPI_STATUS_IGNORE) {
392-
*mpi_status = req->req_status;
393-
}
394-
395-
req->req_complete = false;
396-
ucp_request_release(req);
397-
return OMPI_SUCCESS;
398482
}
399483

400484
static inline const char *mca_pml_ucx_send_mode_name(mca_pml_base_send_mode_t mode)
@@ -490,10 +574,11 @@ int mca_pml_ucx_send(void *buf, size_t count, ompi_datatype_t *datatype, int dst
490574
mca_pml_ucx_get_datatype(datatype),
491575
PML_UCX_MAKE_SEND_TAG(tag, comm),
492576
mca_pml_ucx_send_completion);
493-
if (req == NULL) {
577+
if (OPAL_LIKELY(req == NULL)) {
494578
return OMPI_SUCCESS;
495579
} else if (!UCS_PTR_IS_ERR(req)) {
496580
PML_UCX_VERBOSE(8, "got request %p", (void*)req);
581+
ucp_worker_progress(ompi_pml_ucx.ucp_worker);
497582
ompi_request_wait(&req, MPI_STATUS_IGNORE);
498583
return OMPI_SUCCESS;
499584
} else {
@@ -518,6 +603,7 @@ int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm,
518603
*matched = 1;
519604
mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
520605
} else {
606+
opal_progress();
521607
*matched = 0;
522608
}
523609
return OMPI_SUCCESS;
@@ -563,7 +649,8 @@ int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm,
563649
PML_UCX_VERBOSE(8, "got message %p (%p)", (void*)*message, (void*)ucp_msg);
564650
*matched = 1;
565651
mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
566-
} else if (UCS_PTR_STATUS(ucp_msg) == UCS_ERR_NO_MESSAGE) {
652+
} else {
653+
opal_progress();
567654
*matched = 0;
568655
}
569656
return OMPI_SUCCESS;
@@ -696,6 +783,7 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
696783
PML_UCX_VERBOSE(8, "temporary request %p will complete persistent request %p",
697784
(void*)tmp_req, (void*)preq);
698785
tmp_req->req_complete_cb_data = preq;
786+
preq->tmp_req = tmp_req;
699787
}
700788
OPAL_THREAD_UNLOCK(&ompi_request_lock);
701789
} else {

ompi/mca/pml/ucx/pml_ucx.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,17 @@ struct mca_pml_ucx_module {
3434
mca_pml_base_module_t super;
3535

3636
/* UCX global objects */
37+
int using_mem_hooks;
3738
ucp_context_h ucp_context;
3839
ucp_worker_h ucp_worker;
3940

4041
/* Requests */
4142
mca_pml_ucx_freelist_t persistent_reqs;
4243
ompi_request_t completed_send_req;
44+
size_t request_size;
45+
int num_disconnect;
4346

44-
/* Convertors pool */
47+
/* Converters pool */
4548
mca_pml_ucx_freelist_t convs;
4649

4750
int priority;

0 commit comments

Comments
 (0)