Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ompi/mca/pml/ucx/pml_ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src
return OMPI_ERROR;
}

ucp_worker_progress(ompi_pml_ucx.ucp_worker);
while (!req->req_complete) {
opal_progress();
}
Expand Down Expand Up @@ -492,10 +493,11 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
mca_pml_ucx_get_datatype(datatype),
PML_UCX_MAKE_SEND_TAG(tag, comm),
mca_pml_ucx_send_completion);
if (req == NULL) {
if (OPAL_LIKELY(req == NULL)) {
return OMPI_SUCCESS;
} else if (!UCS_PTR_IS_ERR(req)) {
PML_UCX_VERBOSE(8, "got request %p", (void*)req);
ucp_worker_progress(ompi_pml_ucx.ucp_worker);
ompi_request_wait(&req, MPI_STATUS_IGNORE);
return OMPI_SUCCESS;
} else {
Expand Down Expand Up @@ -698,6 +700,7 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
PML_UCX_VERBOSE(8, "temporary request %p will complete persistent request %p",
(void*)tmp_req, (void*)preq);
tmp_req->req_complete_cb_data = preq;
preq->tmp_req = tmp_req;
}
OPAL_THREAD_UNLOCK(&ompi_request_lock);
} else {
Expand Down
48 changes: 40 additions & 8 deletions ompi/mca/pml/ucx/pml_ucx_request.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ static int mca_pml_ucx_request_free(ompi_request_t **rptr)
return OMPI_SUCCESS;
}

static int mca_pml_ucx_request_cancel(ompi_request_t *req, int flag)
{
ucp_request_cancel(ompi_pml_ucx.ucp_worker, req);
return OMPI_SUCCESS;
}

void mca_pml_ucx_send_completion(void *request, ucs_status_t status)
{
ompi_request_t *req = request;
Expand Down Expand Up @@ -55,12 +61,19 @@ void mca_pml_ucx_recv_completion(void *request, ucs_status_t status,
OPAL_THREAD_UNLOCK(&ompi_request_lock);
}

void mca_pml_ucx_persistent_requset_complete(mca_pml_ucx_persistent_request_t *preq,
static void mca_pml_ucx_persistent_request_detach(mca_pml_ucx_persistent_request_t *preq,
ompi_request_t *tmp_req)
{
tmp_req->req_complete_cb_data = NULL;
preq->tmp_req = NULL;
}

void mca_pml_ucx_persistent_request_complete(mca_pml_ucx_persistent_request_t *preq,
ompi_request_t *tmp_req)
{
preq->ompi.req_status = tmp_req->req_status;
ompi_request_complete(&preq->ompi, true);
tmp_req->req_complete_cb_data = NULL;
mca_pml_ucx_persistent_request_detach(preq, tmp_req);
mca_pml_ucx_request_reset(tmp_req);
ucp_request_release(tmp_req);
}
Expand All @@ -73,7 +86,8 @@ static inline void mca_pml_ucx_preq_completion(ompi_request_t *tmp_req)
ompi_request_complete(tmp_req, false);
preq = (mca_pml_ucx_persistent_request_t*)tmp_req->req_complete_cb_data;
if (preq != NULL) {
mca_pml_ucx_persistent_requset_complete(preq, tmp_req);
PML_UCX_ASSERT(preq->tmp_req != NULL);
mca_pml_ucx_persistent_request_complete(preq, tmp_req);
}
OPAL_THREAD_UNLOCK(&ompi_request_lock);
}
Expand Down Expand Up @@ -120,7 +134,8 @@ void mca_pml_ucx_request_init(void *request)
ompi_request_t* ompi_req = request;
OBJ_CONSTRUCT(ompi_req, ompi_request_t);
mca_pml_ucx_request_init_common(ompi_req, false, OMPI_REQUEST_ACTIVE,
mca_pml_ucx_request_free, NULL);
mca_pml_ucx_request_free,
mca_pml_ucx_request_cancel);
}

void mca_pml_ucx_request_cleanup(void *request)
Expand All @@ -133,18 +148,35 @@ void mca_pml_ucx_request_cleanup(void *request)

static int mca_pml_ucx_persistent_request_free(ompi_request_t **rptr)
{
mca_pml_ucx_persistent_request_t* req = (mca_pml_ucx_persistent_request_t*)*rptr;
mca_pml_ucx_persistent_request_t* preq = (mca_pml_ucx_persistent_request_t*)*rptr;
ompi_request_t *tmp_req = preq->tmp_req;

preq->ompi.req_state = OMPI_REQUEST_INVALID;
if (tmp_req != NULL) {
mca_pml_ucx_persistent_request_detach(preq, tmp_req);
ucp_request_release(tmp_req);
}
PML_UCX_FREELIST_RETURN(&ompi_pml_ucx.persistent_reqs, &preq->ompi.super);
*rptr = MPI_REQUEST_NULL;
req->ompi.req_state = OMPI_REQUEST_INVALID;
PML_UCX_FREELIST_RETURN(&ompi_pml_ucx.persistent_reqs, &req->ompi.super);
return OMPI_SUCCESS;
}

static int mca_pml_ucx_persistent_request_cancel(ompi_request_t *req, int flag)
{
mca_pml_ucx_persistent_request_t* preq = (mca_pml_ucx_persistent_request_t*)req;

if (preq->tmp_req != NULL) {
ucp_request_cancel(ompi_pml_ucx.ucp_worker, preq->tmp_req);
}
return OMPI_SUCCESS;
}

static void mca_pml_ucx_persisternt_request_construct(mca_pml_ucx_persistent_request_t* req)
{
mca_pml_ucx_request_init_common(&req->ompi, true, OMPI_REQUEST_INACTIVE,
mca_pml_ucx_persistent_request_free, NULL);
mca_pml_ucx_persistent_request_free,
mca_pml_ucx_persistent_request_cancel);
req->tmp_req = NULL;
}

static void mca_pml_ucx_persisternt_request_destruct(mca_pml_ucx_persistent_request_t* req)
Expand Down
1 change: 1 addition & 0 deletions ompi/mca/pml/ucx/pml_ucx_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ enum {

struct pml_ucx_persistent_request {
ompi_request_t ompi;
ompi_request_t *tmp_req;
unsigned flags;
void *buffer;
size_t count;
Expand Down