diff --git a/ompi/mca/pml/ucx/pml_ucx.c b/ompi/mca/pml/ucx/pml_ucx.c index 780c61abca2..59885790eec 100644 --- a/ompi/mca/pml/ucx/pml_ucx.c +++ b/ompi/mca/pml/ucx/pml_ucx.c @@ -70,8 +70,8 @@ mca_pml_ucx_module_t ompi_pml_ucx = { 1ul << (PML_UCX_TAG_BITS - 1), 1ul << (PML_UCX_CONTEXT_BITS), }, - NULL, - NULL + NULL, /* ucp_context */ + NULL /* ucp_worker */ }; static int mca_pml_ucx_send_worker_address(void) @@ -116,6 +116,7 @@ static int mca_pml_ucx_recv_worker_address(ompi_proc_t *proc, int mca_pml_ucx_open(void) { + ucp_context_attr_t attr; ucp_params_t params; ucp_config_t *config; ucs_status_t status; @@ -128,10 +129,17 @@ int mca_pml_ucx_open(void) return OMPI_ERROR; } + /* Initialize UCX context */ + params.field_mask = UCP_PARAM_FIELD_FEATURES | + UCP_PARAM_FIELD_REQUEST_SIZE | + UCP_PARAM_FIELD_REQUEST_INIT | + UCP_PARAM_FIELD_REQUEST_CLEANUP | + UCP_PARAM_FIELD_TAG_SENDER_MASK; params.features = UCP_FEATURE_TAG; params.request_size = sizeof(ompi_request_t); params.request_init = mca_pml_ucx_request_init; params.request_cleanup = mca_pml_ucx_request_cleanup; + params.tag_sender_mask = PML_UCX_SPECIFIC_SOURCE_MASK; status = ucp_init(¶ms, config, &ompi_pml_ucx.ucp_context); ucp_config_release(config); @@ -140,6 +148,17 @@ int mca_pml_ucx_open(void) return OMPI_ERROR; } + /* Query UCX attributes */ + attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE; + status = ucp_context_query(ompi_pml_ucx.ucp_context, &attr); + if (UCS_OK != status) { + ucp_cleanup(ompi_pml_ucx.ucp_context); + ompi_pml_ucx.ucp_context = NULL; + return OMPI_ERROR; + } + + ompi_pml_ucx.request_size = attr.request_size; + return OMPI_SUCCESS; } @@ -163,7 +182,7 @@ int mca_pml_ucx_init(void) /* TODO check MPI thread mode */ status = ucp_worker_create(ompi_pml_ucx.ucp_context, UCS_THREAD_MODE_SINGLE, - &ompi_pml_ucx.ucp_worker); + &ompi_pml_ucx.ucp_worker); if (UCS_OK != status) { return OMPI_ERROR; } @@ -252,6 +271,7 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs) { ucp_address_t *address; ucs_status_t status; + ompi_proc_t *proc; size_t addrlen; ucp_ep_h ep; size_t i; @@ -264,47 +284,109 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs) } for (i = 0; i < nprocs; ++i) { - ret = mca_pml_ucx_recv_worker_address(procs[i], &address, &addrlen); + proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs]; + + ret = mca_pml_ucx_recv_worker_address(proc, &address, &addrlen); if (ret < 0) { - PML_UCX_ERROR("Failed to receive worker address from proc: %d", procs[i]->super.proc_name.vpid); + PML_UCX_ERROR("Failed to receive worker address from proc: %d", + proc->super.proc_name.vpid); return ret; } - if (procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]) { - PML_UCX_VERBOSE(3, "already connected to proc. %d", procs[i]->super.proc_name.vpid); + if (proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]) { + PML_UCX_VERBOSE(3, "already connected to proc. %d", proc->super.proc_name.vpid); continue; } - PML_UCX_VERBOSE(2, "connecting to proc. %d", procs[i]->super.proc_name.vpid); + PML_UCX_VERBOSE(2, "connecting to proc. %d", proc->super.proc_name.vpid); status = ucp_ep_create(ompi_pml_ucx.ucp_worker, address, &ep); free(address); if (UCS_OK != status) { - PML_UCX_ERROR("Failed to connect to proc: %d, %s", procs[i]->super.proc_name.vpid, - ucs_status_string(status)); + PML_UCX_ERROR("Failed to connect to proc: %d, %s", proc->super.proc_name.vpid, + ucs_status_string(status)); return OMPI_ERROR; } - procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep; + proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep; } return OMPI_SUCCESS; } +static void mca_pml_ucx_waitall(void **reqs, size_t *count_p) +{ + ucs_status_t status; + size_t i; + + PML_UCX_VERBOSE(2, "waiting for %d disconnect requests", *count_p); + for (i = 0; i < *count_p; ++i) { + do { + opal_progress(); + status = ucp_request_test(reqs[i], NULL); + } while (status == UCS_INPROGRESS); + if (status != UCS_OK) { + PML_UCX_ERROR("disconnect request failed: %s", + ucs_status_string(status)); + } + ucp_request_release(reqs[i]); + reqs[i] = NULL; + } + + *count_p = 0; +} + int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs) { + ompi_proc_t *proc; + size_t num_reqs, max_reqs; + void *dreq, **dreqs; ucp_ep_h ep; size_t i; + max_reqs = ompi_pml_ucx.num_disconnect; + if (max_reqs > nprocs) { + max_reqs = nprocs; + } + + dreqs = malloc(sizeof(*dreqs) * max_reqs); + if (dreqs == NULL) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + + num_reqs = 0; + for (i = 0; i < nprocs; ++i) { - PML_UCX_VERBOSE(2, "disconnecting from rank %d", procs[i]->super.proc_name.vpid); - ep = procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]; - if (ep != NULL) { - ucp_ep_destroy(ep); + proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs]; + ep = proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]; + if (ep == NULL) { + continue; + } + + PML_UCX_VERBOSE(2, "disconnecting from rank %d", proc->super.proc_name.vpid); + dreq = ucp_disconnect_nb(ep); + if (dreq != NULL) { + if (UCS_PTR_IS_ERR(dreq)) { + PML_UCX_ERROR("ucp_disconnect_nb(%d) failed: %s", + proc->super.proc_name.vpid, + ucs_status_string(UCS_PTR_STATUS(dreq))); + } else { + dreqs[num_reqs++] = dreq; + } + } + + proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL; + + if (num_reqs >= ompi_pml_ucx.num_disconnect) { + mca_pml_ucx_waitall(dreqs, &num_reqs); } - procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL; } + + mca_pml_ucx_waitall(dreqs, &num_reqs); + free(dreqs); + opal_pmix.fence(NULL, 0); + return OMPI_SUCCESS; } @@ -321,14 +403,7 @@ int mca_pml_ucx_enable(bool enable) int mca_pml_ucx_progress(void) { - static int inprogress = 0; - if (inprogress != 0) { - return 0; - } - - ++inprogress; ucp_worker_progress(ompi_pml_ucx.ucp_worker); - --inprogress; return OMPI_SUCCESS; } @@ -393,52 +468,32 @@ int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype, return OMPI_SUCCESS; } -static void -mca_pml_ucx_blocking_recv_completion(void *request, ucs_status_t status, - ucp_tag_recv_info_t *info) -{ - ompi_request_t *req = request; - - PML_UCX_VERBOSE(8, "blocking receive request %p completed with status %s tag %"PRIx64" len %zu", - (void*)req, ucs_status_string(status), info->sender_tag, - info->length); - - mca_pml_ucx_set_recv_status(&req->req_status, status, info); - PML_UCX_ASSERT( !(REQUEST_COMPLETE(req))); - ompi_request_complete(req,true); -} - int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src, int tag, struct ompi_communicator_t* comm, ompi_status_public_t* mpi_status) { ucp_tag_t ucp_tag, ucp_tag_mask; - ompi_request_t *req; + ucp_tag_recv_info_t info; + ucs_status_t status; + void *req; PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv"); PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm); - req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker, buf, count, - mca_pml_ucx_get_datatype(datatype), - ucp_tag, ucp_tag_mask, - mca_pml_ucx_blocking_recv_completion); - if (UCS_PTR_IS_ERR(req)) { - PML_UCX_ERROR("ucx recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req))); - return OMPI_ERROR; - } + req = alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size; + status = ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count, + mca_pml_ucx_get_datatype(datatype), + ucp_tag, ucp_tag_mask, req); ucp_worker_progress(ompi_pml_ucx.ucp_worker); - while ( !REQUEST_COMPLETE(req) ) { + for (;;) { + status = ucp_request_test(req, &info); + if (status != UCS_INPROGRESS) { + mca_pml_ucx_set_recv_status_safe(mpi_status, status, &info); + return OMPI_SUCCESS; + } opal_progress(); } - - if (mpi_status != MPI_STATUS_IGNORE) { - *mpi_status = req->req_status; - } - - req->req_complete = REQUEST_PENDING; - ucp_request_release(req); - return OMPI_SUCCESS; } static inline const char *mca_pml_ucx_send_mode_name(mca_pml_base_send_mode_t mode) @@ -583,6 +638,7 @@ int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm, *matched = 1; mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info); } else { + opal_progress(); *matched = 0; } return OMPI_SUCCESS; @@ -628,7 +684,8 @@ int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm, PML_UCX_VERBOSE(8, "got message %p (%p)", (void*)*message, (void*)ucp_msg); *matched = 1; mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info); - } else if (UCS_PTR_STATUS(ucp_msg) == UCS_ERR_NO_MESSAGE) { + } else { + opal_progress(); *matched = 0; } return OMPI_SUCCESS; diff --git a/ompi/mca/pml/ucx/pml_ucx.h b/ompi/mca/pml/ucx/pml_ucx.h index 2f50cb27770..44320b2a48e 100644 --- a/ompi/mca/pml/ucx/pml_ucx.h +++ b/ompi/mca/pml/ucx/pml_ucx.h @@ -40,8 +40,10 @@ struct mca_pml_ucx_module { /* Requests */ mca_pml_ucx_freelist_t persistent_reqs; ompi_request_t completed_send_req; + size_t request_size; + int num_disconnect; - /* Convertors pool */ + /* Converters pool */ mca_pml_ucx_freelist_t convs; int priority; diff --git a/ompi/mca/pml/ucx/pml_ucx_component.c b/ompi/mca/pml/ucx/pml_ucx_component.c index 31889a6fb20..528bfd871ed 100644 --- a/ompi/mca/pml/ucx/pml_ucx_component.c +++ b/ompi/mca/pml/ucx/pml_ucx_component.c @@ -63,6 +63,13 @@ static int mca_pml_ucx_component_register(void) MCA_BASE_VAR_SCOPE_LOCAL, &ompi_pml_ucx.priority); + ompi_pml_ucx.num_disconnect = 1; + (void) mca_base_component_var_register(&mca_pml_ucx_component.pmlm_version, "num_disconnect", + "How many disconnects to do in parallel", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, + OPAL_INFO_LVL_3, + MCA_BASE_VAR_SCOPE_LOCAL, + &ompi_pml_ucx.num_disconnect); return 0; } diff --git a/ompi/mca/pml/ucx/pml_ucx_request.h b/ompi/mca/pml/ucx/pml_ucx_request.h index db9f089b4cf..3176eb60f0a 100644 --- a/ompi/mca/pml/ucx/pml_ucx_request.h +++ b/ompi/mca/pml/ucx/pml_ucx_request.h @@ -34,6 +34,9 @@ enum { #define PML_UCX_TAG_BITS 24 #define PML_UCX_RANK_BITS 24 #define PML_UCX_CONTEXT_BITS 16 +#define PML_UCX_ANY_SOURCE_MASK 0x800000000000fffful +#define PML_UCX_SPECIFIC_SOURCE_MASK 0x800000fffffffffful +#define PML_UCX_TAG_MASK 0x7fffff0000000000ul #define PML_UCX_MAKE_SEND_TAG(_tag, _comm) \ @@ -45,16 +48,16 @@ enum { #define PML_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _src, _comm) \ { \ if ((_src) == MPI_ANY_SOURCE) { \ - _ucp_tag_mask = 0x800000000000fffful; \ + _ucp_tag_mask = PML_UCX_ANY_SOURCE_MASK; \ } else { \ - _ucp_tag_mask = 0x800000fffffffffful; \ + _ucp_tag_mask = PML_UCX_SPECIFIC_SOURCE_MASK; \ } \ \ _ucp_tag = (((uint64_t)(_src) & UCS_MASK(PML_UCX_RANK_BITS)) << PML_UCX_CONTEXT_BITS) | \ (_comm)->c_contextid; \ \ if ((_tag) != MPI_ANY_TAG) { \ - _ucp_tag_mask |= 0x7fffff0000000000ul; \ + _ucp_tag_mask |= PML_UCX_TAG_MASK; \ _ucp_tag |= ((uint64_t)(_tag)) << (PML_UCX_RANK_BITS + PML_UCX_CONTEXT_BITS); \ } \ } diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index 79c81f1c930..426814ec797 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -65,7 +65,13 @@ mca_spml_ucx_t mca_spml_ucx = { mca_spml_ucx_rmkey_unpack, mca_spml_ucx_rmkey_free, (void*)&mca_spml_ucx - } + }, + + NULL, /* ucp_context */ + NULL, /* ucp_worker */ + NULL, /* ucp_peers */ + 0, /* using_mem_hooks */ + 1 /* num_disconnect */ }; int mca_spml_ucx_enable(bool enable) @@ -80,10 +86,37 @@ int mca_spml_ucx_enable(bool enable) return OSHMEM_SUCCESS; } + +static void mca_spml_ucx_waitall(void **reqs, size_t *count_p) +{ + ucs_status_t status; + size_t i; + + SPML_VERBOSE(10, "waiting for %d disconnect requests", *count_p); + for (i = 0; i < *count_p; ++i) { + do { + opal_progress(); + status = ucp_request_test(reqs[i], NULL); + } while (status == UCS_INPROGRESS); + if (status != UCS_OK) { + SPML_ERROR("disconnect request failed: %s", + ucs_status_string(status)); + } + ucp_request_release(reqs[i]); + reqs[i] = NULL; + } + + *count_p = 0; +} + int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs) { - size_t i, n; int my_rank = oshmem_my_proc_id(); + size_t num_reqs, max_reqs; + void *dreq, **dreqs; + ompi_proc_t *proc; + ucp_ep_h ep; + size_t i, n; oshmem_shmem_barrier(); @@ -91,12 +124,45 @@ int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs) return OSHMEM_SUCCESS; } - for (n = 0; n < nprocs; n++) { - i = (my_rank + n) % nprocs; - if (mca_spml_ucx.ucp_peers[i].ucp_conn) { - ucp_ep_destroy(mca_spml_ucx.ucp_peers[i].ucp_conn); - } - } + max_reqs = mca_spml_ucx.num_disconnect; + if (max_reqs > nprocs) { + max_reqs = nprocs; + } + + dreqs = malloc(sizeof(*dreqs) * max_reqs); + if (dreqs == NULL) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + + num_reqs = 0; + + for (i = 0; i < nprocs; ++i) { + n = (i + my_rank) % nprocs; + ep = mca_spml_ucx.ucp_peers[n].ucp_conn; + if (ep == NULL) { + continue; + } + + SPML_VERBOSE(10, "disconnecting from peer %d", n); + dreq = ucp_disconnect_nb(ep); + if (dreq != NULL) { + if (UCS_PTR_IS_ERR(dreq)) { + SPML_ERROR("ucp_disconnect_nb(%d) failed: %s", n, + ucs_status_string(UCS_PTR_STATUS(dreq))); + } else { + dreqs[num_reqs++] = dreq; + } + } + + mca_spml_ucx.ucp_peers[n].ucp_conn = NULL; + + if (num_reqs >= mca_spml_ucx.num_disconnect) { + mca_spml_ucx_waitall(dreqs, &num_reqs); + } + } + + mca_spml_ucx_waitall(dreqs, &num_reqs); + free(dreqs); free(mca_spml_ucx.ucp_peers); return OSHMEM_SUCCESS; diff --git a/oshmem/mca/spml/ucx/spml_ucx.h b/oshmem/mca/spml/ucx/spml_ucx.h index 6887ecac904..615bb30c77e 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.h +++ b/oshmem/mca/spml/ucx/spml_ucx.h @@ -50,6 +50,7 @@ struct mca_spml_ucx { ucp_context_h ucp_context; ucp_worker_h ucp_worker; ucp_peer_t *ucp_peers; + int num_disconnect; int priority; /* component priority */ bool enabled; diff --git a/oshmem/mca/spml/ucx/spml_ucx_component.c b/oshmem/mca/spml/ucx/spml_ucx_component.c index 22fae41e295..687ce06a746 100644 --- a/oshmem/mca/spml/ucx/spml_ucx_component.c +++ b/oshmem/mca/spml/ucx/spml_ucx_component.c @@ -97,6 +97,10 @@ static int mca_spml_ucx_component_register(void) "[integer] ucx priority", &mca_spml_ucx.priority); + mca_spml_ucx_param_register_int("num_disconnect", 1, + "How many disconnects to do in parallel", + &mca_spml_ucx.num_disconnect); + return OSHMEM_SUCCESS; } @@ -118,7 +122,8 @@ static int mca_spml_ucx_component_open(void) } memset(¶ms, 0, sizeof(params)); - params.features = UCP_FEATURE_RMA|UCP_FEATURE_AMO32|UCP_FEATURE_AMO64; + params.field_mask = UCP_PARAM_FIELD_FEATURES; + params.features = UCP_FEATURE_RMA|UCP_FEATURE_AMO32|UCP_FEATURE_AMO64; err = ucp_init(¶ms, ucp_config, &mca_spml_ucx.ucp_context); ucp_config_release(ucp_config); @@ -131,7 +136,10 @@ static int mca_spml_ucx_component_open(void) static int mca_spml_ucx_component_close(void) { - ucp_cleanup(mca_spml_ucx.ucp_context); + if (mca_spml_ucx.ucp_context) { + ucp_cleanup(mca_spml_ucx.ucp_context); + mca_spml_ucx.ucp_context = NULL; + } return OSHMEM_SUCCESS; }