diff --git a/ompi/request/grequest.c b/ompi/request/grequest.c index 6125d134a9c..4d4211fcecf 100644 --- a/ompi/request/grequest.c +++ b/ompi/request/grequest.c @@ -24,33 +24,55 @@ #include "ompi/request/grequest.h" #include "ompi/mpi/fortran/base/fint_2_int.h" -/** +/* * Internal function to specialize the call to the user provided free_fn * for generalized requests. * @return The return value of the user specified callback or MPI_SUCCESS. */ -static inline int ompi_grequest_internal_free(ompi_grequest_t* greq) +static inline int ompi_grequest_invoke_free(ompi_grequest_t* greq) { - int rc = MPI_SUCCESS; + int rc = OMPI_SUCCESS; + MPI_Fint ierr; + if (NULL != greq->greq_free.c_free) { + if (greq->greq_funcs_are_c) { + rc = greq->greq_free.c_free(greq->greq_state); + } else { + greq->greq_free.f_free((MPI_Aint*)greq->greq_state, &ierr); + rc = OMPI_FINT_2_INT(ierr); + } /* We were already putting query_fn()'s return value into * status.MPI_ERROR but for MPI_{Wait,Test}*. If there's a * free callback to invoke, the standard says to use the * return value from free_fn() callback, too. */ - if (greq->greq_funcs_are_c) { - greq->greq_base.req_status.MPI_ERROR = - greq->greq_free.c_free(greq->greq_state); - } else { - MPI_Fint ierr; - greq->greq_free.f_free((MPI_Aint*)greq->greq_state, &ierr); - greq->greq_base.req_status.MPI_ERROR = OMPI_FINT_2_INT(ierr); + if (OMPI_SUCCESS != rc) { + greq->greq_base.req_status.MPI_ERROR = rc; } - rc = greq->greq_base.req_status.MPI_ERROR; } return rc; } + +/* + * Internal function to dispatch the call to the user provided free_fn + * for generalized requests. The freeing code executes as soon as both + * wait/free and complete have occured. + * @return The return value of the user specified callback or MPI_SUCCESS. + */ +static inline int ompi_grequest_internal_free(ompi_grequest_t* greq) +{ + int rc = OMPI_SUCCESS; + + if (REQUEST_COMPLETE(&greq->greq_base) && greq->greq_user_freed) { + rc = ompi_grequest_invoke_free(greq); + /* The free_fn() callback should be invoked only once. */ + if (NULL != greq->greq_free.c_free) + greq->greq_free.c_free = NULL; + } + return rc; +} + /* * See the comment in the grequest destructor for the weird semantics * here. If the request has been marked complete via a call to @@ -66,14 +88,10 @@ static int ompi_grequest_free(ompi_request_t** req) ompi_grequest_t* greq = (ompi_grequest_t*)*req; int rc = OMPI_SUCCESS; - if( greq->greq_user_freed ) { - return OMPI_ERR_OUT_OF_RESOURCE; - } greq->greq_user_freed = true; - if( REQUEST_COMPLETE(*req) ) { - rc = ompi_grequest_internal_free(greq); - } - if (OMPI_SUCCESS == rc ) { + rc = ompi_grequest_internal_free(greq); + + if (OMPI_SUCCESS == rc) { OBJ_RELEASE(*req); *req = MPI_REQUEST_NULL; } @@ -180,7 +198,7 @@ int ompi_grequest_start( ompi_request_t** request) { ompi_grequest_t *greq = OBJ_NEW(ompi_grequest_t); - if(greq == NULL) { + if (greq == NULL) { return OMPI_ERR_OUT_OF_RESOURCE; } /* We call RETAIN here specifically to increase the refcount to 2. @@ -211,13 +229,14 @@ int ompi_grequest_start( int ompi_grequest_complete(ompi_request_t *req) { ompi_grequest_t* greq = (ompi_grequest_t*)req; + bool greq_release = !REQUEST_COMPLETE(req); int rc; rc = ompi_request_complete(req, true); - if( greq->greq_user_freed ) { + if (OMPI_SUCCESS == rc && greq_release) { rc = ompi_grequest_internal_free(greq); + OBJ_RELEASE(req); } - OBJ_RELEASE(req); return rc; } @@ -237,6 +256,11 @@ int ompi_grequest_invoke_query(ompi_request_t *request, int rc = OMPI_SUCCESS; ompi_grequest_t *g = (ompi_grequest_t*) request; + /* MPI mandates that query_fn must be called after the request is + * completed. Make sure the caller does not break the contract. + */ + assert( REQUEST_COMPLETE(request) ); + /* MPI-3 mandates that the return value from the query function * (i.e., the int return value from the C function or the ierr * argument from the Fortran function) must be returned to the @@ -268,9 +292,8 @@ int ompi_grequest_invoke_query(ompi_request_t *request, rc = OMPI_FINT_2_INT(ierr); } } - if( MPI_SUCCESS != rc ) { + if (OMPI_SUCCESS != rc) { status->MPI_ERROR = rc; } return rc; } -