Skip to content

Commit

Permalink
coll/HAN: Add support for XHC's "special" Reduce for the low-comm in …
Browse files Browse the repository at this point in the history
…Allreduce

MPI_Reduce in XHC is not complete; it is implemented as a sub-case of
Allreduce, and requires that the rbuf parameter is always present and
appropriately sized for all ranks (not only for the root). This
implementation is disabled by default and falls back to another
coll component, but can be manually enabled for a single operation
by invoking it with root=-1, which will do a reduce to rank 0.

Inside HAN's Allreduce, the rbuf parameter restriction is satisfied,
so it's safe to use this partially implemented Reduce.

This patch is temporary (TM) until XHC's Reduce is fully implemented.
The reason for its existence is the improved Allreduce performance
potential with XHC for the intra-comm.

Signed-off-by: George Katevenis <gkatev@ics.forth.gr>
  • Loading branch information
gkatev committed Mar 27, 2023
1 parent 17c4dba commit b08f94a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 34 deletions.
1 change: 1 addition & 0 deletions ompi/mca/coll/han/coll_han.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ struct mca_coll_han_allreduce_args_s {
int seg_count;
int root_up_rank;
int root_low_rank;
int root_reduce_low_rank;
int num_segments;
int cur_seg;
int w_rank;
Expand Down
93 changes: 59 additions & 34 deletions ompi/mca/coll/han/coll_han_allreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
*
* Copyright (c) 2020 Cisco Systems, Inc. All rights reserved.
* Copyright (c) 2022 IBM Corporation. All rights reserved
* Copyright (c) 2023 Computer Architecture and VLSI Systems (CARV)
* Laboratory, ICS Forth. All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
Expand All @@ -22,6 +24,7 @@

#include "coll_han.h"
#include "ompi/mca/coll/base/coll_base_functions.h"
#include "ompi/mca/coll/base/coll_base_util.h"
#include "ompi/mca/coll/base/coll_tags.h"
#include "ompi/mca/pml/pml.h"
#include "coll_han_trigger.h"
Expand All @@ -43,6 +46,7 @@ mca_coll_han_set_allreduce_args(mca_coll_han_allreduce_args_t * args,
struct ompi_op_t *op,
int root_up_rank,
int root_low_rank,
int root_reduce_low_rank,
struct ompi_communicator_t *up_comm,
struct ompi_communicator_t *low_comm,
int num_segments,
Expand All @@ -59,6 +63,7 @@ mca_coll_han_set_allreduce_args(mca_coll_han_allreduce_args_t * args,
args->op = op;
args->root_up_rank = root_up_rank;
args->root_low_rank = root_low_rank;
args->root_reduce_low_rank = root_reduce_low_rank;
args->up_comm = up_comm;
args->low_comm = low_comm;
args->num_segments = num_segments;
Expand Down Expand Up @@ -139,15 +144,26 @@ mca_coll_han_allreduce_intra(const void *sbuf,
int low_rank = ompi_comm_rank(low_comm);
int root_up_rank = 0;
int root_low_rank = 0;
int root_reduce_low_rank = 0;

mca_coll_base_avail_coll_t *low_1st_module = (mca_coll_base_avail_coll_t *)
opal_list_get_last(low_comm->c_coll->module_list);

// Invoke XHC's "special" Reduce
if(0 == strcmp(low_1st_module->ac_component_name, "xhc")
&& low_comm->c_coll->coll_reduce_module == low_1st_module->ac_module) {
root_reduce_low_rank = -1;
}

/* Create t0 task for the first segment */
mca_coll_task_t *t0 = OBJ_NEW(mca_coll_task_t);
/* Setup up t0 task arguments */
int *completed = (int *) malloc(sizeof(int));
completed[0] = 0;
mca_coll_han_allreduce_args_t *t = malloc(sizeof(mca_coll_han_allreduce_args_t));
mca_coll_han_set_allreduce_args(t, t0, (char *) sbuf, (char *) rbuf, seg_count, dtype, op,
root_up_rank, root_low_rank, up_comm, low_comm, num_segments, 0,
w_rank, count - (num_segments - 1) * seg_count,
root_up_rank, root_low_rank, root_reduce_low_rank, up_comm,
low_comm, num_segments, 0, w_rank, count - (num_segments - 1) * seg_count,
low_rank != root_low_rank, NULL, completed);
/* Init t0 task */
init_task(t0, mca_coll_han_allreduce_t0_task, (void *) (t));
Expand Down Expand Up @@ -215,18 +231,18 @@ int mca_coll_han_allreduce_t0_task(void *task_args)
if (MPI_IN_PLACE == t->sbuf) {
if (!t->noop) {
t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE, (char *) t->rbuf, t->seg_count, t->dtype,
t->op, t->root_low_rank, t->low_comm,
t->op, t->root_reduce_low_rank, t->low_comm,
t->low_comm->c_coll->coll_reduce_module);
}
else {
t->low_comm->c_coll->coll_reduce((char *) t->rbuf, NULL, t->seg_count, t->dtype,
t->op, t->root_low_rank, t->low_comm,
t->op, t->root_reduce_low_rank, t->low_comm,
t->low_comm->c_coll->coll_reduce_module);
}
}
else {
t->low_comm->c_coll->coll_reduce((char *) t->sbuf, (char *) t->rbuf, t->seg_count, t->dtype,
t->op, t->root_low_rank, t->low_comm,
t->op, t->root_reduce_low_rank, t->low_comm,
t->low_comm->c_coll->coll_reduce_module);
}
return OMPI_SUCCESS;
Expand Down Expand Up @@ -267,21 +283,20 @@ int mca_coll_han_allreduce_t1_task(void *task_args)
if (!t->noop) {
t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE,
(char *) t->rbuf + extent * t->seg_count, tmp_count,
t->dtype, t->op, t->root_low_rank, t->low_comm,
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
t->low_comm->c_coll->coll_reduce_module);
} else {
t->low_comm->c_coll->coll_reduce((char *) t->rbuf + extent * t->seg_count,
NULL, tmp_count,
t->dtype, t->op, t->root_low_rank, t->low_comm,
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
t->low_comm->c_coll->coll_reduce_module);

}
} else {
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + extent * t->seg_count,
(char *) t->rbuf + extent * t->seg_count, tmp_count,
t->dtype, t->op, t->root_low_rank, t->low_comm,
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
t->low_comm->c_coll->coll_reduce_module);
}
}
}
if (!t->noop) {
ompi_request_wait(&ireduce_req, MPI_STATUS_IGNORE);
Expand Down Expand Up @@ -337,25 +352,25 @@ int mca_coll_han_allreduce_t2_task(void *task_args)
tmp_count = t->last_seg_count;
}

