Skip to content

Commit

Permalink
coll/han: implement hierarchical scatterv
Browse files Browse the repository at this point in the history
Add scatterv implementation to optimize large-scale communications on
multiple nodes and multiple processes per node, by avoiding high-incast
traffic on the root process.

Because *V collectives do not have equal datatype/count on every
process, it does not natively support message-size based tuning without
an additional global communication.

Similar to scatter, the hierarchical scatterv requires a
temporary buffer and memory copy to handle out-of-order data, or
non-contiguous placement on the send buffer, which results in worse
performance for large messages compared to the linear implementation.

Signed-off-by: Jessie Yang <jiaxiyan@amazon.com>
  • Loading branch information
jiaxiyan authored and wenduwan committed Mar 22, 2024
1 parent 48c125e commit 2152b61
Show file tree
Hide file tree
Showing 9 changed files with 553 additions and 1 deletion.
1 change: 1 addition & 0 deletions ompi/mca/coll/han/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ coll_han_barrier.c \
coll_han_bcast.c \
coll_han_reduce.c \
coll_han_scatter.c \
coll_han_scatterv.c \
coll_han_gather.c \
coll_han_gatherv.c \
coll_han_allreduce.c \
Expand Down
13 changes: 13 additions & 0 deletions ompi/mca/coll/han/coll_han.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ typedef struct mca_coll_han_op_module_name_t {
mca_coll_han_op_up_low_module_name_t gather;
mca_coll_han_op_up_low_module_name_t gatherv;
mca_coll_han_op_up_low_module_name_t scatter;
mca_coll_han_op_up_low_module_name_t scatterv;
} mca_coll_han_op_module_name_t;

/**
Expand Down Expand Up @@ -244,6 +245,10 @@ typedef struct mca_coll_han_component_t {
uint32_t han_scatter_up_module;
/* low level module for scatter */
uint32_t han_scatter_low_module;
/* up level module for scatterv */
uint32_t han_scatterv_up_module;
/* low level module for scatterv */
uint32_t han_scatterv_low_module;
/* name of the modules */
mca_coll_han_op_module_name_t han_op_module_name;
/* whether we need reproducible results
Expand Down Expand Up @@ -287,6 +292,7 @@ typedef struct mca_coll_han_single_collective_fallback_s {
mca_coll_base_module_gatherv_fn_t gatherv;
mca_coll_base_module_reduce_fn_t reduce;
mca_coll_base_module_scatter_fn_t scatter;
mca_coll_base_module_scatterv_fn_t scatterv;
} module_fn;
mca_coll_base_module_t* module;
} mca_coll_han_single_collective_fallback_t;
Expand All @@ -306,6 +312,7 @@ typedef struct mca_coll_han_collectives_fallback_s {
mca_coll_han_single_collective_fallback_t gather;
mca_coll_han_single_collective_fallback_t gatherv;
mca_coll_han_single_collective_fallback_t scatter;
mca_coll_han_single_collective_fallback_t scatterv;
} mca_coll_han_collectives_fallback_t;

/** Coll han module */
Expand Down Expand Up @@ -384,6 +391,8 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
#define previous_scatter fallback.scatter.module_fn.scatter
#define previous_scatter_module fallback.scatter.module

#define previous_scatterv fallback.scatterv.module_fn.scatterv
#define previous_scatterv_module fallback.scatterv.module

/* macro to correctly load a fallback collective module */
#define HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, COLL) \
Expand All @@ -403,6 +412,7 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, barrier); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, bcast); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatter); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatterv); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gather); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gatherv); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, reduce); \
Expand Down Expand Up @@ -495,6 +505,9 @@ mca_coll_han_reduce_intra_dynamic(REDUCE_BASE_ARGS,
int
mca_coll_han_scatter_intra_dynamic(SCATTER_BASE_ARGS,
mca_coll_base_module_t *module);
int
mca_coll_han_scatterv_intra_dynamic(SCATTERV_BASE_ARGS,
mca_coll_base_module_t *module);

int mca_coll_han_barrier_intra_simple(struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);
Expand Down
4 changes: 4 additions & 0 deletions ompi/mca/coll/han/coll_han_algorithms.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ mca_coll_han_algorithm_value_t* mca_coll_han_available_algorithms[COLLCOUNT] =
{"simple", (fnptr_t) &mca_coll_han_scatter_intra_simple}, // 2-level
{ 0 }
},
[SCATTERV] = (mca_coll_han_algorithm_value_t[]){
{"intra", (fnptr_t) &mca_coll_han_scatterv_intra}, // 2-level
{ 0 }
},
[GATHER] = (mca_coll_han_algorithm_value_t[]){
{"intra", (fnptr_t) &mca_coll_han_gather_intra}, // 2-level
{"simple", (fnptr_t) &mca_coll_han_gather_intra_simple}, // 2-level
Expand Down
10 changes: 10 additions & 0 deletions ompi/mca/coll/han/coll_han_algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ mca_coll_han_scatter_intra_simple(const void *sbuf, int scount,
struct ompi_communicator_t *comm,
mca_coll_base_module_t * module);

/* Scatterv */
int
mca_coll_han_scatterv_intra(const void *sbuf, const int *scounts,
const int *displs, struct ompi_datatype_t *sdtype,
void *rbuf, int rcount,
struct ompi_datatype_t *rdtype,
int root,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);

/* Gather */
int
mca_coll_han_gather_intra(const void *sbuf, int scount,
Expand Down
17 changes: 17 additions & 0 deletions ompi/mca/coll/han/coll_han_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ static int han_close(void)
free(mca_coll_han_component.han_op_module_name.scatter.han_op_low_module_name);
mca_coll_han_component.han_op_module_name.scatter.han_op_low_module_name = NULL;

free(mca_coll_han_component.han_op_module_name.scatterv.han_op_up_module_name);
mca_coll_han_component.han_op_module_name.scatterv.han_op_up_module_name = NULL;
free(mca_coll_han_component.han_op_module_name.scatterv.han_op_low_module_name);
mca_coll_han_component.han_op_module_name.scatterv.han_op_low_module_name = NULL;

return OMPI_SUCCESS;
}

Expand Down Expand Up @@ -373,6 +378,18 @@ static int han_register(void)
OPAL_INFO_LVL_9, &cs->han_scatter_low_module,
&cs->han_op_module_name.scatter.han_op_low_module_name);

