From 87e6eb867f225b95bbf4a1422d22f71cce3f48f2 Mon Sep 17 00:00:00 2001 From: Jessie Yang Date: Fri, 5 Apr 2024 14:52:05 -0700 Subject: [PATCH] mca/coll: Add any radix k for alltoall bruck algorithm This method extends ompi_coll_base_alltoall_intra_bruck to handle any radix k. Signed-off-by: Jessie Yang --- ompi/mca/coll/base/coll_base_alltoall.c | 187 ++++++++++++++---- ompi/mca/coll/base/coll_base_functions.h | 2 +- .../coll/tuned/coll_tuned_alltoall_decision.c | 2 +- .../coll/tuned/coll_tuned_decision_fixed.c | 3 +- 4 files changed, 152 insertions(+), 42 deletions(-) diff --git a/ompi/mca/coll/base/coll_base_alltoall.c b/ompi/mca/coll/base/coll_base_alltoall.c index f5650b4312b..217d952d471 100644 --- a/ompi/mca/coll/base/coll_base_alltoall.c +++ b/ompi/mca/coll/base/coll_base_alltoall.c @@ -235,19 +235,96 @@ int ompi_coll_base_alltoall_intra_pairwise(const void *sbuf, int scount, return err; } - -int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount, - struct ompi_datatype_t *sdtype, - void* rbuf, int rcount, - struct ompi_datatype_t *rdtype, - struct ompi_communicator_t *comm, - mca_coll_base_module_t *module) +/* + * + * Function: ompi_coll_base_alltoall_intra_k_bruck using O(logk(N)) steps + * Accepts: Same arguments as MPI_Alltoall + * Returns: MPI_SUCCESS or error code + * + * Description: This method extends ompi_coll_base_alltoall_intra_bruck to handle any + * radix k(k >= 2). + * + * Example on 6 ranks with k = 4 + * # 0 1 2 3 4 5 + * [00] [10] [20] [30] [40] [50] + * [01] [11] [21] [31] [41] [51] + * [02] [12] [22] [32] [42] [52] + * [03] [13] [23] [33] [43] [53] + * [04] [14] [24] [34] [44] [54] + * [05] [15] [25] [35] [45] [55] + * After local rotation + * # 0 1 2 3 4 5 + * [00] [11] [22] [33] [44] [55] + * [01] [12] [23] [34] [45] [50] + * [02] [13] [24] [35] [40] [51] + * [03] [14] [25] [30] [41] [52] + * [04] [15] [20] [31] [42] [53] + * [05] [10] [21] [32] [43] [54] + * Phase 0: send message to (rank + k^0 * i), receive message from (rank - k^0 * i) + * send the data block whose least significant bit is i in base k representation + * for i between [1, k-1] + * i = 1: send the data block at offset i, i + k to (rank + 1) + * # 0 1 2 3 4 5 + * [00] [11] [22] [33] [44] [55] + * [50] [01] [12] [23] [34] [45] + * [02] [13] [24] [35] [40] [51] + * [03] [14] [25] [30] [41] [52] + * [04] [15] [20] [31] [42] [53] + * [54] [05] [10] [21] [32] [43] + * i = 2: send the data block at offset i to (rank + 2) + * # 0 1 2 3 4 5 + * [00] [11] [22] [33] [44] [55] + * [50] [01] [12] [23] [34] [45] + * [40] [51] [02] [13] [24] [35] + * [03] [14] [25] [30] [41] [52] + * [04] [15] [20] [31] [42] [53] + * [54] [05] [10] [21] [32] [43] + * i = 3: send the data block at offset i to (rank + 3) + * # 0 1 2 3 4 5 + * [00] [11] [22] [33] [44] [55] + * [50] [01] [12] [23] [34] [45] + * [40] [51] [02] [13] [24] [35] + * [30] [41] [52] [03] [14] [25] + * [04] [15] [20] [31] [42] [53] + * [54] [05] [10] [21] [32] [43] + * Phase 1: send message to (rank + k^1 * i), receive message from (rank - k^1 * i) + * send the data block whose second bit is i in base k representation + * for i between [1, k-1] + * i = 1: send the data block at offset k with size of min(k, size-i*k)=2 to (rank + 4) + * # 0 1 2 3 4 5 + * [00] [11] [22] [33] [44] [55] + * [50] [01] [12] [23] [34] [45] + * [40] [51] [02] [13] [24] [35] + * [30] [41] [52] [03] [14] [25] + * [20] [31] [42] [53] [04] [15] + * [10] [21] [32] [43] [54] [05] + * i = 2: nothing is to be sent + * i = 3: nothing is to be sent + * After local inverse rotation + * # 0 1 2 3 4 5 + * [00] [01] [02] [03] [04] [05] + * [10] [11] [12] [13] [14] [15] + * [20] [21] [22] [23] [24] [25] + * [30] [31] [32] [33] [34] [35] + * [40] [41] [42] [43] [44] [45] + * [50] [51] [52] [53] [54] [55] + * +*/ +int ompi_coll_base_alltoall_intra_k_bruck(const void *sbuf, int scount, + struct ompi_datatype_t *sdtype, + void* rbuf, int rcount, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module, + int radix) { - int i, line = -1, rank, size, err = 0; + int i, j, line = -1, rank, size, err = 0; int sendto, recvfrom, distance, *displs = NULL; char *tmpbuf = NULL, *tmpbuf_free = NULL; ptrdiff_t sext, rext, span, gap = 0; struct ompi_datatype_t *new_ddt; + ompi_request_t **reqs; + int num_reqs, max_reqs = 0; if (MPI_IN_PLACE == sbuf) { return mca_coll_base_alltoall_intra_basic_inplace (rbuf, rcount, rdtype, @@ -257,8 +334,12 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount, size = ompi_comm_size(comm); rank = ompi_comm_rank(comm); + if (radix < 2) { + line = __LINE__; err = -1; goto err_hndl; + } + OPAL_OUTPUT((ompi_coll_base_framework.framework_output, - "coll:base:alltoall_intra_bruck rank %d", rank)); + "coll:base:alltoall_intra_k_bruck radix %d rank %d", radix, rank)); err = ompi_datatype_type_extent (sdtype, &sext); if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; } @@ -297,42 +378,57 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount, } /* perform communication step */ - for (distance = 1; distance < size; distance<<=1) { - - sendto = (rank + distance) % size; - recvfrom = (rank - distance + size) % size; - - new_ddt = ompi_datatype_create((1 + size/distance) * (2 + rdtype->super.desc.used)); + max_reqs = 2 * (radix - 1); + reqs = ompi_coll_base_comm_get_reqs(module->base_data, max_reqs); + for (distance = 1; distance < size; distance *= radix) { + num_reqs = 0; + for (i = 1; i < radix; i++) { - /* Create datatype describing data sent/received */ - for (i = distance; i < size; i += 2*distance) { - int nblocks = distance; - if (i + distance >= size) { - nblocks = size - i; + if (distance * i >= size) { + break; } - ompi_datatype_add(new_ddt, rdtype, rcount * nblocks, - i * rcount * rext, rext); - } - /* Commit the new datatype */ - err = ompi_datatype_commit(&new_ddt); - if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; } + sendto = (rank + distance * i) % size; + recvfrom = (rank - distance * i + size) % size; - /* Sendreceive */ - err = ompi_coll_base_sendrecv ( tmpbuf, 1, new_ddt, sendto, - MCA_COLL_BASE_TAG_ALLTOALL, - rbuf, 1, new_ddt, recvfrom, - MCA_COLL_BASE_TAG_ALLTOALL, - comm, MPI_STATUS_IGNORE, rank ); - if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; } + new_ddt = ompi_datatype_create((1 + size/distance) * (2 + rdtype->super.desc.used)); - /* Copy back new data from recvbuf to tmpbuf */ - err = ompi_datatype_copy_content_same_ddt(new_ddt, 1,tmpbuf, (char *) rbuf); - if (err < 0) { line = __LINE__; err = -1; goto err_hndl; } + /* Create datatype describing data sent/received */ + for (j = i * distance; j < size; j += radix * distance) { + int nblocks = distance; + if (j + distance >= size) { + nblocks = size - j; + } + ompi_datatype_add(new_ddt, rdtype, rcount * nblocks, + j * rcount * rext, rext); + } - /* free ddt */ - err = ompi_datatype_destroy(&new_ddt); - if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; } + /* Commit the new datatype */ + err = ompi_datatype_commit(&new_ddt); + if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; } + + err = MCA_PML_CALL(irecv(rbuf, 1, new_ddt, recvfrom, + MCA_COLL_BASE_TAG_ALLTOALL, + comm, + &reqs[num_reqs++])); + if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; } + err = MCA_PML_CALL(isend(tmpbuf, 1, new_ddt, sendto, + MCA_COLL_BASE_TAG_ALLTOALL, + MCA_PML_BASE_SEND_STANDARD, + comm, + &reqs[num_reqs++])); + if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; } + + /* Copy back new data from recvbuf to tmpbuf */ + err = ompi_datatype_copy_content_same_ddt(new_ddt, 1, tmpbuf, (char *) rbuf); + if (err < 0) { line = __LINE__; err = -1; goto err_hndl; } + + /* free ddt */ + err = ompi_datatype_destroy(&new_ddt); + if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; } + } + err = ompi_request_wait_all(num_reqs, reqs, MPI_STATUSES_IGNORE); + if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; } } /* end of for (distance = 1... */ /* Step 3 - local rotation - */ @@ -349,6 +445,19 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount, return OMPI_SUCCESS; err_hndl: + if( NULL != reqs ) { + if (MPI_ERR_IN_STATUS == err) { + for( num_reqs = 0; num_reqs < max_reqs; num_reqs++ ) { + if (MPI_REQUEST_NULL == reqs[num_reqs]) continue; + if (MPI_ERR_PENDING == reqs[num_reqs]->req_status.MPI_ERROR) continue; + if (reqs[num_reqs]->req_status.MPI_ERROR != MPI_SUCCESS) { + err = reqs[num_reqs]->req_status.MPI_ERROR; + break; + } + } + } + ompi_coll_base_free_reqs(reqs, max_reqs); + } OPAL_OUTPUT((ompi_coll_base_framework.framework_output, "%s:%4d\tError occurred %d, rank %2d", __FILE__, line, err, rank)); diff --git a/ompi/mca/coll/base/coll_base_functions.h b/ompi/mca/coll/base/coll_base_functions.h index 935ae796eb1..dc184687664 100644 --- a/ompi/mca/coll/base/coll_base_functions.h +++ b/ompi/mca/coll/base/coll_base_functions.h @@ -215,7 +215,7 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(ALLREDUCE_ARGS); /* AlltoAll */ int ompi_coll_base_alltoall_intra_pairwise(ALLTOALL_ARGS); -int ompi_coll_base_alltoall_intra_bruck(ALLTOALL_ARGS); +int ompi_coll_base_alltoall_intra_k_bruck(ALLTOALL_ARGS, int radix); int ompi_coll_base_alltoall_intra_basic_linear(ALLTOALL_ARGS); int ompi_coll_base_alltoall_intra_linear_sync(ALLTOALL_ARGS, int max_requests); int ompi_coll_base_alltoall_intra_two_procs(ALLTOALL_ARGS); diff --git a/ompi/mca/coll/tuned/coll_tuned_alltoall_decision.c b/ompi/mca/coll/tuned/coll_tuned_alltoall_decision.c index 487f9da4fde..aa524e31acc 100644 --- a/ompi/mca/coll/tuned/coll_tuned_alltoall_decision.c +++ b/ompi/mca/coll/tuned/coll_tuned_alltoall_decision.c @@ -173,7 +173,7 @@ int ompi_coll_tuned_alltoall_intra_do_this(const void *sbuf, int scount, case (2): return ompi_coll_base_alltoall_intra_pairwise(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, module); case (3): - return ompi_coll_base_alltoall_intra_bruck(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, module); + return ompi_coll_base_alltoall_intra_k_bruck(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, module, faninout); case (4): return ompi_coll_base_alltoall_intra_linear_sync(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, module, max_requests); case (5): diff --git a/ompi/mca/coll/tuned/coll_tuned_decision_fixed.c b/ompi/mca/coll/tuned/coll_tuned_decision_fixed.c index 35b31806ecf..217ffbee8d6 100644 --- a/ompi/mca/coll/tuned/coll_tuned_decision_fixed.c +++ b/ompi/mca/coll/tuned/coll_tuned_decision_fixed.c @@ -404,10 +404,11 @@ int ompi_coll_tuned_alltoall_intra_dec_fixed(const void *sbuf, int scount, } } + int faninout = 2; return ompi_coll_tuned_alltoall_intra_do_this (sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, module, - alg, 0, 0, ompi_coll_tuned_alltoall_max_requests); + alg, faninout, 0, ompi_coll_tuned_alltoall_max_requests); } /*