Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement hierarchical MPI_Gatherv and MPI_Scatterv #12376

Merged
merged 3 commits into from
Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ompi/mca/coll/han/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ coll_han_barrier.c \
coll_han_bcast.c \
coll_han_reduce.c \
coll_han_scatter.c \
coll_han_scatterv.c \
coll_han_gather.c \
coll_han_gatherv.c \
coll_han_allreduce.c \
coll_han_allgather.c \
coll_han_component.c \
Expand All @@ -31,7 +33,8 @@ coll_han_algorithms.c \
coll_han_dynamic.c \
coll_han_dynamic_file.c \
coll_han_topo.c \
coll_han_subcomms.c
coll_han_subcomms.c \
coll_han_utils.c

# Make the output library in this directory, and name it either
# mca_<type>_<name>.la (for DSO builds) or libmca_<type>_<name>.la
Expand Down
48 changes: 44 additions & 4 deletions ompi/mca/coll/han/coll_han.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
* reserved.
* Copyright (c) 2022 IBM Corporation. All rights reserved
* Copyright (c) 2020-2022 Bull S.A.S. All rights reserved.
* Copyright (c) Amazon.com, Inc. or its affiliates.
* All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
Expand Down Expand Up @@ -189,7 +191,9 @@ typedef struct mca_coll_han_op_module_name_t {
mca_coll_han_op_up_low_module_name_t allreduce;
mca_coll_han_op_up_low_module_name_t allgather;
mca_coll_han_op_up_low_module_name_t gather;
mca_coll_han_op_up_low_module_name_t gatherv;
mca_coll_han_op_up_low_module_name_t scatter;
mca_coll_han_op_up_low_module_name_t scatterv;
} mca_coll_han_op_module_name_t;

/**
Expand Down Expand Up @@ -233,10 +237,18 @@ typedef struct mca_coll_han_component_t {
uint32_t han_gather_up_module;
/* low level module for gather */
uint32_t han_gather_low_module;
/* up level module for gatherv */
uint32_t han_gatherv_up_module;
/* low level module for gatherv */
uint32_t han_gatherv_low_module;
/* up level module for scatter */
uint32_t han_scatter_up_module;
/* low level module for scatter */
uint32_t han_scatter_low_module;
/* up level module for scatterv */
uint32_t han_scatterv_up_module;
/* low level module for scatterv */
uint32_t han_scatterv_low_module;
/* name of the modules */
mca_coll_han_op_module_name_t han_op_module_name;
/* whether we need reproducible results
Expand Down Expand Up @@ -277,8 +289,10 @@ typedef struct mca_coll_han_single_collective_fallback_s {
mca_coll_base_module_barrier_fn_t barrier;
mca_coll_base_module_bcast_fn_t bcast;
mca_coll_base_module_gather_fn_t gather;
mca_coll_base_module_gatherv_fn_t gatherv;
mca_coll_base_module_reduce_fn_t reduce;
mca_coll_base_module_scatter_fn_t scatter;
mca_coll_base_module_scatterv_fn_t scatterv;
} module_fn;
mca_coll_base_module_t* module;
} mca_coll_han_single_collective_fallback_t;
Expand All @@ -296,7 +310,9 @@ typedef struct mca_coll_han_collectives_fallback_s {
mca_coll_han_single_collective_fallback_t bcast;
mca_coll_han_single_collective_fallback_t reduce;
mca_coll_han_single_collective_fallback_t gather;
mca_coll_han_single_collective_fallback_t gatherv;
mca_coll_han_single_collective_fallback_t scatter;
mca_coll_han_single_collective_fallback_t scatterv;
} mca_coll_han_collectives_fallback_t;

/** Coll han module */
Expand Down Expand Up @@ -369,9 +385,14 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
#define previous_gather fallback.gather.module_fn.gather
#define previous_gather_module fallback.gather.module

#define previous_gatherv fallback.gatherv.module_fn.gatherv
#define previous_gatherv_module fallback.gatherv.module

#define previous_scatter fallback.scatter.module_fn.scatter
#define previous_scatter_module fallback.scatter.module