cs->han_scatterv_up_module = 0;
(void) mca_coll_han_query_module_from_mca(c, "scatterv_up_module",
"up level module for scatterv, 0 basic",
OPAL_INFO_LVL_9, &cs->han_scatterv_up_module,
&cs->han_op_module_name.scatterv.han_op_up_module_name);

cs->han_scatterv_low_module = 0;
(void) mca_coll_han_query_module_from_mca(c, "scatterv_low_module",
"low level module for scatterv, 0 basic",
OPAL_INFO_LVL_9, &cs->han_scatterv_low_module,
&cs->han_op_module_name.scatterv.han_op_low_module_name);

cs->han_reproducible = 0;
(void) mca_base_component_var_register(c, "reproducible",
"whether we need reproducible results "
Expand Down
111 changes: 111 additions & 0 deletions ompi/mca/coll/han/coll_han_dynamic.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ bool mca_coll_han_is_coll_dynamic_implemented(COLLTYPE_T coll_id)
case GATHERV:
case REDUCE:
case SCATTER:
case SCATTERV:
return true;
default:
return false;
Expand Down Expand Up @@ -1397,3 +1398,113 @@ mca_coll_han_scatter_intra_dynamic(const void *sbuf, int scount,
root, comm,
sub_module);
}


/*
* Scatterv selector:
* On a sub-communicator, checks the stored rules to find the module to use
* On the global communicator, calls the han collective implementation, or
* calls the correct module if fallback mechanism is activated
*/
int
mca_coll_han_scatterv_intra_dynamic(const void *sbuf, const int *scounts,
const int *displs, struct ompi_datatype_t *sdtype,
void *rbuf, int rcount,
struct ompi_datatype_t *rdtype,
int root,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
{
mca_coll_han_module_t *han_module = (mca_coll_han_module_t*) module;
TOPO_LVL_T topo_lvl = han_module->topologic_level;
mca_coll_base_module_scatterv_fn_t scatterv;
mca_coll_base_module_t *sub_module;
int rank, verbosity = 0;

if (!han_module->enabled) {
return han_module->previous_scatterv(sbuf, scounts, displs, sdtype, rbuf, rcount, rdtype,
root, comm, han_module->previous_scatterv_module);
}

/* v collectives do not support message-size based dynamic rules */
sub_module = get_module(SCATTERV,
MCA_COLL_HAN_ANY_MESSAGE_SIZE,
comm,
han_module);

/* First errors are always printed by rank 0 */
rank = ompi_comm_rank(comm);
if( (0 == rank) && (han_module->dynamic_errors < mca_coll_han_component.max_dynamic_errors) ) {
verbosity = 30;
}

if(NULL == sub_module) {
/*
* No valid collective module from dynamic rules
* nor from mca parameter
*/
han_module->dynamic_errors++;
opal_output_verbose(verbosity, mca_coll_han_component.han_output,
"coll:han:mca_coll_han_scatterv_intra_dynamic "
"HAN did not find any valid module for collective %d (%s) "
"with topological level %d (%s) on communicator (%s/%s). "
"Please check dynamic file/mca parameters\n",
SCATTERV, mca_coll_base_colltype_to_str(SCATTERV),
topo_lvl, mca_coll_han_topo_lvl_to_str(topo_lvl),
ompi_comm_print_cid(comm), comm->c_name);
OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output,
"HAN/SCATTERV: No module found for the sub-communicator. "
"Falling back to another component\n"));
scatterv = han_module->previous_scatterv;
sub_module = han_module->previous_scatterv_module;
} else if (NULL == sub_module->coll_scatterv) {
/*
* No valid collective from dynamic rules
* nor from mca parameter
*/
han_module->dynamic_errors++;
opal_output_verbose(verbosity, mca_coll_han_component.han_output,
"coll:han:mca_coll_han_scatterv_intra_dynamic "
"HAN found valid module for collective %d (%s) "
"with topological level %d (%s) on communicator (%s/%s) "
"but this module cannot handle this collective. "
"Please check dynamic file/mca parameters\n",
SCATTERV, mca_coll_base_colltype_to_str(SCATTERV),
topo_lvl, mca_coll_han_topo_lvl_to_str(topo_lvl),
ompi_comm_print_cid(comm), comm->c_name);
OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output,
"HAN/SCATTERV: the module found for the sub-"
"communicator cannot handle the SCATTERV operation. "
"Falling back to another component\n"));
scatterv = han_module->previous_scatterv;
sub_module = han_module->previous_scatterv_module;
} else if (GLOBAL_COMMUNICATOR == topo_lvl && sub_module == module) {
/*
* No fallback mechanism activated for this configuration
* sub_module is valid
* sub_module->coll_scatterv is valid and point to this function
* Call han topological collective algorithm
*/
int algorithm_id = get_algorithm(SCATTERV,
MCA_COLL_HAN_ANY_MESSAGE_SIZE,
comm,
han_module);
scatterv = (mca_coll_base_module_scatterv_fn_t)mca_coll_han_algorithm_id_to_fn(SCATTERV, algorithm_id);
if (NULL == scatterv) { /* default behaviour */
scatterv = mca_coll_han_scatterv_intra;
}
} else {
/*
* If we get here:
* sub_module is valid
* sub_module->coll_scatterv is valid
* They point to the collective to use, according to the dynamic rules
* Selector's job is done, call the collective
*/
scatterv = sub_module->coll_scatterv;
}

