Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ompi/mca/coll/base/coll_base_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ int ompi_coll_base_reduce_scatter_intra_ring(REDUCESCATTER_ARGS);
/* Reduce_scatter_block */
int ompi_coll_base_reduce_scatter_block_basic(REDUCESCATTERBLOCK_ARGS);
int ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(REDUCESCATTERBLOCK_ARGS);
int ompi_coll_base_reduce_scatter_block_intra_recursivehalving(REDUCESCATTERBLOCK_ARGS);

/* Scan */
int ompi_coll_base_scan_intra_recursivedoubling(SCAN_ARGS);
Expand Down
214 changes: 211 additions & 3 deletions ompi/mca/coll/base/coll_base_reduce_scatter_block.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/communicator/communicator.h"
#include "ompi/mca/coll/coll.h"
#include "ompi/mca/coll/base/coll_tags.h"
#include "ompi/mca/coll/base/coll_base_functions.h"
#include "ompi/mca/coll/basic/coll_basic.h"
#include "ompi/mca/pml/pml.h"
#include "ompi/op/op.h"
#include "ompi/mca/coll/base/coll_base_functions.h"
#include "coll_tags.h"
#include "coll_base_functions.h"
#include "coll_base_topo.h"
#include "coll_base_util.h"


/*
* ompi_reduce_scatter_block_basic
*
Expand Down Expand Up @@ -303,3 +304,210 @@ ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(
free(tmprecv_raw);
return err;
}

/*
* ompi_range_sum: Returns sum of elems in intersection of [a, b] and [0, r]
* index: 0 1 2 3 4 ... r r+1 r+2 ... nproc_pof2
* value: 2 2 2 2 2 ... 2 1 1 ... 1
*/
static int ompi_range_sum(int a, int b, int r)
{
if (r < a)
return b - a + 1;
else if (r > b)
return 2 * (b - a + 1);
return 2 * (r - a + 1) + b - r;
}

/*
* ompi_coll_base_reduce_scatter_block_intra_recursivehalving
*
* Function: Recursive halving algorithm for reduce_scatter_block
* Accepts: Same as MPI_Reduce_scatter_block
* Returns: MPI_SUCCESS or error code
*
* Description: Implements recursive halving algorithm for MPI_Reduce_scatter_block.
* The algorithm can be used by commutative operations only.
*
* Limitations: commutative operations only
* Memory requirements (per process): 2 * rcount * comm_size * typesize
*/
int
ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
const void *sbuf, void *rbuf, int rcount, struct ompi_datatype_t *dtype,
struct ompi_op_t *op, struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
{
char *tmprecv_raw = NULL, *tmpbuf_raw = NULL, *tmprecv, *tmpbuf;
ptrdiff_t span, gap, totalcount, extent;
int err = MPI_SUCCESS;
int comm_size = ompi_comm_size(comm);
int rank = ompi_comm_rank(comm);

OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
"coll:base:reduce_scatter_block_intra_recursivehalving: rank %d/%d",
rank, comm_size));
if (rcount == 0 || comm_size < 2)
return MPI_SUCCESS;

if (!ompi_op_is_commute(op)) {
OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
"coll:base:reduce_scatter_block_intra_recursivehalving: rank %d/%d "
"switching to basic reduce_scatter_block", rank, comm_size));
return ompi_coll_base_reduce_scatter_block_basic(sbuf, rbuf, rcount, dtype,
op, comm, module);
}
totalcount = comm_size * rcount;
ompi_datatype_type_extent(dtype, &extent);
span = opal_datatype_span(&dtype->super, totalcount, &gap);
tmpbuf_raw = malloc(span);
tmprecv_raw = malloc(span);
if (NULL == tmpbuf_raw || NULL == tmprecv_raw) {
err = OMPI_ERR_OUT_OF_RESOURCE;
goto cleanup_and_return;
}
tmpbuf = tmpbuf_raw - gap;
tmprecv = tmprecv_raw - gap;

if (sbuf != MPI_IN_PLACE) {
err = ompi_datatype_copy_content_same_ddt(dtype, totalcount, tmpbuf, (char *)sbuf);
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
} else {
err = ompi_datatype_copy_content_same_ddt(dtype, totalcount, tmpbuf, rbuf);
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
}

/*
* Step 1. Reduce the number of processes to the nearest lower power of two
* p' = 2^{\floor{\log_2 p}} by removing r = p - p' processes.
* In the first 2r processes (ranks 0 to 2r - 1), all the even ranks send
* the input vector to their neighbor (rank + 1) and all the odd ranks recv
* the input vector and perform local reduction.
* The odd ranks (0 to 2r - 1) contain the reduction with the input
* vector on their neighbors (the even ranks). The first r odd
* processes and the p - 2r last processes are renumbered from
* 0 to 2^{\floor{\log_2 p}} - 1. Even ranks do not participate in the
* rest of the algorithm.
*/

/* Find nearest power-of-two less than or equal to comm_size */
int nprocs_pof2 = opal_next_poweroftwo(comm_size);
nprocs_pof2 >>= 1;
int nprocs_rem = comm_size - nprocs_pof2;