#define previous_scatterv fallback.scatterv.module_fn.scatterv
#define previous_scatterv_module fallback.scatterv.module

/* macro to correctly load a fallback collective module */
#define HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, COLL) \
Expand All @@ -391,7 +412,9 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, barrier); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, bcast); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatter); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatterv); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gather); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gatherv); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, reduce); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allreduce); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allgather); \
Expand Down Expand Up @@ -432,11 +455,16 @@ int *mca_coll_han_topo_init(struct ompi_communicator_t *comm, mca_coll_han_modul

/* Utils */
static inline void
mca_coll_han_get_ranks(int *vranks, int root, int low_size,
int *root_low_rank, int *root_up_rank)
mca_coll_han_get_ranks(int *vranks, int w_rank, int low_size,
int *low_rank, int *up_rank)
{
*root_up_rank = vranks[root] / low_size;
*root_low_rank = vranks[root] % low_size;
if (up_rank) {
*up_rank = vranks[w_rank] / low_size;
}

if (low_rank) {
*low_rank = vranks[w_rank] % low_size;
}
}

const char* mca_coll_han_topo_lvl_to_str(TOPO_LVL_T topo_lvl);
Expand Down Expand Up @@ -469,11 +497,17 @@ int
mca_coll_han_gather_intra_dynamic(GATHER_BASE_ARGS,
mca_coll_base_module_t *module);
int
mca_coll_han_gatherv_intra_dynamic(GATHERV_BASE_ARGS,
mca_coll_base_module_t *module);
int
mca_coll_han_reduce_intra_dynamic(REDUCE_BASE_ARGS,
mca_coll_base_module_t *module);
int
mca_coll_han_scatter_intra_dynamic(SCATTER_BASE_ARGS,
mca_coll_base_module_t *module);
int
mca_coll_han_scatterv_intra_dynamic(SCATTERV_BASE_ARGS,
mca_coll_base_module_t *module);

int mca_coll_han_barrier_intra_simple(struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);
Expand All @@ -486,4 +520,10 @@ ompi_coll_han_reorder_gather(const void *sbuf,
struct ompi_communicator_t *comm,
int * topo);

size_t
coll_han_utils_gcd(const size_t *numerators, const size_t size);

int
coll_han_utils_create_contiguous_datatype(size_t count, const ompi_datatype_t *oldType,
ompi_datatype_t **newType);
#endif /* MCA_COLL_HAN_EXPORT_H */
8 changes: 8 additions & 0 deletions ompi/mca/coll/han/coll_han_algorithms.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,19 @@ mca_coll_han_algorithm_value_t* mca_coll_han_available_algorithms[COLLCOUNT] =
{"simple", (fnptr_t) &mca_coll_han_scatter_intra_simple}, // 2-level
{ 0 }
},
[SCATTERV] = (mca_coll_han_algorithm_value_t[]){
{"intra", (fnptr_t) &mca_coll_han_scatterv_intra}, // 2-level
{ 0 }
},
[GATHER] = (mca_coll_han_algorithm_value_t[]){
{"intra", (fnptr_t) &mca_coll_han_gather_intra}, // 2-level
{"simple", (fnptr_t) &mca_coll_han_gather_intra_simple}, // 2-level
{ 0 }
},
[GATHERV] = (mca_coll_han_algorithm_value_t[]){
{"intra", (fnptr_t) &mca_coll_han_gatherv_intra}, // 2-level
{ 0 }
},
[ALLGATHER] = (mca_coll_han_algorithm_value_t[]){
{"intra", (fnptr_t)&mca_coll_han_allgather_intra}, // 2-level
{"simple", (fnptr_t)&mca_coll_han_allgather_intra_simple}, // 2-level
Expand Down
17 changes: 17 additions & 0 deletions ompi/mca/coll/han/coll_han_algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ mca_coll_han_scatter_intra_simple(const void *sbuf, int scount,
struct ompi_communicator_t *comm,
mca_coll_base_module_t * module);

/* Scatterv */
int
mca_coll_han_scatterv_intra(const void *sbuf, const int *scounts,
const int *displs, struct ompi_datatype_t *sdtype,
void *rbuf, int rcount,
struct ompi_datatype_t *rdtype,
int root,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);

/* Gather */
int
mca_coll_han_gather_intra(const void *sbuf, int scount,
Expand All @@ -176,6 +186,13 @@ mca_coll_han_gather_intra_simple(const void *sbuf, int scount,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);

/* Gatherv */
int
mca_coll_han_gatherv_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
void *rbuf, const int *rcounts, const int *displs,
struct ompi_datatype_t *rdtype, int root,
struct ompi_communicator_t *comm, mca_coll_base_module_t *module);

/* Allgather */
int
mca_coll_han_allgather_intra(const void *sbuf, int scount,
Expand Down
34 changes: 34 additions & 0 deletions ompi/mca/coll/han/coll_han_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,21 @@ static int han_close(void)
free(mca_coll_han_component.han_op_module_name.gather.han_op_low_module_name);
mca_coll_han_component.han_op_module_name.gather.han_op_low_module_name = NULL;

free(mca_coll_han_component.han_op_module_name.gatherv.han_op_up_module_name);
mca_coll_han_component.han_op_module_name.gatherv.han_op_up_module_name = NULL;
free(mca_coll_han_component.han_op_module_name.gatherv.han_op_low_module_name);
mca_coll_han_component.han_op_module_name.gatherv.han_op_low_module_name = NULL;

free(mca_coll_han_component.han_op_module_name.scatter.han_op_up_module_name);
mca_coll_han_component.han_op_module_name.scatter.han_op_up_module_name = NULL;
free(mca_coll_han_component.han_op_module_name.scatter.han_op_low_module_name);
mca_coll_han_component.han_op_module_name.scatter.han_op_low_module_name = NULL;

free(mca_coll_han_component.han_op_module_name.scatterv.han_op_up_module_name);
mca_coll_han_component.han_op_module_name.scatterv.han_op_up_module_name = NULL;
free(mca_coll_han_component.han_op_module_name.scatterv.han_op_low_module_name);
mca_coll_han_component.han_op_module_name.scatterv.han_op_low_module_name = NULL;

return OMPI_SUCCESS;
}

Expand Down Expand Up @@ -344,6 +354,18 @@ static int han_register(void)
OPAL_INFO_LVL_9, &cs->han_gather_low_module,
&cs->han_op_module_name.gather.han_op_low_module_name);