return scatterv(sbuf, scounts, displs, sdtype,
rbuf, rcount, rdtype,
root, comm, sub_module);
}
7 changes: 6 additions & 1 deletion ompi/mca/coll/han/coll_han_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ static void han_module_clear(mca_coll_han_module_t *han_module)
CLEAN_PREV_COLL(han_module, gather);
CLEAN_PREV_COLL(han_module, gatherv);
CLEAN_PREV_COLL(han_module, scatter);
CLEAN_PREV_COLL(han_module, scatterv);

han_module->reproducible_reduce = NULL;
han_module->reproducible_reduce_module = NULL;
Expand Down Expand Up @@ -152,6 +153,7 @@ mca_coll_han_module_destruct(mca_coll_han_module_t * module)
OBJ_RELEASE_IF_NOT_NULL(module->previous_gatherv_module);
OBJ_RELEASE_IF_NOT_NULL(module->previous_reduce_module);
OBJ_RELEASE_IF_NOT_NULL(module->previous_scatter_module);
OBJ_RELEASE_IF_NOT_NULL(module->previous_scatterv_module);

han_module_clear(module);
}
Expand Down Expand Up @@ -254,7 +256,7 @@ mca_coll_han_comm_query(struct ompi_communicator_t * comm, int *priority)
han_module->super.coll_exscan = NULL;
han_module->super.coll_reduce_scatter = NULL;
han_module->super.coll_scan = NULL;
han_module->super.coll_scatterv = NULL;
han_module->super.coll_scatterv = mca_coll_han_scatterv_intra_dynamic;
han_module->super.coll_barrier = mca_coll_han_barrier_intra_dynamic;
han_module->super.coll_scatter = mca_coll_han_scatter_intra_dynamic;
han_module->super.coll_reduce = mca_coll_han_reduce_intra_dynamic;
Expand Down Expand Up @@ -316,6 +318,7 @@ han_module_enable(mca_coll_base_module_t * module,
HAN_SAVE_PREV_COLL_API(gatherv);
HAN_SAVE_PREV_COLL_API(reduce);
HAN_SAVE_PREV_COLL_API(scatter);
HAN_SAVE_PREV_COLL_API(scatterv);

/* set reproducible algos */
mca_coll_han_reduce_reproducible_decision(comm, module);
Expand All @@ -332,6 +335,7 @@ han_module_enable(mca_coll_base_module_t * module,
OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gatherv_module);
OBJ_RELEASE_IF_NOT_NULL(han_module->previous_reduce_module);
OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatter_module);
OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatterv_module);

return OMPI_ERROR;
}
Expand All @@ -354,6 +358,7 @@ mca_coll_han_module_disable(mca_coll_base_module_t * module,
OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gatherv_module);
OBJ_RELEASE_IF_NOT_NULL(han_module->previous_reduce_module);
OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatter_module);
OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatterv_module);

han_module_clear(han_module);

Expand Down
Loading

0 comments on commit 2152b61

Please sign in to comment.