Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 107 additions & 16 deletions ompi/mca/pml/ucx/pml_ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "opal/runtime/opal.h"
#include "opal/mca/pmix/pmix.h"
#include "ompi/message/message.h"
#include "ompi/mca/pml/base/pml_base_bsend.h"
#include "pml_ucx_request.h"

#include <inttypes.h>
Expand Down Expand Up @@ -333,7 +334,7 @@ static void mca_pml_ucx_waitall(void **reqs, size_t *count_p)
ucs_status_t status;
size_t i;

PML_UCX_VERBOSE(2, "waiting for %d disconnect requests", *count_p);
PML_UCX_VERBOSE(2, "waiting for %d disconnect requests", (int)*count_p);
for (i = 0; i < *count_p; ++i) {
do {
opal_progress();
Expand All @@ -343,7 +344,7 @@ static void mca_pml_ucx_waitall(void **reqs, size_t *count_p)
PML_UCX_ERROR("disconnect request failed: %s",
ucs_status_string(status));
}
ucp_request_release(reqs[i]);
ucp_request_free(reqs[i]);
reqs[i] = NULL;
}

Expand Down Expand Up @@ -391,7 +392,7 @@ int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)

proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL;

if (num_reqs >= ompi_pml_ucx.num_disconnect) {
if ((int)num_reqs >= ompi_pml_ucx.num_disconnect) {
mca_pml_ucx_waitall(dreqs, &num_reqs);
}
}
Expand Down Expand Up @@ -494,7 +495,7 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src
PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv");

PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
req = alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size;
req = (char *)alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size;
status = ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count,
mca_pml_ucx_get_datatype(datatype),
ucp_tag, ucp_tag_mask, req);
Expand Down Expand Up @@ -556,15 +557,80 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat
req->flags = MCA_PML_UCX_REQUEST_FLAG_SEND;
req->buffer = (void *)buf;
req->count = count;
req->datatype = mca_pml_ucx_get_datatype(datatype);
req->tag = PML_UCX_MAKE_SEND_TAG(tag, comm);
req->send.mode = mode;
req->send.ep = ep;
if (MCA_PML_BASE_SEND_BUFFERED == mode) {
req->ompi_datatype = datatype;
OBJ_RETAIN(datatype);
} else {
req->datatype = mca_pml_ucx_get_datatype(datatype);
}

*request = &req->ompi;
return OMPI_SUCCESS;
}

static int
mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
ompi_datatype_t *datatype, uint64_t pml_tag)
{
ompi_request_t *req;
void *packed_data;
size_t packed_length;
size_t offset;
uint32_t iov_count;
struct iovec iov;
opal_convertor_t opal_conv;

OBJ_CONSTRUCT(&opal_conv, opal_convertor_t);
opal_convertor_copy_and_prepare_for_recv(ompi_proc_local_proc->super.proc_convertor,
&datatype->super, count, buf, 0,
&opal_conv);
opal_convertor_get_packed_size(&opal_conv, &packed_length);

packed_data = mca_pml_base_bsend_request_alloc_buf(packed_length);
if (OPAL_UNLIKELY(NULL == packed_data)) {
OBJ_DESTRUCT(&opal_conv);
PML_UCX_ERROR("bsend: failed to allocate buffer");
return OMPI_ERR_OUT_OF_RESOURCE;
}

iov_count = 1;
iov.iov_base = packed_data;
iov.iov_len = packed_length;

PML_UCX_VERBOSE(8, "bsend of packed buffer %p len %d\n", packed_data, packed_length);
offset = 0;
opal_convertor_set_position(&opal_conv, &offset);
if (0 > opal_convertor_pack(&opal_conv, &iov, &iov_count, &packed_length)) {
mca_pml_base_bsend_request_free(packed_data);
OBJ_DESTRUCT(&opal_conv);
PML_UCX_ERROR("bsend: failed to pack user datatype");
return OMPI_ERROR;
}

OBJ_DESTRUCT(&opal_conv);

req = (ompi_request_t*)ucp_tag_send_nb(ep, packed_data, packed_length,
ucp_dt_make_contig(1), pml_tag,
mca_pml_ucx_bsend_completion);
if (NULL == req) {
/* request was completed in place */
mca_pml_base_bsend_request_free(packed_data);
return OMPI_SUCCESS;
}

if (OPAL_UNLIKELY(UCS_PTR_IS_ERR(req))) {
mca_pml_base_bsend_request_free(packed_data);
PML_UCX_ERROR("ucx bsend failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
return OMPI_ERROR;
}

req->req_complete_cb_data = packed_data;
return OMPI_SUCCESS;
}

int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
int dst, int tag, mca_pml_base_send_mode_t mode,
struct ompi_communicator_t* comm,
Expand All @@ -573,8 +639,10 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
ompi_request_t *req;
ucp_ep_h ep;

PML_UCX_TRACE_SEND("isend request *%p", buf, count, datatype, dst, tag, mode,
comm, (void*)request)
PML_UCX_TRACE_SEND("i%ssend request *%p",
buf, count, datatype, dst, tag, mode, comm,
mode == MCA_PML_BASE_SEND_BUFFERED ? "b" : "",
(void*)request)

/* TODO special care to sync/buffered send */

Expand All @@ -584,6 +652,13 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
return OMPI_ERROR;
}

/* Special care to sync/buffered send */
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
*request = &ompi_pml_ucx.completed_send_req;
return mca_pml_ucx_bsend(ep, buf, count, datatype,
PML_UCX_MAKE_SEND_TAG(tag, comm));
}