int vrank = -1;
if (rank < 2 * nprocs_rem) {
if ((rank % 2) == 0) {
/* Even process */
err = MCA_PML_CALL(send(tmpbuf, totalcount, dtype, rank + 1,
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
MCA_PML_BASE_SEND_STANDARD, comm));
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
/* This process does not pariticipate in the rest of the algorithm */
vrank = -1;
} else {
/* Odd process */
err = MCA_PML_CALL(recv(tmprecv, totalcount, dtype, rank - 1,
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
comm, MPI_STATUS_IGNORE));
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
ompi_op_reduce(op, tmprecv, tmpbuf, totalcount, dtype);
/* Adjust rank to be the bottom "remain" ranks */
vrank = rank / 2;
}
} else {
/* Adjust rank to show that the bottom "even remain" ranks dropped out */
vrank = rank - nprocs_rem;
}

if (vrank != -1) {
/*
* Step 2. Recursive vector halving. We have p' = 2^{\floor{\log_2 p}}
* power-of-two number of processes with new ranks (vrank) and partial
* result in tmpbuf.
* All processes then compute the reduction between the local
* buffer and the received buffer. In the next \log_2(p') - 1 steps, the
* buffers are recursively halved. At the end, each of the p' processes
* has 1 / p' of the total reduction result.
*/
int send_index = 0, recv_index = 0, last_index = nprocs_pof2;
for (int mask = nprocs_pof2 >> 1; mask > 0; mask >>= 1) {
int vpeer = vrank ^ mask;
int peer = (vpeer < nprocs_rem) ? vpeer * 2 + 1 : vpeer + nprocs_rem;

/*
* Calculate the recv_count and send_count because the
* even-numbered processes who no longer participate will
* have their result calculated by the process to their
* right (rank + 1).
*/
int send_count = 0, recv_count = 0;
if (vrank < vpeer) {
/* Send the right half of the buffer, recv the left half */
send_index = recv_index + mask;
send_count = rcount * ompi_range_sum(send_index, last_index - 1, nprocs_rem - 1);
recv_count = rcount * ompi_range_sum(recv_index, send_index - 1, nprocs_rem - 1);
} else {
/* Send the left half of the buffer, recv the right half */
recv_index = send_index + mask;
send_count = rcount * ompi_range_sum(send_index, recv_index - 1, nprocs_rem - 1);
recv_count = rcount * ompi_range_sum(recv_index, last_index - 1, nprocs_rem - 1);
}
ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1) ?
2 * recv_index : nprocs_rem + recv_index);
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
2 * send_index : nprocs_rem + send_index);
struct ompi_request_t *request = NULL;

if (recv_count > 0) {
err = MCA_PML_CALL(irecv(tmprecv + rdispl * extent, recv_count,
dtype, peer, MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
comm, &request));
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
}
if (send_count > 0) {
err = MCA_PML_CALL(send(tmpbuf + sdispl * extent, send_count,
dtype, peer, MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
MCA_PML_BASE_SEND_STANDARD,
comm));
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
}
if (recv_count > 0) {
err = ompi_request_wait(&request, MPI_STATUS_IGNORE);
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
ompi_op_reduce(op, tmprecv + rdispl * extent,
tmpbuf + rdispl * extent, recv_count, dtype);
}
send_index = recv_index;
last_index = recv_index + mask;
}
err = ompi_datatype_copy_content_same_ddt(dtype, rcount, rbuf,
tmpbuf + (ptrdiff_t)rank * rcount * extent);
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
}

/* Step 3. Send the result to excluded even ranks */
if (rank < 2 * nprocs_rem) {
if ((rank % 2) == 0) {
/* Even process */
err = MCA_PML_CALL(recv(rbuf, rcount, dtype, rank + 1,
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK, comm,
MPI_STATUS_IGNORE));
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
} else {
/* Odd process */
err = MCA_PML_CALL(send(tmpbuf + (ptrdiff_t)(rank - 1) * rcount * extent,
rcount, dtype, rank - 1,
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
MCA_PML_BASE_SEND_STANDARD, comm));
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
}
}

cleanup_and_return:
if (tmpbuf_raw)
free(tmpbuf_raw);
if (tmprecv_raw)
free(tmprecv_raw);
return err;
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ static mca_base_var_enum_value_t reduce_scatter_block_algorithms[] = {
{0, "ignore"},
{1, "basic"},
{2, "recursive_doubling"},
{3, "recursive_halving"},
{0, NULL}
};

Expand Down Expand Up @@ -125,6 +126,8 @@ int ompi_coll_tuned_reduce_scatter_block_intra_do_this(const void *sbuf, void *r
dtype, op, comm, module);
case (2): return ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(sbuf, rbuf, rcount,
dtype, op, comm, module);
case (3): return ompi_coll_base_reduce_scatter_block_intra_recursivehalving(sbuf, rbuf, rcount,
dtype, op, comm, module);
} /* switch */
OPAL_OUTPUT((ompi_coll_tuned_stream, "coll:tuned:reduce_scatter_block_intra_do_this attempt to select algorithm %d when only 0-%d is valid?",
algorithm, ompi_coll_tuned_forced_max_algorithms[REDUCESCATTERBLOCK]));
Expand Down