Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 61 additions & 15 deletions ompi/mca/osc/ucx/osc_ucx_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@
memcpy(((char*)(_dst)) + (_off), _src, _len); \
(_off) += (_len);

opal_mutex_t mca_osc_service_mutex = OPAL_MUTEX_STATIC_INIT;
static void _osc_ucx_init_lock(void)
{
if(mca_osc_ucx_component.enable_mpi_threads) {
opal_mutex_lock(&mca_osc_service_mutex);
}
}
static void _osc_ucx_init_unlock(void)
{
if(mca_osc_ucx_component.enable_mpi_threads) {
opal_mutex_unlock(&mca_osc_service_mutex);
}
}


static int component_open(void);
static int component_register(void);
static int component_init(bool enable_progress_threads, bool enable_mpi_threads);
Expand Down Expand Up @@ -141,6 +156,9 @@ static int component_init(bool enable_progress_threads, bool enable_mpi_threads)

static int component_finalize(void) {
opal_common_ucx_mca_deregister();
if (mca_osc_ucx_component.env_initialized) {
opal_common_ucx_wpool_finalize(mca_osc_ucx_component.wpool);
}
opal_common_ucx_wpool_free(mca_osc_ucx_component.wpool);
return OMPI_SUCCESS;
}
Expand Down Expand Up @@ -189,6 +207,9 @@ static void ompi_osc_ucx_unregister_progress()
{
int ret;

/* May be called concurrently - protect */
_osc_ucx_init_lock();

mca_osc_ucx_component.num_modules--;
OSC_UCX_ASSERT(mca_osc_ucx_component.num_modules >= 0);
if (0 == mca_osc_ucx_component.num_modules) {
Expand All @@ -197,6 +218,8 @@ static void ompi_osc_ucx_unregister_progress()
OSC_UCX_VERBOSE(1, "opal_progress_unregister failed: %d", ret);
}
}

_osc_ucx_init_unlock();
}

static int component_select(struct ompi_win_t *win, void **base, size_t size, int disp_unit,
Expand All @@ -223,7 +246,14 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
return OMPI_ERR_NOT_SUPPORTED;
}

/* May be called concurrently - protect */
_osc_ucx_init_lock();

if (mca_osc_ucx_component.env_initialized == false) {
/* Lazy initialization of the global state.
* As not all of the MPI applications are using One-Sided functionality
* we don't want to initialize in the component_init()
*/

OBJ_CONSTRUCT(&mca_osc_ucx_component.requests, opal_free_list_t);
ret = opal_free_list_init (&mca_osc_ucx_component.requests,
Expand All @@ -233,30 +263,52 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
0, 0, 8, 0, 8, NULL, 0, NULL, NULL, NULL);
if (OMPI_SUCCESS != ret) {
OSC_UCX_VERBOSE(1, "opal_free_list_init failed: %d", ret);
goto error;
goto select_unlock;
}

ret = opal_common_ucx_wpool_init(mca_osc_ucx_component.wpool,
ompi_proc_world_size(),
mca_osc_ucx_component.enable_mpi_threads);
if (OMPI_SUCCESS != ret) {
OSC_UCX_VERBOSE(1, "opal_common_ucx_wpool_init failed: %d", ret);
goto error;
goto select_unlock;
}

/* Make sure that all memory updates performed above are globally
* observable before (mca_osc_ucx_component.env_initialized = true)
*/
mca_osc_ucx_component.env_initialized = true;
env_initialized = true;
}

/* Account for the number of active "modules" = MPI windows */
mca_osc_ucx_component.num_modules++;

/* If this is the first window to be registered - register the progress
* callback
*/
OSC_UCX_ASSERT(mca_osc_ucx_component.num_modules > 0);
if (1 == mca_osc_ucx_component.num_modules) {
ret = opal_progress_register(progress_callback);
if (OMPI_SUCCESS != ret) {
OSC_UCX_VERBOSE(1, "opal_progress_register failed: %d", ret);
goto error;
}
}

select_unlock:
_osc_ucx_init_unlock();
if (ret) {
goto error;
}

/* create module structure */
module = (ompi_osc_ucx_module_t *)calloc(1, sizeof(ompi_osc_ucx_module_t));
if (module == NULL) {
ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE;
goto error_nomem;
}

mca_osc_ucx_component.num_modules++;

/* fill in the function pointer part */
memcpy(module, &ompi_osc_ucx_module_template, sizeof(ompi_osc_base_module_t));

Expand Down Expand Up @@ -410,19 +462,15 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
goto error;
}

OSC_UCX_ASSERT(mca_osc_ucx_component.num_modules > 0);
if (1 == mca_osc_ucx_component.num_modules) {
ret = opal_progress_register(progress_callback);
if (OMPI_SUCCESS != ret) {
OSC_UCX_VERBOSE(1, "opal_progress_register failed: %d", ret);
goto error;
}
}
return ret;

error:
error:
if (module->disp_units) free(module->disp_units);
if (module->comm) ompi_comm_free(&module->comm);
/* We update the modules count and (if need) registering a callback right
* prior to memory allocation for the module.
* So we use it as an indirect sign here
*/
if (module) {
free(module);
ompi_osc_ucx_unregister_progress();
Expand Down Expand Up @@ -575,8 +623,6 @@ int ompi_osc_ucx_free(struct ompi_win_t *win) {
}

opal_common_ucx_wpctx_release(module->ctx);

opal_common_ucx_wpool_finalize(mca_osc_ucx_component.wpool);

if (module->disp_units) free(module->disp_units);
ompi_comm_free(&module->comm);
Expand Down