99
1010#include "pml_ucx.h"
1111
12+ #include "opal/memoryhooks/memory.h"
1213#include "opal/runtime/opal.h"
1314#include "ompi/runtime/ompi_module_exchange.h"
1415#include "ompi/message/message.h"
1516#include "pml_ucx_request.h"
1617
18+ #include <ucm/api/ucm.h>
1719#include <inttypes.h>
1820
1921
@@ -67,8 +69,9 @@ mca_pml_ucx_module_t ompi_pml_ucx = {
6769 1ul << (PML_UCX_TAG_BITS - 1 ),
6870 1ul << (PML_UCX_CONTEXT_BITS ),
6971 },
70- NULL ,
71- NULL
72+ 0 , /* using_mem_hooks */
73+ NULL , /* ucp_context */
74+ NULL /* ucp_worker */
7275};
7376
7477static int mca_pml_ucx_send_worker_address (void )
@@ -110,24 +113,51 @@ static int mca_pml_ucx_recv_worker_address(ompi_proc_t *proc,
110113 return ret ;
111114}
112115
116+ static void mca_pml_ucx_mem_release_cb (void * buf , size_t length , void * cbdata ,
117+ bool from_alloc )
118+ {
119+ ucm_vm_munmap (buf , length );
120+ }
121+
113122int mca_pml_ucx_open (void )
114123{
124+ ucp_context_attr_t attr ;
115125 ucp_params_t params ;
116126 ucp_config_t * config ;
117127 ucs_status_t status ;
118128
119129 PML_UCX_VERBOSE (1 , "mca_pml_ucx_open" );
120130
131+ /* Set memory hooks */
132+ if ((OPAL_MEMORY_FREE_SUPPORT | OPAL_MEMORY_MUNMAP_SUPPORT ) ==
133+ ((OPAL_MEMORY_FREE_SUPPORT | OPAL_MEMORY_MUNMAP_SUPPORT ) &
134+ opal_mem_hooks_support_level ()))
135+ {
136+ PML_UCX_VERBOSE (1 , "using opal memory hooks" );
137+ ucm_set_external_event (UCM_EVENT_VM_UNMAPPED );
138+ ompi_pml_ucx .using_mem_hooks = 1 ;
139+ } else {
140+ PML_UCX_VERBOSE (1 , "not using opal memory hooks" );
141+ ompi_pml_ucx .using_mem_hooks = 0 ;
142+ }
143+
121144 /* Read options */
122145 status = ucp_config_read ("MPI" , NULL , & config );
123146 if (UCS_OK != status ) {
124147 return OMPI_ERROR ;
125148 }
126149
150+ /* Initialize UCX context */
151+ params .field_mask = UCP_PARAM_FIELD_FEATURES |
152+ UCP_PARAM_FIELD_REQUEST_SIZE |
153+ UCP_PARAM_FIELD_REQUEST_INIT |
154+ UCP_PARAM_FIELD_REQUEST_CLEANUP |
155+ UCP_PARAM_FIELD_TAG_SENDER_MASK ;
127156 params .features = UCP_FEATURE_TAG ;
128157 params .request_size = sizeof (ompi_request_t );
129158 params .request_init = mca_pml_ucx_request_init ;
130159 params .request_cleanup = mca_pml_ucx_request_cleanup ;
160+ params .tag_sender_mask = PML_UCX_SPECIFIC_SOURCE_MASK ;
131161
132162 status = ucp_init (& params , config , & ompi_pml_ucx .ucp_context );
133163 ucp_config_release (config );
@@ -136,6 +166,17 @@ int mca_pml_ucx_open(void)
136166 return OMPI_ERROR ;
137167 }
138168
169+ /* Query UCX attributes */
170+ attr .field_mask = UCP_ATTR_FIELD_REQUEST_SIZE ;
171+ status = ucp_context_query (ompi_pml_ucx .ucp_context , & attr );
172+ if (UCS_OK != status ) {
173+ ucp_cleanup (ompi_pml_ucx .ucp_context );
174+ ompi_pml_ucx .ucp_context = NULL ;
175+ return OMPI_ERROR ;
176+ }
177+
178+ ompi_pml_ucx .request_size = attr .request_size ;
179+
139180 return OMPI_SUCCESS ;
140181}
141182
@@ -157,9 +198,13 @@ int mca_pml_ucx_init(void)
157198
158199 PML_UCX_VERBOSE (1 , "mca_pml_ucx_init" );
159200
201+ if (ompi_pml_ucx .using_mem_hooks ) {
202+ opal_mem_hooks_register_release (mca_pml_ucx_mem_release_cb , NULL );
203+ }
204+
160205 /* TODO check MPI thread mode */
161206 status = ucp_worker_create (ompi_pml_ucx .ucp_context , UCS_THREAD_MODE_SINGLE ,
162- & ompi_pml_ucx .ucp_worker );
207+ & ompi_pml_ucx .ucp_worker );
163208 if (UCS_OK != status ) {
164209 return OMPI_ERROR ;
165210 }
@@ -203,13 +248,18 @@ int mca_pml_ucx_cleanup(void)
203248 ompi_pml_ucx .ucp_worker = NULL ;
204249 }
205250
251+ if (ompi_pml_ucx .using_mem_hooks ) {
252+ opal_mem_hooks_unregister_release (mca_pml_ucx_mem_release_cb );
253+ }
254+
206255 return OMPI_SUCCESS ;
207256}
208257
209258int mca_pml_ucx_add_procs (struct ompi_proc_t * * procs , size_t nprocs )
210259{
211260 ucp_address_t * address ;
212261 ucs_status_t status ;
262+ ompi_proc_t * proc ;
213263 size_t addrlen ;
214264 ucp_ep_h ep ;
215265 size_t i ;
@@ -222,44 +272,106 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
222272 }
223273
224274 for (i = 0 ; i < nprocs ; ++ i ) {
225- ret = mca_pml_ucx_recv_worker_address (procs [i ], & address , & addrlen );
275+ proc = procs [(i + OMPI_PROC_MY_NAME -> vpid ) % nprocs ];
276+
277+ ret = mca_pml_ucx_recv_worker_address (proc , & address , & addrlen );
226278 if (ret < 0 ) {
279+ PML_UCX_ERROR ("Failed to receive worker address from proc: %d" ,
280+ proc -> proc_name .vpid );
227281 return ret ;
228282 }
229283
230- if (procs [ i ] -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ]) {
231- PML_UCX_VERBOSE (3 , "already connected to proc. %d" , procs [ i ] -> proc_name .vpid );
284+ if (proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ]) {
285+ PML_UCX_VERBOSE (3 , "already connected to proc. %d" , proc -> proc_name .vpid );
232286 continue ;
233287 }
234288
235- PML_UCX_VERBOSE (2 , "connecting to proc. %d" , procs [ i ] -> proc_name .vpid );
289+ PML_UCX_VERBOSE (2 , "connecting to proc. %d" , proc -> proc_name .vpid );
236290 status = ucp_ep_create (ompi_pml_ucx .ucp_worker , address , & ep );
237291 free (address );
238292
239293 if (UCS_OK != status ) {
240- PML_UCX_ERROR ("Failed to connect" );
294+ PML_UCX_ERROR ("Failed to connect to proc: %d, %s" , proc -> proc_name .vpid ,
295+ ucs_status_string (status ));
241296 return OMPI_ERROR ;
242297 }
243298
244- procs [ i ] -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = ep ;
299+ proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = ep ;
245300 }
246301
247302 return OMPI_SUCCESS ;
248303}
249304
305+ static void mca_pml_ucx_waitall (void * * reqs , size_t * count_p )
306+ {
307+ ucs_status_t status ;
308+ size_t i ;
309+
310+ PML_UCX_VERBOSE (2 , "waiting for %d disconnect requests" , * count_p );
311+ for (i = 0 ; i < * count_p ; ++ i ) {
312+ do {
313+ opal_progress ();
314+ status = ucp_request_test (reqs [i ], NULL );
315+ } while (status == UCS_INPROGRESS );
316+ if (status != UCS_OK ) {
317+ PML_UCX_ERROR ("disconnect request failed: %s" ,
318+ ucs_status_string (status ));
319+ }
320+ ucp_request_release (reqs [i ]);
321+ reqs [i ] = NULL ;
322+ }
323+
324+ * count_p = 0 ;
325+ }
326+
250327int mca_pml_ucx_del_procs (struct ompi_proc_t * * procs , size_t nprocs )
251328{
329+ ompi_proc_t * proc ;
330+ size_t num_reqs , max_reqs ;
331+ void * dreq , * * dreqs ;
252332 ucp_ep_h ep ;
253333 size_t i ;
254334
335+ max_reqs = ompi_pml_ucx .num_disconnect ;
336+ if (max_reqs > nprocs ) {
337+ max_reqs = nprocs ;
338+ }
339+
340+ dreqs = malloc (sizeof (* dreqs ) * max_reqs );
341+ if (dreqs == NULL ) {
342+ return OMPI_ERR_OUT_OF_RESOURCE ;
343+ }
344+
345+ num_reqs = 0 ;
346+
255347 for (i = 0 ; i < nprocs ; ++ i ) {
256- PML_UCX_VERBOSE (2 , "disconnecting from rank %d" , procs [i ]-> proc_name .vpid );
257- ep = procs [i ]-> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ];
258- if (ep != NULL ) {
259- ucp_ep_destroy (ep );
348+ proc = procs [(i + OMPI_PROC_MY_NAME -> vpid ) % nprocs ];
349+ ep = proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ];
350+ if (ep == NULL ) {
351+ continue ;
352+ }
353+
354+ PML_UCX_VERBOSE (2 , "disconnecting from rank %d" , proc -> proc_name .vpid );
355+ dreq = ucp_disconnect_nb (ep );
356+ if (dreq != NULL ) {
357+ if (UCS_PTR_IS_ERR (dreq )) {
358+ PML_UCX_ERROR ("ucp_disconnect_nb(%d) failed: %s" ,
359+ proc -> proc_name .vpid ,
360+ ucs_status_string (UCS_PTR_STATUS (dreq )));
361+ } else {
362+ dreqs [num_reqs ++ ] = dreq ;
363+ }
364+ }
365+
366+ proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = NULL ;
367+
368+ if (num_reqs >= ompi_pml_ucx .num_disconnect ) {
369+ mca_pml_ucx_waitall (dreqs , & num_reqs );
260370 }
261- procs [i ]-> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = NULL ;
262371 }
372+
373+ mca_pml_ucx_waitall (dreqs , & num_reqs );
374+ free (dreqs );
263375 return OMPI_SUCCESS ;
264376}
265377
@@ -276,14 +388,7 @@ int mca_pml_ucx_enable(bool enable)
276388
277389int mca_pml_ucx_progress (void )
278390{
279- static int inprogress = 0 ;
280- if (inprogress != 0 ) {
281- return 0 ;
282- }
283-
284- ++ inprogress ;
285391 ucp_worker_progress (ompi_pml_ucx .ucp_worker );
286- -- inprogress ;
287392 return OMPI_SUCCESS ;
288393}
289394
@@ -348,53 +453,32 @@ int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype,
348453 return OMPI_SUCCESS ;
349454}
350455
351- static void
352- mca_pml_ucx_blocking_recv_completion (void * request , ucs_status_t status ,
353- ucp_tag_recv_info_t * info )
354- {
355- ompi_request_t * req = request ;
356-
357- PML_UCX_VERBOSE (8 , "blocking receive request %p completed with status %s tag %" PRIx64 " len %zu" ,
358- (void * )req , ucs_status_string (status ), info -> sender_tag ,
359- info -> length );
360-
361- OPAL_THREAD_LOCK (& ompi_request_lock );
362- mca_pml_ucx_set_recv_status (& req -> req_status , status , info );
363- PML_UCX_ASSERT (!req -> req_complete );
364- req -> req_complete = true;
365- OPAL_THREAD_UNLOCK (& ompi_request_lock );
366- }
367-
368456int mca_pml_ucx_recv (void * buf , size_t count , ompi_datatype_t * datatype , int src ,
369457 int tag , struct ompi_communicator_t * comm ,
370458 ompi_status_public_t * mpi_status )
371459{
372460 ucp_tag_t ucp_tag , ucp_tag_mask ;
373- ompi_request_t * req ;
461+ ucp_tag_recv_info_t info ;
462+ ucs_status_t status ;
463+ void * req ;
374464
375465 PML_UCX_TRACE_RECV ("%s" , buf , count , datatype , src , tag , comm , "recv" );
376466
377467 PML_UCX_MAKE_RECV_TAG (ucp_tag , ucp_tag_mask , tag , src , comm );
378- req = (ompi_request_t * )ucp_tag_recv_nb (ompi_pml_ucx .ucp_worker , buf , count ,
379- mca_pml_ucx_get_datatype (datatype ),
380- ucp_tag , ucp_tag_mask ,
381- mca_pml_ucx_blocking_recv_completion );
382- if (UCS_PTR_IS_ERR (req )) {
383- PML_UCX_ERROR ("ucx recv failed: %s" , ucs_status_string (UCS_PTR_STATUS (req )));
384- return OMPI_ERROR ;
385- }
468+ req = alloca (ompi_pml_ucx .request_size ) + ompi_pml_ucx .request_size ;
469+ status = ucp_tag_recv_nbr (ompi_pml_ucx .ucp_worker , buf , count ,
470+ mca_pml_ucx_get_datatype (datatype ),
471+ ucp_tag , ucp_tag_mask , req );
386472
387- while (!req -> req_complete ) {
473+ ucp_worker_progress (ompi_pml_ucx .ucp_worker );
474+ for (;;) {
475+ status = ucp_request_test (req , & info );
476+ if (status != UCS_INPROGRESS ) {
477+ mca_pml_ucx_set_recv_status_safe (mpi_status , status , & info );
478+ return OMPI_SUCCESS ;
479+ }
388480 opal_progress ();
389481 }
390-
391- if (mpi_status != MPI_STATUS_IGNORE ) {
392- * mpi_status = req -> req_status ;
393- }
394-
395- req -> req_complete = false;
396- ucp_request_release (req );
397- return OMPI_SUCCESS ;
398482}
399483
400484static inline const char * mca_pml_ucx_send_mode_name (mca_pml_base_send_mode_t mode )
@@ -490,10 +574,11 @@ int mca_pml_ucx_send(void *buf, size_t count, ompi_datatype_t *datatype, int dst
490574 mca_pml_ucx_get_datatype (datatype ),
491575 PML_UCX_MAKE_SEND_TAG (tag , comm ),
492576 mca_pml_ucx_send_completion );
493- if (req == NULL ) {
577+ if (OPAL_LIKELY ( req == NULL ) ) {
494578 return OMPI_SUCCESS ;
495579 } else if (!UCS_PTR_IS_ERR (req )) {
496580 PML_UCX_VERBOSE (8 , "got request %p" , (void * )req );
581+ ucp_worker_progress (ompi_pml_ucx .ucp_worker );
497582 ompi_request_wait (& req , MPI_STATUS_IGNORE );
498583 return OMPI_SUCCESS ;
499584 } else {
@@ -518,6 +603,7 @@ int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm,
518603 * matched = 1 ;
519604 mca_pml_ucx_set_recv_status_safe (mpi_status , UCS_OK , & info );
520605 } else {
606+ opal_progress ();
521607 * matched = 0 ;
522608 }
523609 return OMPI_SUCCESS ;
@@ -563,7 +649,8 @@ int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm,
563649 PML_UCX_VERBOSE (8 , "got message %p (%p)" , (void * )* message , (void * )ucp_msg );
564650 * matched = 1 ;
565651 mca_pml_ucx_set_recv_status_safe (mpi_status , UCS_OK , & info );
566- } else if (UCS_PTR_STATUS (ucp_msg ) == UCS_ERR_NO_MESSAGE ) {
652+ } else {
653+ opal_progress ();
567654 * matched = 0 ;
568655 }
569656 return OMPI_SUCCESS ;
@@ -696,6 +783,7 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
696783 PML_UCX_VERBOSE (8 , "temporary request %p will complete persistent request %p" ,
697784 (void * )tmp_req , (void * )preq );
698785 tmp_req -> req_complete_cb_data = preq ;
786+ preq -> tmp_req = tmp_req ;
699787 }
700788 OPAL_THREAD_UNLOCK (& ompi_request_lock );
701789 } else {
0 commit comments