cs->han_gatherv_up_module = 0;
(void) mca_coll_han_query_module_from_mca(c, "gatherv_up_module",
"up level module for gatherv, 0 basic",
OPAL_INFO_LVL_9, &cs->han_gatherv_up_module,
&cs->han_op_module_name.gatherv.han_op_up_module_name);

cs->han_gatherv_low_module = 0;
(void) mca_coll_han_query_module_from_mca(c, "gatherv_low_module",
"low level module for gatherv, 0 basic",
OPAL_INFO_LVL_9, &cs->han_gatherv_low_module,
&cs->han_op_module_name.gatherv.han_op_low_module_name);

cs->han_scatter_up_module = 0;
(void) mca_coll_han_query_module_from_mca(c, "scatter_up_module",
"up level module for scatter, 0 libnbc, 1 adapt",
Expand All @@ -356,6 +378,18 @@ static int han_register(void)
OPAL_INFO_LVL_9, &cs->han_scatter_low_module,
&cs->han_op_module_name.scatter.han_op_low_module_name);

cs->han_scatterv_up_module = 0;
(void) mca_coll_han_query_module_from_mca(c, "scatterv_up_module",
"up level module for scatterv, 0 basic",
OPAL_INFO_LVL_9, &cs->han_scatterv_up_module,
&cs->han_op_module_name.scatterv.han_op_up_module_name);

cs->han_scatterv_low_module = 0;
(void) mca_coll_han_query_module_from_mca(c, "scatterv_low_module",
"low level module for scatterv, 0 basic",
OPAL_INFO_LVL_9, &cs->han_scatterv_low_module,
&cs->han_op_module_name.scatterv.han_op_low_module_name);

cs->han_reproducible = 0;
(void) mca_base_component_var_register(c, "reproducible",
"whether we need reproducible results "
Expand Down
Loading
Loading