From 22d688faf918a60c78d96c906d610b65a936150b Mon Sep 17 00:00:00 2001 From: Alex Mikheev Date: Thu, 9 Feb 2017 17:07:34 +0200 Subject: [PATCH] ompi: pml ucx: add support for the buffered send Signed-off-by: Alex Mikheev (cherry picked from commit b015c8bb489580a33739ec31baa17291b5eea80b) --- ompi/mca/pml/ucx/pml_ucx.c | 115 ++++++++++++++++++++++++++--- ompi/mca/pml/ucx/pml_ucx_request.c | 21 +++++- ompi/mca/pml/ucx/pml_ucx_request.h | 7 +- 3 files changed, 127 insertions(+), 16 deletions(-) diff --git a/ompi/mca/pml/ucx/pml_ucx.c b/ompi/mca/pml/ucx/pml_ucx.c index 9b0fdf495e0..1ddca8812d0 100644 --- a/ompi/mca/pml/ucx/pml_ucx.c +++ b/ompi/mca/pml/ucx/pml_ucx.c @@ -15,6 +15,7 @@ #include "opal/runtime/opal.h" #include "opal/mca/pmix/pmix.h" #include "ompi/message/message.h" +#include "ompi/mca/pml/base/pml_base_bsend.h" #include "pml_ucx_request.h" #include @@ -506,15 +507,80 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat req->flags = MCA_PML_UCX_REQUEST_FLAG_SEND; req->buffer = (void *)buf; req->count = count; - req->datatype = mca_pml_ucx_get_datatype(datatype); req->tag = PML_UCX_MAKE_SEND_TAG(tag, comm); req->send.mode = mode; req->send.ep = ep; + if (MCA_PML_BASE_SEND_BUFFERED == mode) { + req->ompi_datatype = datatype; + OBJ_RETAIN(datatype); + } else { + req->datatype = mca_pml_ucx_get_datatype(datatype); + } *request = &req->ompi; return OMPI_SUCCESS; } +static int +mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count, + ompi_datatype_t *datatype, uint64_t pml_tag) +{ + ompi_request_t *req; + void *packed_data; + size_t packed_length; + size_t offset; + uint32_t iov_count; + struct iovec iov; + opal_convertor_t opal_conv; + + OBJ_CONSTRUCT(&opal_conv, opal_convertor_t); + opal_convertor_copy_and_prepare_for_recv(ompi_proc_local_proc->super.proc_convertor, + &datatype->super, count, buf, 0, + &opal_conv); + opal_convertor_get_packed_size(&opal_conv, &packed_length); + + packed_data = mca_pml_base_bsend_request_alloc_buf(packed_length); + if (OPAL_UNLIKELY(NULL == packed_data)) { + OBJ_DESTRUCT(&opal_conv); + PML_UCX_ERROR("bsend: failed to allocate buffer"); + return OMPI_ERR_OUT_OF_RESOURCE; + } + + iov_count = 1; + iov.iov_base = packed_data; + iov.iov_len = packed_length; + + PML_UCX_VERBOSE(8, "bsend of packed buffer %p len %d\n", packed_data, packed_length); + offset = 0; + opal_convertor_set_position(&opal_conv, &offset); + if (0 > opal_convertor_pack(&opal_conv, &iov, &iov_count, &packed_length)) { + mca_pml_base_bsend_request_free(packed_data); + OBJ_DESTRUCT(&opal_conv); + PML_UCX_ERROR("bsend: failed to pack user datatype"); + return OMPI_ERROR; + } + + OBJ_DESTRUCT(&opal_conv); + + req = (ompi_request_t*)ucp_tag_send_nb(ep, packed_data, packed_length, + ucp_dt_make_contig(1), pml_tag, + mca_pml_ucx_bsend_completion); + if (NULL == req) { + /* request was completed in place */ + mca_pml_base_bsend_request_free(packed_data); + return OMPI_SUCCESS; + } + + if (OPAL_UNLIKELY(UCS_PTR_IS_ERR(req))) { + mca_pml_base_bsend_request_free(packed_data); + PML_UCX_ERROR("ucx bsend failed: %s", ucs_status_string(UCS_PTR_STATUS(req))); + return OMPI_ERROR; + } + + req->req_complete_cb_data = packed_data; + return OMPI_SUCCESS; +} + int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype, int dst, int tag, mca_pml_base_send_mode_t mode, struct ompi_communicator_t* comm, @@ -523,8 +589,10 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype, ompi_request_t *req; ucp_ep_h ep; - PML_UCX_TRACE_SEND("isend request *%p", buf, count, datatype, dst, tag, mode, - comm, (void*)request) + PML_UCX_TRACE_SEND("i%ssend request *%p", + buf, count, datatype, dst, tag, mode, comm, + mode == MCA_PML_BASE_SEND_BUFFERED ? "b" : "", + (void*)request) /* TODO special care to sync/buffered send */ @@ -534,6 +602,13 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype, return OMPI_ERROR; } + /* Special care to sync/buffered send */ + if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) { + *request = &ompi_pml_ucx.completed_send_req; + return mca_pml_ucx_bsend(ep, buf, count, datatype, + PML_UCX_MAKE_SEND_TAG(tag, comm)); + } + req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count, mca_pml_ucx_get_datatype(datatype), PML_UCX_MAKE_SEND_TAG(tag, comm), @@ -559,9 +634,8 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i ompi_request_t *req; ucp_ep_h ep; - PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm, "send"); - - /* TODO special care to sync/buffered send */ + PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm, + mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send"); ep = mca_pml_ucx_get_ep(comm, dst); if (OPAL_UNLIKELY(NULL == ep)) { @@ -569,6 +643,12 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i return OMPI_ERROR; } + /* Special care to sync/buffered send */ + if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) { + return mca_pml_ucx_bsend(ep, buf, count, datatype, + PML_UCX_MAKE_SEND_TAG(tag, comm)); + } + req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count, mca_pml_ucx_get_datatype(datatype), PML_UCX_MAKE_SEND_TAG(tag, comm), @@ -729,6 +809,7 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests) mca_pml_ucx_persistent_request_t *preq; ompi_request_t *tmp_req; size_t i; + int rc; for (i = 0; i < count; ++i) { preq = (mca_pml_ucx_persistent_request_t *)requests[i]; @@ -743,12 +824,22 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests) mca_pml_ucx_request_reset(&preq->ompi); if (preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND) { - /* TODO special care to sync/buffered send */ - PML_UCX_VERBOSE(8, "start send request %p", (void*)preq); - tmp_req = (ompi_request_t*)ucp_tag_send_nb(preq->send.ep, preq->buffer, - preq->count, preq->datatype, - preq->tag, - mca_pml_ucx_psend_completion); + if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == preq->send.mode)) { + PML_UCX_VERBOSE(8, "start bsend request %p", (void*)preq); + rc = mca_pml_ucx_bsend(preq->send.ep, preq->buffer, preq->count, + preq->ompi_datatype, preq->tag); + if (OMPI_SUCCESS != rc) { + return rc; + } + /* pretend that we got immediate completion */ + tmp_req = NULL; + } else { + PML_UCX_VERBOSE(8, "start send request %p", (void*)preq); + tmp_req = (ompi_request_t*)ucp_tag_send_nb(preq->send.ep, preq->buffer, + preq->count, preq->datatype, + preq->tag, + mca_pml_ucx_psend_completion); + } } else { PML_UCX_VERBOSE(8, "start recv request %p", (void*)preq); tmp_req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker, diff --git a/ompi/mca/pml/ucx/pml_ucx_request.c b/ompi/mca/pml/ucx/pml_ucx_request.c index d82cdffae0c..6537a4017e6 100644 --- a/ompi/mca/pml/ucx/pml_ucx_request.c +++ b/ompi/mca/pml/ucx/pml_ucx_request.c @@ -24,7 +24,7 @@ static int mca_pml_ucx_request_free(ompi_request_t **rptr) *rptr = MPI_REQUEST_NULL; mca_pml_ucx_request_reset(req); - ucp_request_release(req); + ucp_request_free(req); return OMPI_SUCCESS; } @@ -46,6 +46,18 @@ void mca_pml_ucx_send_completion(void *request, ucs_status_t status) ompi_request_complete(req, true); } +void mca_pml_ucx_bsend_completion(void *request, ucs_status_t status) +{ + ompi_request_t *req = request; + + PML_UCX_VERBOSE(8, "bsend request %p buffer %p completed with status %s", (void*)req, + req->req_complete_cb_data, ucs_status_string(status)); + mca_pml_base_bsend_request_free(req->req_complete_cb_data); + mca_pml_ucx_set_send_status(&req->req_status, status); + PML_UCX_ASSERT( !(REQUEST_COMPLETE(req))); + mca_pml_ucx_request_free(&req); +} + void mca_pml_ucx_recv_completion(void *request, ucs_status_t status, ucp_tag_recv_info_t *info) { @@ -74,7 +86,7 @@ void mca_pml_ucx_persistent_request_complete(mca_pml_ucx_persistent_request_t *p ompi_request_complete(&preq->ompi, true); mca_pml_ucx_persistent_request_detach(preq, tmp_req); mca_pml_ucx_request_reset(tmp_req); - ucp_request_release(tmp_req); + ucp_request_free(tmp_req); } static inline void mca_pml_ucx_preq_completion(ompi_request_t *tmp_req) @@ -151,7 +163,10 @@ static int mca_pml_ucx_persistent_request_free(ompi_request_t **rptr) preq->ompi.req_state = OMPI_REQUEST_INVALID; if (tmp_req != NULL) { mca_pml_ucx_persistent_request_detach(preq, tmp_req); - ucp_request_release(tmp_req); + ucp_request_free(tmp_req); + } + if (MCA_PML_BASE_SEND_BUFFERED == preq->send.mode) { + OBJ_RELEASE(preq->ompi_datatype); } PML_UCX_FREELIST_RETURN(&ompi_pml_ucx.persistent_reqs, &preq->ompi.super); *rptr = MPI_REQUEST_NULL; diff --git a/ompi/mca/pml/ucx/pml_ucx_request.h b/ompi/mca/pml/ucx/pml_ucx_request.h index db9f089b4cf..d0d426f441a 100644 --- a/ompi/mca/pml/ucx/pml_ucx_request.h +++ b/ompi/mca/pml/ucx/pml_ucx_request.h @@ -96,7 +96,10 @@ struct pml_ucx_persistent_request { unsigned flags; void *buffer; size_t count; - ucp_datatype_t datatype; + union { + ucp_datatype_t datatype; + ompi_datatype_t *ompi_datatype; + }; ucp_tag_t tag; struct { mca_pml_base_send_mode_t mode; @@ -115,6 +118,8 @@ void mca_pml_ucx_recv_completion(void *request, ucs_status_t status, void mca_pml_ucx_psend_completion(void *request, ucs_status_t status); +void mca_pml_ucx_bsend_completion(void *request, ucs_status_t status); + void mca_pml_ucx_precv_completion(void *request, ucs_status_t status, ucp_tag_recv_info_t *info);