Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 55 additions & 18 deletions ompi/mca/osc/ucx/osc_ucx_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@
memcpy(((char*)(_dst)) + (_off), _src, _len); \
(_off) += (_len);

opal_mutex_t mca_osc_service_mutex = OPAL_MUTEX_STATIC_INIT;
static void _osc_ucx_init_lock(void)
{
if(mca_osc_ucx_component.enable_mpi_threads) {
opal_mutex_lock(&mca_osc_service_mutex);
}
}
static void _osc_ucx_init_unlock(void)
{
if(mca_osc_ucx_component.enable_mpi_threads) {
opal_mutex_unlock(&mca_osc_service_mutex);
}
}

static int component_open(void);
static int component_register(void);
static int component_init(bool enable_progress_threads, bool enable_mpi_threads);
Expand Down Expand Up @@ -254,6 +268,9 @@ static void ompi_osc_ucx_unregister_progress()
{
int ret;

/* May be called concurrently - protect */
_osc_ucx_init_lock();

mca_osc_ucx_component.num_modules--;
OSC_UCX_ASSERT(mca_osc_ucx_component.num_modules >= 0);
if (0 == mca_osc_ucx_component.num_modules) {
Expand All @@ -262,6 +279,8 @@ static void ompi_osc_ucx_unregister_progress()
OSC_UCX_VERBOSE(1, "opal_progress_unregister failed: %d", ret);
}
}

_osc_ucx_init_unlock();
}

static int component_select(struct ompi_win_t *win, void **base, size_t size, int disp_unit,
Expand Down Expand Up @@ -295,6 +314,8 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
return OMPI_ERR_NOT_SUPPORTED;
}

_osc_ucx_init_lock();

if (mca_osc_ucx_component.env_initialized == false) {
ucp_config_t *config = NULL;
ucp_params_t context_params;
Expand All @@ -304,7 +325,8 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
status = ucp_config_read("MPI", NULL, &config);
if (UCS_OK != status) {
OSC_UCX_VERBOSE(1, "ucp_config_read failed: %d", status);
return OMPI_ERROR;
ret = OMPI_ERROR;
goto select_unlock;
}

OBJ_CONSTRUCT(&mca_osc_ucx_component.requests, opal_free_list_t);
Expand All @@ -315,7 +337,7 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
0, 0, 8, 0, 8, NULL, 0, NULL, NULL, NULL);
if (OMPI_SUCCESS != ret) {
OSC_UCX_VERBOSE(1, "opal_free_list_init failed: %d", ret);
goto error;
goto select_unlock;
}

/* initialize UCP context */
Expand All @@ -337,7 +359,7 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
if (UCS_OK != status) {
OSC_UCX_VERBOSE(1, "ucp_init failed: %d", status);
ret = OMPI_ERROR;
goto error;
goto select_unlock;
}

assert(mca_osc_ucx_component.ucp_worker == NULL);
Expand All @@ -349,29 +371,53 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
&(mca_osc_ucx_component.ucp_worker));
if (UCS_OK != status) {
OSC_UCX_VERBOSE(1, "ucp_worker_create failed: %d", status);
ret = OMPI_ERROR;
goto error_nomem;
ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE;
goto select_unlock;
}

/* query UCP worker attributes */
worker_attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE;
status = ucp_worker_query(mca_osc_ucx_component.ucp_worker, &worker_attr);
if (UCS_OK != status) {
OSC_UCX_VERBOSE(1, "ucp_worker_query failed: %d", status);
ret = OMPI_ERROR;
goto error_nomem;
ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE;
goto select_unlock;
}

if (mca_osc_ucx_component.enable_mpi_threads == true &&
worker_attr.thread_mode != UCS_THREAD_MODE_MULTI) {
OSC_UCX_VERBOSE(1, "ucx does not support multithreading");
ret = OMPI_ERROR;
goto error_nomem;
ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE;
goto select_unlock;
}

mca_osc_ucx_component.env_initialized = true;
env_initialized = true;
}

mca_osc_ucx_component.num_modules++;

OSC_UCX_ASSERT(mca_osc_ucx_component.num_modules > 0);
if (1 == mca_osc_ucx_component.num_modules) {
ret = opal_progress_register(progress_callback);
if (OMPI_SUCCESS != ret) {
OSC_UCX_VERBOSE(1, "opal_progress_register failed: %d", ret);
goto select_unlock;
}
}

select_unlock:
_osc_ucx_init_unlock();
switch(ret) {
case OMPI_SUCCESS:
break;
case OMPI_ERROR:
goto error;
case OMPI_ERR_TEMP_OUT_OF_RESOURCE:
goto error_nomem;
default:
goto error;
}

/* create module structure */
module = (ompi_osc_ucx_module_t *)calloc(1, sizeof(ompi_osc_ucx_module_t));
Expand All @@ -380,7 +426,6 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
goto error_nomem;
}

mca_osc_ucx_component.num_modules++;

/* fill in the function pointer part */
memcpy(module, &ompi_osc_ucx_module_template, sizeof(ompi_osc_base_module_t));
Expand Down Expand Up @@ -648,14 +693,6 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
goto error;
}

OSC_UCX_ASSERT(mca_osc_ucx_component.num_modules > 0);
if (1 == mca_osc_ucx_component.num_modules) {
ret = opal_progress_register(progress_callback);
if (OMPI_SUCCESS != ret) {
OSC_UCX_VERBOSE(1, "opal_progress_register failed: %d", ret);
goto error;
}
}
return ret;

error:
Expand Down