diff --git a/ompi/mca/coll/acoll/coll_acoll_allgather.c b/ompi/mca/coll/acoll/coll_acoll_allgather.c index bb7ccf5bbf8..f0e2110402f 100644 --- a/ompi/mca/coll/acoll/coll_acoll_allgather.c +++ b/ompi/mca/coll/acoll/coll_acoll_allgather.c @@ -268,6 +268,9 @@ static inline int mca_coll_acoll_allgather_intra(const void *sbuf, int scount, data_blk_size[0] = bcount * (num_sgs - 2) + last_subgrp_rcnt; blk_ofst[0] = bcount; } else if (sg_id == num_sgs - 1) { + if (last_subgrp_size < 2) { + return err; + } num_data_blks = 1; data_blk_size[0] = bcount * (num_sgs - 1); blk_ofst[0] = 0; @@ -329,8 +332,7 @@ int mca_coll_acoll_allgather(const void *sbuf, int scount, struct ompi_datatype_ int i; int err; int size; - int rank, adj_rank; - int num_sgs; + int rank; int sg_size, log2_sg_size; int num_nodes, node_start, node_end, node_id; int node_size, last_node_size; @@ -388,7 +390,9 @@ int mca_coll_acoll_allgather(const void *sbuf, int scount, struct ompi_datatype_ if (size <= 2) { intra_comm = comm; } else { - assert(subc->local_r_comm != NULL); + if (num_nodes > 1) { + assert(subc->local_r_comm != NULL); + } intra_comm = num_nodes == 1 ? comm : subc->local_r_comm; } err = mca_coll_acoll_allgather_intra(sbuf, scount, sdtype, local_rbuf, rcount, rdtype, @@ -454,12 +458,14 @@ int mca_coll_acoll_allgather(const void *sbuf, int scount, struct ompi_datatype_ } /* End of if inter leader */ /* Do intra node broadcast */ - num_sgs = (node_size + sg_size - 1) >> log2_sg_size; if (node_id == 0) { num_data_blks = 1; data_blk_size[0] = bcount * (num_nodes - 2) + last_subgrp_rcnt; blk_ofst[0] = bcount; } else if (node_id == num_nodes - 1) { + if (last_node_size < 2) { + return err; + } num_data_blks = 1; data_blk_size[0] = bcount * (num_nodes - 1); blk_ofst[0] = 0; diff --git a/ompi/mca/coll/acoll/coll_acoll_barrier.c b/ompi/mca/coll/acoll/coll_acoll_barrier.c index a07900d7503..06a9ef070d7 100644 --- a/ompi/mca/coll/acoll/coll_acoll_barrier.c +++ b/ompi/mca/coll/acoll/coll_acoll_barrier.c @@ -125,7 +125,6 @@ static int mca_coll_acoll_barrier_send_subc(struct ompi_communicator_t *comm, int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { int size, ssize, bsize; - int srank; int err = MPI_SUCCESS; int nreqs = 0; ompi_request_t **reqs; @@ -141,6 +140,9 @@ int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base subc = &acoll_module->subc[cid]; size = ompi_comm_size(comm); + if (size == 1) { + return err; + } if (!subc->initialized && size > 1) { err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0); if (MPI_SUCCESS != err) { diff --git a/ompi/mca/coll/acoll/coll_acoll_bcast.c b/ompi/mca/coll/acoll/coll_acoll_bcast.c index 38eb185695d..39bbb73f56e 100644 --- a/ompi/mca/coll/acoll/coll_acoll_bcast.c +++ b/ompi/mca/coll/acoll/coll_acoll_bcast.c @@ -37,7 +37,7 @@ static int bcast_binomial(void *buff, int count, struct ompi_datatype_t *datatyp struct ompi_communicator_t *comm, ompi_request_t **preq, int *nreqs, int world_rank) { - int msb_pos, sub_rank, peer, err; + int msb_pos, sub_rank, peer, err = MPI_SUCCESS; int size, rank, dim; int i, mask; @@ -83,7 +83,7 @@ static int bcast_flat_tree(void *buff, int count, struct ompi_datatype_t *dataty int world_rank) { int peer; - int err; + int err = MPI_SUCCESS; int rank = ompi_comm_rank(comm); int size = ompi_comm_size(comm); diff --git a/ompi/mca/coll/acoll/coll_acoll_gather.c b/ompi/mca/coll/acoll/coll_acoll_gather.c index f8a0ecc319b..31897f43777 100644 --- a/ompi/mca/coll/acoll/coll_acoll_gather.c +++ b/ompi/mca/coll/acoll/coll_acoll_gather.c @@ -43,17 +43,16 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty int i, err, rank, size; char *wkg = NULL, *workbuf = NULL; MPI_Status status; - MPI_Aint incr, extent, lb; MPI_Aint sextent, sgap = 0, ssize; - MPI_Aint rextent, rgap = 0, rsize; + MPI_Aint rextent; int total_recv = 0; int sg_cnt, node_cnt; int cur_sg, root_sg; int cur_node, root_node; int is_base, is_local_root; int startr, endr, inc; - int startn, endn, incn; - int num_nodes, node_id; + int startn, endn; + int num_nodes; mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; coll_acoll_reserve_mem_t *reserve_mem_gather = &(acoll_module->reserve_mem_s); @@ -70,17 +69,13 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty num_nodes = 1; } - ompi_datatype_get_extent(rdtype, &lb, &extent); - incr = extent * (ptrdiff_t) rcount; - - /* Setup root for reveive */ + /* Setup root for receive */ if (rank == root) { ompi_datatype_type_extent(rdtype, &rextent); - rsize = opal_datatype_span(&rdtype->super, (int64_t) rcount * size, &rgap); /* Just use the recv buffer */ wkg = (char *) rbuf; if (sbuf != MPI_IN_PLACE) { - MPI_Aint root_ofst = extent * (ptrdiff_t) (rcount * root); + MPI_Aint root_ofst = rextent * (ptrdiff_t) (rcount * root); err = ompi_datatype_sndrcv((void *) sbuf, scount, sdtype, wkg + (ptrdiff_t) root_ofst, rcount, rdtype); if (MPI_SUCCESS != err) { @@ -100,7 +95,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty is_local_root = (rank % node_cnt == 0) && (cur_node != root_node); startn = (rank / node_cnt) * node_cnt; - if (is_base || (rank == root)) { + if (is_base) { int64_t buf_size = is_local_root ? (int64_t) scount * node_cnt : (int64_t) scount * sg_cnt; ompi_datatype_type_extent(sdtype, &sextent); ssize = opal_datatype_span(&sdtype->super, buf_size, &sgap); @@ -111,7 +106,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty return OMPI_ERR_OUT_OF_RESOURCE; } wkg = workbuf - sgap; - tmprecv = wkg + extent * (ptrdiff_t) (rcount * (rank - startr)); + tmprecv = wkg + sextent * (ptrdiff_t) (rcount * (rank - startr)); /* local copy to workbuf */ err = ompi_datatype_sndrcv((void *) sbuf, scount, sdtype, tmprecv, scount, sdtype); if (MPI_SUCCESS != err) { @@ -123,7 +118,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty rcount = scount; rextent = sextent; total_recv = rcount; - } else { + } else if (rank != root) { wkg = (char *) sbuf; total_recv = scount; } @@ -141,9 +136,9 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty continue; } if (rank == root) { - tmprecv = wkg + extent * (ptrdiff_t) (rcount * i); + tmprecv = wkg + rextent * (ptrdiff_t) (rcount * i); } else { - tmprecv = wkg + extent * (ptrdiff_t) (rcount * (i - startr)); + tmprecv = wkg + rextent * (ptrdiff_t) (rcount * (i - startr)); } err = MCA_PML_CALL( recv(tmprecv, rcount, rdtype, i, MCA_COLL_BASE_TAG_GATHER, comm, &status)); @@ -161,10 +156,9 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty if (endn > size) { endn = size; } - incn = (rank == root) ? ((root != startn) ? 0 : sg_cnt) : sg_cnt; if (sg_cnt < size) { int local_root = (root_node == cur_node) ? root : startn; - for (i = startn + incn; i < endn; i += sg_cnt) { + for (i = startn; i < endn; i += sg_cnt) { int i_sg = i / sg_cnt; if ((rank != local_root) && (rank == i) && is_base) { err = MCA_PML_CALL(send(workbuf - sgap, total_recv, sdtype, local_root, @@ -173,7 +167,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty } if ((rank == local_root) && (rank != i) && (i_sg != root_sg)) { int recv_amt = (i + sg_cnt > size) ? rcount * (size - i) : rcount * sg_cnt; - MPI_Aint rcv_ofst = extent * (ptrdiff_t) (rcount * (i - startn)); + MPI_Aint rcv_ofst = rextent * (ptrdiff_t) (rcount * (i - startn)); err = MCA_PML_CALL(recv(wkg + (ptrdiff_t) rcv_ofst, recv_amt, rdtype, i, MCA_COLL_BASE_TAG_GATHER, comm, &status)); @@ -189,7 +183,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty } /* All local roots ranks send to root */ - if (node_cnt < size) { + if (node_cnt < size && num_nodes > 1) { for (i = 0; i < size; i += node_cnt) { int i_node = i / node_cnt; if ((rank != root) && (rank == i) && is_base) { @@ -199,7 +193,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty } if ((rank == root) && (rank != i) && (i_node != root_node)) { int recv_amt = (i + node_cnt > size) ? rcount * (size - i) : rcount * node_cnt; - MPI_Aint rcv_ofst = extent * (ptrdiff_t) (rcount * i); + MPI_Aint rcv_ofst = rextent * (ptrdiff_t) (rcount * i); err = MCA_PML_CALL(recv((char *) rbuf + (ptrdiff_t) rcv_ofst, recv_amt, rdtype, i, MCA_COLL_BASE_TAG_GATHER, comm, &status)); diff --git a/ompi/mca/coll/acoll/coll_acoll_module.c b/ompi/mca/coll/acoll/coll_acoll_module.c index d35a9db94df..9a60d086931 100644 --- a/ompi/mca/coll/acoll/coll_acoll_module.c +++ b/ompi/mca/coll/acoll/coll_acoll_module.c @@ -41,6 +41,15 @@ mca_coll_base_module_t *mca_coll_acoll_comm_query(struct ompi_communicator_t *co return NULL; } + if (OMPI_COMM_IS_INTER(comm)) { + *priority = 0; + return NULL; + } + if (OMPI_COMM_IS_INTRA(comm) && ompi_comm_size(comm) < 2) { + *priority = 0; + return NULL; + } + *priority = mca_coll_acoll_priority; /* Set topology params */ diff --git a/ompi/mca/coll/acoll/coll_acoll_reduce.c b/ompi/mca/coll/acoll/coll_acoll_reduce.c index 08afa893ee7..82082bb9681 100644 --- a/ompi/mca/coll/acoll/coll_acoll_reduce.c +++ b/ompi/mca/coll/acoll/coll_acoll_reduce.c @@ -382,11 +382,11 @@ int mca_coll_acoll_reduce_intra(const void *sbuf, void *rbuf, int count, module); } else { return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, - root, comm, module, 0, 0); + root, comm, module, 0, 0); } #else return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, root, - comm, module, 0, 0); + comm, module, 0, 0); #endif } } else { diff --git a/ompi/mca/coll/acoll/coll_acoll_utils.h b/ompi/mca/coll/acoll/coll_acoll_utils.h index a5223a77f91..5832c067ac5 100644 --- a/ompi/mca/coll/acoll/coll_acoll_utils.h +++ b/ompi/mca/coll/acoll/coll_acoll_utils.h @@ -262,6 +262,9 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm, mca_coll_base_module_allreduce_fn_t coll_allreduce_org = (comm)->c_coll->coll_allreduce; mca_coll_base_module_allgather_fn_t coll_allgather_org = (comm)->c_coll->coll_allgather; mca_coll_base_module_bcast_fn_t coll_bcast_org = (comm)->c_coll->coll_bcast; + mca_coll_base_module_allreduce_fn_t coll_allreduce_loc, coll_allreduce_soc; + mca_coll_base_module_allgather_fn_t coll_allgather_loc, coll_allgather_soc; + mca_coll_base_module_bcast_fn_t coll_bcast_loc, coll_bcast_soc; coll_acoll_subcomms_t *subc; int err; int size = ompi_comm_size(comm); @@ -362,6 +365,21 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm, subc->base_root[MCA_COLL_ACOLL_L3CACHE][i] = -1; subc->base_root[MCA_COLL_ACOLL_NUMA][i] = -1; } + /* Store original collectives for local and socket comms */ + coll_allreduce_loc = (subc->local_comm)->c_coll->coll_allreduce; + coll_allgather_loc = (subc->local_comm)->c_coll->coll_allgather; + coll_bcast_loc = (subc->local_comm)->c_coll->coll_bcast; + (subc->local_comm)->c_coll->coll_allgather = ompi_coll_base_allgather_intra_ring; + (subc->local_comm)->c_coll->coll_allreduce + = ompi_coll_base_allreduce_intra_recursivedoubling; + (subc->local_comm)->c_coll->coll_bcast = ompi_coll_base_bcast_intra_basic_linear; + coll_allreduce_soc = (subc->socket_comm)->c_coll->coll_allreduce; + coll_allgather_soc = (subc->socket_comm)->c_coll->coll_allgather; + coll_bcast_soc = (subc->socket_comm)->c_coll->coll_bcast; + (subc->socket_comm)->c_coll->coll_allgather = ompi_coll_base_allgather_intra_ring; + (subc->socket_comm)->c_coll->coll_allreduce + = ompi_coll_base_allreduce_intra_recursivedoubling; + (subc->socket_comm)->c_coll->coll_bcast = ompi_coll_base_bcast_intra_basic_linear; } /* Further subcommunicators based on root */ @@ -519,6 +537,14 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm, } } + /* Restore originals for local and socket comms */ + (subc->local_comm)->c_coll->coll_allreduce = coll_allreduce_loc; + (subc->local_comm)->c_coll->coll_allgather = coll_allgather_loc; + (subc->local_comm)->c_coll->coll_bcast = coll_bcast_loc; + (subc->socket_comm)->c_coll->coll_allreduce = coll_allreduce_soc; + (subc->socket_comm)->c_coll->coll_allgather = coll_allgather_soc; + (subc->socket_comm)->c_coll->coll_bcast = coll_bcast_soc; + /* For collectives where order is important (like gather, allgather), * split based on ranks. This is optimal for global communicators with * equal split among nodes, but suboptimal for other cases. @@ -590,6 +616,7 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica data = (coll_acoll_data_t *) malloc(sizeof(coll_acoll_data_t)); if (NULL == data) { line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; goto error_hndl; } size = ompi_comm_size(comm); @@ -601,6 +628,7 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica data->scratch = (char *) malloc(subc->xpmem_buf_size); if (NULL == data->scratch) { line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; goto error_hndl; } } else { @@ -611,41 +639,49 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica data->allseg_id = (xpmem_segid_t *) malloc(sizeof(xpmem_segid_t) * size); if (NULL == data->allseg_id) { line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; goto error_hndl; } data->all_apid = (xpmem_apid_t *) malloc(sizeof(xpmem_apid_t) * size); if (NULL == data->all_apid) { line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; goto error_hndl; } data->allshm_sbuf = (void **) malloc(sizeof(void *) * size); if (NULL == data->allshm_sbuf) { line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; goto error_hndl; } data->allshm_rbuf = (void **) malloc(sizeof(void *) * size); if (NULL == data->allshm_rbuf) { line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; goto error_hndl; } data->xpmem_saddr = (void **) malloc(sizeof(void *) * size); if (NULL == data->xpmem_saddr) { line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; goto error_hndl; } data->xpmem_raddr = (void **) malloc(sizeof(void *) * size); if (NULL == data->xpmem_raddr) { line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; goto error_hndl; } data->rcache = (mca_rcache_base_module_t **) malloc(sizeof(mca_rcache_base_module_t *) * size); if (NULL == data->rcache) { line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; goto error_hndl; } seg_id = xpmem_make(0, XPMEM_MAXADDR_SIZE, XPMEM_PERMIT_MODE, (void *) 0666); if (seg_id == -1) { line = __LINE__; + ret = -1; goto error_hndl; }