req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count,
mca_pml_ucx_get_datatype(datatype),
PML_UCX_MAKE_SEND_TAG(tag, comm),
Expand All @@ -609,16 +684,21 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
ompi_request_t *req;
ucp_ep_h ep;

PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm, "send");

/* TODO special care to sync/buffered send */
PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm,
mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send");

ep = mca_pml_ucx_get_ep(comm, dst);
if (OPAL_UNLIKELY(NULL == ep)) {
PML_UCX_ERROR("Failed to get ep for rank %d", dst);
return OMPI_ERROR;
}

/* Special care to sync/buffered send */
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
return mca_pml_ucx_bsend(ep, buf, count, datatype,
PML_UCX_MAKE_SEND_TAG(tag, comm));
}

req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count,
mca_pml_ucx_get_datatype(datatype),
PML_UCX_MAKE_SEND_TAG(tag, comm),
Expand Down Expand Up @@ -781,6 +861,7 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
mca_pml_ucx_persistent_request_t *preq;
ompi_request_t *tmp_req;
size_t i;
int rc;

for (i = 0; i < count; ++i) {
preq = (mca_pml_ucx_persistent_request_t *)requests[i];
Expand All @@ -795,12 +876,22 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
mca_pml_ucx_request_reset(&preq->ompi);

if (preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND) {
/* TODO special care to sync/buffered send */
PML_UCX_VERBOSE(8, "start send request %p", (void*)preq);
tmp_req = (ompi_request_t*)ucp_tag_send_nb(preq->send.ep, preq->buffer,
preq->count, preq->datatype,
preq->tag,
mca_pml_ucx_psend_completion);
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == preq->send.mode)) {
PML_UCX_VERBOSE(8, "start bsend request %p", (void*)preq);
rc = mca_pml_ucx_bsend(preq->send.ep, preq->buffer, preq->count,
preq->ompi_datatype, preq->tag);
if (OMPI_SUCCESS != rc) {
return rc;
}
/* pretend that we got immediate completion */
tmp_req = NULL;
} else {
PML_UCX_VERBOSE(8, "start send request %p", (void*)preq);
tmp_req = (ompi_request_t*)ucp_tag_send_nb(preq->send.ep, preq->buffer,
preq->count, preq->datatype,
preq->tag,
mca_pml_ucx_psend_completion);
}
} else {
PML_UCX_VERBOSE(8, "start recv request %p", (void*)preq);
tmp_req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker,
Expand Down
21 changes: 18 additions & 3 deletions ompi/mca/pml/ucx/pml_ucx_request.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ static int mca_pml_ucx_request_free(ompi_request_t **rptr)

*rptr = MPI_REQUEST_NULL;
mca_pml_ucx_request_reset(req);
ucp_request_release(req);
ucp_request_free(req);
return OMPI_SUCCESS;
}

Expand All @@ -46,6 +46,18 @@ void mca_pml_ucx_send_completion(void *request, ucs_status_t status)
ompi_request_complete(req, true);
}

void mca_pml_ucx_bsend_completion(void *request, ucs_status_t status)
{
ompi_request_t *req = request;

PML_UCX_VERBOSE(8, "bsend request %p buffer %p completed with status %s", (void*)req,
req->req_complete_cb_data, ucs_status_string(status));
mca_pml_base_bsend_request_free(req->req_complete_cb_data);
mca_pml_ucx_set_send_status(&req->req_status, status);
PML_UCX_ASSERT( !(REQUEST_COMPLETE(req)));
mca_pml_ucx_request_free(&req);
}

void mca_pml_ucx_recv_completion(void *request, ucs_status_t status,
ucp_tag_recv_info_t *info)
{
Expand Down Expand Up @@ -74,7 +86,7 @@ void mca_pml_ucx_persistent_request_complete(mca_pml_ucx_persistent_request_t *p
ompi_request_complete(&preq->ompi, true);
mca_pml_ucx_persistent_request_detach(preq, tmp_req);
mca_pml_ucx_request_reset(tmp_req);
ucp_request_release(tmp_req);
ucp_request_free(tmp_req);
}

static inline void mca_pml_ucx_preq_completion(ompi_request_t *tmp_req)
Expand Down Expand Up @@ -151,7 +163,10 @@ static int mca_pml_ucx_persistent_request_free(ompi_request_t **rptr)
preq->ompi.req_state = OMPI_REQUEST_INVALID;
if (tmp_req != NULL) {
mca_pml_ucx_persistent_request_detach(preq, tmp_req);
ucp_request_release(tmp_req);
ucp_request_free(tmp_req);
}
if (MCA_PML_BASE_SEND_BUFFERED == preq->send.mode) {
OBJ_RELEASE(preq->ompi_datatype);
}
PML_UCX_FREELIST_RETURN(&ompi_pml_ucx.persistent_reqs, &preq->ompi.super);
*rptr = MPI_REQUEST_NULL;
Expand Down
7 changes: 6 additions & 1 deletion ompi/mca/pml/ucx/pml_ucx_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ struct pml_ucx_persistent_request {
unsigned flags;
void *buffer;
size_t count;
ucp_datatype_t datatype;
union {
ucp_datatype_t datatype;
ompi_datatype_t *ompi_datatype;
};
ucp_tag_t tag;
struct {
mca_pml_base_send_mode_t mode;
Expand All @@ -118,6 +121,8 @@ void mca_pml_ucx_recv_completion(void *request, ucs_status_t status,

void mca_pml_ucx_psend_completion(void *request, ucs_status_t status);

void mca_pml_ucx_bsend_completion(void *request, ucs_status_t status);

void mca_pml_ucx_precv_completion(void *request, ucs_status_t status,
ucp_tag_recv_info_t *info);

Expand Down