Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 114 additions & 57 deletions ompi/mca/pml/ucx/pml_ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ mca_pml_ucx_module_t ompi_pml_ucx = {
1ul << (PML_UCX_TAG_BITS - 1),
1ul << (PML_UCX_CONTEXT_BITS),
},
NULL,
NULL
NULL, /* ucp_context */
NULL /* ucp_worker */
};

static int mca_pml_ucx_send_worker_address(void)
Expand Down Expand Up @@ -116,6 +116,7 @@ static int mca_pml_ucx_recv_worker_address(ompi_proc_t *proc,

int mca_pml_ucx_open(void)
{
ucp_context_attr_t attr;
ucp_params_t params;
ucp_config_t *config;
ucs_status_t status;
Expand All @@ -128,10 +129,17 @@ int mca_pml_ucx_open(void)
return OMPI_ERROR;
}

/* Initialize UCX context */
params.field_mask = UCP_PARAM_FIELD_FEATURES |
UCP_PARAM_FIELD_REQUEST_SIZE |
UCP_PARAM_FIELD_REQUEST_INIT |
UCP_PARAM_FIELD_REQUEST_CLEANUP |
UCP_PARAM_FIELD_TAG_SENDER_MASK;
params.features = UCP_FEATURE_TAG;
params.request_size = sizeof(ompi_request_t);
params.request_init = mca_pml_ucx_request_init;
params.request_cleanup = mca_pml_ucx_request_cleanup;
params.tag_sender_mask = PML_UCX_SPECIFIC_SOURCE_MASK;

status = ucp_init(&params, config, &ompi_pml_ucx.ucp_context);
ucp_config_release(config);
Expand All @@ -140,6 +148,17 @@ int mca_pml_ucx_open(void)
return OMPI_ERROR;
}

/* Query UCX attributes */
attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE;
status = ucp_context_query(ompi_pml_ucx.ucp_context, &attr);
if (UCS_OK != status) {
ucp_cleanup(ompi_pml_ucx.ucp_context);
ompi_pml_ucx.ucp_context = NULL;
return OMPI_ERROR;
}

ompi_pml_ucx.request_size = attr.request_size;

return OMPI_SUCCESS;
}

Expand All @@ -163,7 +182,7 @@ int mca_pml_ucx_init(void)

/* TODO check MPI thread mode */
status = ucp_worker_create(ompi_pml_ucx.ucp_context, UCS_THREAD_MODE_SINGLE,
&ompi_pml_ucx.ucp_worker);
&ompi_pml_ucx.ucp_worker);
if (UCS_OK != status) {
return OMPI_ERROR;
}
Expand Down Expand Up @@ -252,6 +271,7 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
{
ucp_address_t *address;
ucs_status_t status;
ompi_proc_t *proc;
size_t addrlen;
ucp_ep_h ep;
size_t i;
Expand All @@ -264,47 +284,109 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
}

for (i = 0; i < nprocs; ++i) {
ret = mca_pml_ucx_recv_worker_address(procs[i], &address, &addrlen);
proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs];

ret = mca_pml_ucx_recv_worker_address(proc, &address, &addrlen);
if (ret < 0) {
PML_UCX_ERROR("Failed to receive worker address from proc: %d", procs[i]->super.proc_name.vpid);
PML_UCX_ERROR("Failed to receive worker address from proc: %d",
proc->super.proc_name.vpid);
return ret;
}

if (procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]) {
PML_UCX_VERBOSE(3, "already connected to proc. %d", procs[i]->super.proc_name.vpid);
if (proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]) {
PML_UCX_VERBOSE(3, "already connected to proc. %d", proc->super.proc_name.vpid);
continue;
}

PML_UCX_VERBOSE(2, "connecting to proc. %d", procs[i]->super.proc_name.vpid);
PML_UCX_VERBOSE(2, "connecting to proc. %d", proc->super.proc_name.vpid);
status = ucp_ep_create(ompi_pml_ucx.ucp_worker, address, &ep);
free(address);

if (UCS_OK != status) {
PML_UCX_ERROR("Failed to connect to proc: %d, %s", procs[i]->super.proc_name.vpid,
ucs_status_string(status));
PML_UCX_ERROR("Failed to connect to proc: %d, %s", proc->super.proc_name.vpid,
ucs_status_string(status));
return OMPI_ERROR;
}

procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep;
proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep;
}

return OMPI_SUCCESS;
}

static void mca_pml_ucx_waitall(void **reqs, size_t *count_p)
{
ucs_status_t status;
size_t i;

PML_UCX_VERBOSE(2, "waiting for %d disconnect requests", *count_p);
for (i = 0; i < *count_p; ++i) {
do {
opal_progress();
status = ucp_request_test(reqs[i], NULL);
} while (status == UCS_INPROGRESS);
if (status != UCS_OK) {
PML_UCX_ERROR("disconnect request failed: %s",
ucs_status_string(status));
}
ucp_request_release(reqs[i]);
reqs[i] = NULL;
}

*count_p = 0;
}

int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
{
ompi_proc_t *proc;
size_t num_reqs, max_reqs;
void *dreq, **dreqs;
ucp_ep_h ep;
size_t i;

max_reqs = ompi_pml_ucx.num_disconnect;
if (max_reqs > nprocs) {
max_reqs = nprocs;
}

dreqs = malloc(sizeof(*dreqs) * max_reqs);
if (dreqs == NULL) {
return OMPI_ERR_OUT_OF_RESOURCE;
}

num_reqs = 0;

for (i = 0; i < nprocs; ++i) {
PML_UCX_VERBOSE(2, "disconnecting from rank %d", procs[i]->super.proc_name.vpid);
ep = procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
if (ep != NULL) {
ucp_ep_destroy(ep);
proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs];
ep = proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
if (ep == NULL) {
continue;
}

PML_UCX_VERBOSE(2, "disconnecting from rank %d", proc->super.proc_name.vpid);
dreq = ucp_disconnect_nb(ep);
if (dreq != NULL) {
if (UCS_PTR_IS_ERR(dreq)) {
PML_UCX_ERROR("ucp_disconnect_nb(%d) failed: %s",
proc->super.proc_name.vpid,
ucs_status_string(UCS_PTR_STATUS(dreq)));
} else {
dreqs[num_reqs++] = dreq;
}
}

proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL;

if (num_reqs >= ompi_pml_ucx.num_disconnect) {
mca_pml_ucx_waitall(dreqs, &num_reqs);
}
procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL;
}

mca_pml_ucx_waitall(dreqs, &num_reqs);
free(dreqs);

opal_pmix.fence(NULL, 0);

return OMPI_SUCCESS;
}

Expand All @@ -321,14 +403,7 @@ int mca_pml_ucx_enable(bool enable)

int mca_pml_ucx_progress(void)
{
static int inprogress = 0;
if (inprogress != 0) {
return 0;
}

++inprogress;
ucp_worker_progress(ompi_pml_ucx.ucp_worker);
--inprogress;
return OMPI_SUCCESS;
}

Expand Down Expand Up @@ -393,52 +468,32 @@ int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype,
return OMPI_SUCCESS;
}

static void
mca_pml_ucx_blocking_recv_completion(void *request, ucs_status_t status,
ucp_tag_recv_info_t *info)
{
ompi_request_t *req = request;

PML_UCX_VERBOSE(8, "blocking receive request %p completed with status %s tag %"PRIx64" len %zu",
(void*)req, ucs_status_string(status), info->sender_tag,
info->length);

mca_pml_ucx_set_recv_status(&req->req_status, status, info);
PML_UCX_ASSERT( !(REQUEST_COMPLETE(req)));
ompi_request_complete(req,true);
}

int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src,
int tag, struct ompi_communicator_t* comm,
ompi_status_public_t* mpi_status)
{
ucp_tag_t ucp_tag, ucp_tag_mask;
ompi_request_t *req;
ucp_tag_recv_info_t info;
ucs_status_t status;
void *req;

PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv");

PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker, buf, count,
mca_pml_ucx_get_datatype(datatype),
ucp_tag, ucp_tag_mask,
mca_pml_ucx_blocking_recv_completion);
if (UCS_PTR_IS_ERR(req)) {
PML_UCX_ERROR("ucx recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
return OMPI_ERROR;
}
req = alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size;
status = ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count,
mca_pml_ucx_get_datatype(datatype),
ucp_tag, ucp_tag_mask, req);