if (t->sbuf == MPI_IN_PLACE) {
if (!t->noop) {
t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE,
if (t->sbuf == MPI_IN_PLACE) {
if (!t->noop) {
t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE,
(char *) t->rbuf + 2 * extent * t->seg_count, tmp_count,
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
t->low_comm->c_coll->coll_reduce_module);
} else {
t->low_comm->c_coll->coll_reduce((char *) t->rbuf + 2 * extent * t->seg_count,
NULL, tmp_count,
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
t->low_comm->c_coll->coll_reduce_module);

}
} else {
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 2 * extent * t->seg_count,
(char *) t->rbuf + 2 * extent * t->seg_count, tmp_count,
t->dtype, t->op, t->root_low_rank, t->low_comm,
t->low_comm->c_coll->coll_reduce_module);
} else {
t->low_comm->c_coll->coll_reduce((char *) t->rbuf + 2 * extent * t->seg_count,
NULL, tmp_count,
t->dtype, t->op, t->root_low_rank, t->low_comm,
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
t->low_comm->c_coll->coll_reduce_module);

}
} else {
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 2 * extent * t->seg_count,
(char *) t->rbuf + 2 * extent * t->seg_count, tmp_count,
t->dtype, t->op, t->root_low_rank, t->low_comm,
t->low_comm->c_coll->coll_reduce_module);
}
}
}
if (!t->noop && req_count > 0) {
ompi_request_wait_all(req_count, reqs, MPI_STATUSES_IGNORE);
Expand Down Expand Up @@ -421,18 +436,18 @@ int mca_coll_han_allreduce_t3_task(void *task_args)
if (!t->noop) {
t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE,
(char *) t->rbuf + 3 * extent * t->seg_count, tmp_count,
t->dtype, t->op, t->root_low_rank, t->low_comm,
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
t->low_comm->c_coll->coll_reduce_module);
} else {
} else {
t->low_comm->c_coll->coll_reduce((char *) t->rbuf + 3 * extent * t->seg_count,
NULL, tmp_count,
t->dtype, t->op, t->root_low_rank, t->low_comm,
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
t->low_comm->c_coll->coll_reduce_module);
}
} else {
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 3 * extent * t->seg_count,
(char *) t->rbuf + 3 * extent * t->seg_count, tmp_count,
t->dtype, t->op, t->root_low_rank, t->low_comm,
t->dtype, t->op, t->root_reduce_low_rank, t->low_comm,
t->low_comm->c_coll->coll_reduce_module);
}
}
Expand Down Expand Up @@ -473,6 +488,7 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf,
ompi_communicator_t *low_comm;
ompi_communicator_t *up_comm;
int root_low_rank = 0;
int root_reduce_low_rank = 0;
int low_rank;
int ret;
mca_coll_han_module_t *han_module = (mca_coll_han_module_t *)module;
Expand Down Expand Up @@ -504,22 +520,31 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf,
up_comm = han_module->sub_comm[INTER_NODE];
low_rank = ompi_comm_rank(low_comm);

mca_coll_base_avail_coll_t *low_1st_module = (mca_coll_base_avail_coll_t *)
opal_list_get_last(low_comm->c_coll->module_list);

// Invoke XHC's "special" Reduce
if(0 == strcmp(low_1st_module->ac_component_name, "xhc")
&& low_comm->c_coll->coll_reduce_module == low_1st_module->ac_module) {
root_reduce_low_rank = -1;
}

/* Low_comm reduce */
if (MPI_IN_PLACE == sbuf) {
if (low_rank == root_low_rank) {
ret = low_comm->c_coll->coll_reduce(MPI_IN_PLACE, (char *)rbuf,
count, dtype, op, root_low_rank,
count, dtype, op, root_reduce_low_rank,
low_comm, low_comm->c_coll->coll_reduce_module);
}
else {
ret = low_comm->c_coll->coll_reduce((char *)rbuf, NULL,
count, dtype, op, root_low_rank,
count, dtype, op, root_reduce_low_rank,
low_comm, low_comm->c_coll->coll_reduce_module);
}
}
else {
ret = low_comm->c_coll->coll_reduce((char *)sbuf, (char *)rbuf,
count, dtype, op, root_low_rank,
count, dtype, op, root_reduce_low_rank,
low_comm, low_comm->c_coll->coll_reduce_module);
}
if (OPAL_UNLIKELY(OMPI_SUCCESS != ret)) {
Expand Down

0 comments on commit b08f94a

Please sign in to comment.