ucp_worker_progress(ompi_pml_ucx.ucp_worker);
while ( !REQUEST_COMPLETE(req) ) {
for (;;) {
status = ucp_request_test(req, &info);
if (status != UCS_INPROGRESS) {
mca_pml_ucx_set_recv_status_safe(mpi_status, status, &info);
return OMPI_SUCCESS;
}
opal_progress();
}

if (mpi_status != MPI_STATUS_IGNORE) {
*mpi_status = req->req_status;
}

req->req_complete = REQUEST_PENDING;
ucp_request_release(req);
return OMPI_SUCCESS;
}

static inline const char *mca_pml_ucx_send_mode_name(mca_pml_base_send_mode_t mode)
Expand Down Expand Up @@ -583,6 +638,7 @@ int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm,
*matched = 1;
mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
} else {
opal_progress();
*matched = 0;
}
return OMPI_SUCCESS;
Expand Down Expand Up @@ -628,7 +684,8 @@ int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm,
PML_UCX_VERBOSE(8, "got message %p (%p)", (void*)*message, (void*)ucp_msg);
*matched = 1;
mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
} else if (UCS_PTR_STATUS(ucp_msg) == UCS_ERR_NO_MESSAGE) {
} else {
opal_progress();
*matched = 0;
}
return OMPI_SUCCESS;
Expand Down
4 changes: 3 additions & 1 deletion ompi/mca/pml/ucx/pml_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ struct mca_pml_ucx_module {
/* Requests */
mca_pml_ucx_freelist_t persistent_reqs;
ompi_request_t completed_send_req;
size_t request_size;
int num_disconnect;

/* Convertors pool */
/* Converters pool */
mca_pml_ucx_freelist_t convs;

int priority;
Expand Down
7 changes: 7 additions & 0 deletions ompi/mca/pml/ucx/pml_ucx_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ static int mca_pml_ucx_component_register(void)
MCA_BASE_VAR_SCOPE_LOCAL,
&ompi_pml_ucx.priority);

ompi_pml_ucx.num_disconnect = 1;
(void) mca_base_component_var_register(&mca_pml_ucx_component.pmlm_version, "num_disconnect",
"How many disconnects to do in parallel",
MCA_BASE_VAR_TYPE_INT, NULL, 0, 0,
OPAL_INFO_LVL_3,
MCA_BASE_VAR_SCOPE_LOCAL,
&ompi_pml_ucx.num_disconnect);
return 0;
}

Expand Down
9 changes: 6 additions & 3 deletions ompi/mca/pml/ucx/pml_ucx_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ enum {
#define PML_UCX_TAG_BITS 24
#define PML_UCX_RANK_BITS 24
#define PML_UCX_CONTEXT_BITS 16
#define PML_UCX_ANY_SOURCE_MASK 0x800000000000fffful
#define PML_UCX_SPECIFIC_SOURCE_MASK 0x800000fffffffffful
#define PML_UCX_TAG_MASK 0x7fffff0000000000ul


#define PML_UCX_MAKE_SEND_TAG(_tag, _comm) \
Expand All @@ -45,16 +48,16 @@ enum {
#define PML_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _src, _comm) \
{ \
if ((_src) == MPI_ANY_SOURCE) { \
_ucp_tag_mask = 0x800000000000fffful; \
_ucp_tag_mask = PML_UCX_ANY_SOURCE_MASK; \
} else { \
_ucp_tag_mask = 0x800000fffffffffful; \
_ucp_tag_mask = PML_UCX_SPECIFIC_SOURCE_MASK; \
} \
\
_ucp_tag = (((uint64_t)(_src) & UCS_MASK(PML_UCX_RANK_BITS)) << PML_UCX_CONTEXT_BITS) | \
(_comm)->c_contextid; \
\
if ((_tag) != MPI_ANY_TAG) { \
_ucp_tag_mask |= 0x7fffff0000000000ul; \
_ucp_tag_mask |= PML_UCX_TAG_MASK; \
_ucp_tag |= ((uint64_t)(_tag)) << (PML_UCX_RANK_BITS + PML_UCX_CONTEXT_BITS); \
} \
}
Expand Down
Loading