From 1251d735bab4486c1dac53a17ede0ced9b9b6577 Mon Sep 17 00:00:00 2001 From: Alex Mikheev Date: Mon, 12 Feb 2018 12:38:18 +0200 Subject: [PATCH] oshmem: scoll: fixes strided alltoall Signed-off-by: Alex Mikheev (cherry picked from commit cca67a69ea022c24ee4e3fb9ab90d97e8348833e) --- oshmem/mca/scoll/basic/scoll_basic_alltoall.c | 206 ++++++++++++------ 1 file changed, 137 insertions(+), 69 deletions(-) diff --git a/oshmem/mca/scoll/basic/scoll_basic_alltoall.c b/oshmem/mca/scoll/basic/scoll_basic_alltoall.c index cc97a05f21b..7a085ac113d 100644 --- a/oshmem/mca/scoll/basic/scoll_basic_alltoall.c +++ b/oshmem/mca/scoll/basic/scoll_basic_alltoall.c @@ -19,13 +19,19 @@ #include "oshmem/mca/scoll/base/base.h" #include "scoll_basic.h" -static int _algorithm_simple(struct oshmem_group_t *group, - void *target, - const void *source, - ptrdiff_t dst, ptrdiff_t sst, - size_t nelems, - size_t element_size, - long *pSync); +static int a2a_alg_simple(struct oshmem_group_t *group, + void *target, + const void *source, + size_t nelems, + size_t element_size); + +static int a2as_alg_simple(struct oshmem_group_t *group, + void *target, + const void *source, + ptrdiff_t dst, ptrdiff_t sst, + size_t nelems, + size_t element_size); + int mca_scoll_basic_alltoall(struct oshmem_group_t *group, void *target, @@ -36,88 +42,150 @@ int mca_scoll_basic_alltoall(struct oshmem_group_t *group, long *pSync, int alg) { - int rc = OSHMEM_SUCCESS; + int rc; + int i; /* Arguments validation */ if (!group) { SCOLL_ERROR("Active set (group) of PE is not defined"); - rc = OSHMEM_ERR_BAD_PARAM; + return OSHMEM_ERR_BAD_PARAM; } /* Check if this PE is part of the group */ - if ((rc == OSHMEM_SUCCESS) && oshmem_proc_group_is_member(group)) { - int i = 0; - - if (pSync) { - rc = _algorithm_simple(group, - target, - source, - dst, - sst, - nelems, - element_size, - pSync); - } else { - SCOLL_ERROR("Incorrect argument pSync"); - rc = OSHMEM_ERR_BAD_PARAM; - } - - /* Restore initial values */ - SCOLL_VERBOSE(12, - "PE#%d Restore special synchronization array", - group->my_pe); - for (i = 0; pSync && (i < _SHMEM_ALLTOALL_SYNC_SIZE); i++) { - pSync[i] = _SHMEM_SYNC_VALUE; - } + if (!oshmem_proc_group_is_member(group)) { + return OSHMEM_SUCCESS; } - return rc; -} + if (!pSync) { + SCOLL_ERROR("Incorrect argument pSync"); + return OSHMEM_ERR_BAD_PARAM; + } -static int _algorithm_simple(struct oshmem_group_t *group, - void *target, - const void *source, - ptrdiff_t tst, ptrdiff_t sst, - size_t nelems, - size_t element_size, - long *pSync) -{ - int rc = OSHMEM_SUCCESS; - int pe_cur; - int i; - int j; - int k; + if ((sst == 1) && (dst == 1)) { + rc = a2a_alg_simple(group, target, source, nelems, element_size); + } else { + rc = a2as_alg_simple(group, target, source, dst, sst, nelems, + element_size); + } - SCOLL_VERBOSE(14, - "[#%d] send data to all PE in the group", - group->my_pe); - j = oshmem_proc_group_find_id(group, group->my_pe); - for (i = 0; i < group->proc_count; i++) { - /* index permutation for better distribution of traffic */ - k = (((j)+(i))%(group->proc_count)); - pe_cur = oshmem_proc_pe(group->proc_array[k]); - rc = MCA_SPML_CALL(put( - (void *)((char *)target + j * tst * nelems * element_size), - nelems * element_size, - (void *)((char *)source + i * sst * nelems * element_size), - pe_cur)); - if (OSHMEM_SUCCESS != rc) { - break; - } + if (rc != OSHMEM_SUCCESS) { + return rc; } + /* fence (which currently acts as quiet) is needed * because scoll level barrier does not guarantee put completion */ MCA_SPML_CALL(fence()); /* Wait for operation completion */ - if (rc == OSHMEM_SUCCESS) { - SCOLL_VERBOSE(14, "[#%d] Wait for operation completion", group->my_pe); - rc = BARRIER_FUNC(group, - (pSync + 1), - SCOLL_DEFAULT_ALG); + SCOLL_VERBOSE(14, "[#%d] Wait for operation completion", group->my_pe); + rc = BARRIER_FUNC(group, pSync + 1, SCOLL_DEFAULT_ALG); + + /* Restore initial values */ + SCOLL_VERBOSE(12, "PE#%d Restore special synchronization array", + group->my_pe); + + for (i = 0; pSync && (i < _SHMEM_ALLTOALL_SYNC_SIZE); i++) { + pSync[i] = _SHMEM_SYNC_VALUE; } return rc; } + +static inline void * +get_stride_elem(const void *base, ptrdiff_t sst, size_t nelems, size_t elem_size, + int block_idx, int elem_idx) +{ + /* + * j th block starts at: nelems * element_size * sst * j + * offset of the l th element in the block is: element_size * sst * l + */ + return (char *)base + elem_size * sst * (nelems * block_idx + elem_idx); +} + +static inline int +get_dst_pe(struct oshmem_group_t *group, int src_blk_idx, int dst_blk_idx) +{ + int dst_grp_pe; + + /* index permutation for better distribution of traffic */ + dst_grp_pe = (dst_blk_idx + src_blk_idx) % group->proc_count; + + /* convert to the global pe */ + return oshmem_proc_pe(group->proc_array[dst_grp_pe]); +} + +static int a2as_alg_simple(struct oshmem_group_t *group, + void *target, + const void *source, + ptrdiff_t tst, ptrdiff_t sst, + size_t nelems, + size_t element_size) +{ + int rc; + int dst_pe; + int src_blk_idx; + int dst_blk_idx; + size_t elem_idx; + + SCOLL_VERBOSE(14, + "[#%d] send data to all PE in the group", + group->my_pe); + + dst_blk_idx = oshmem_proc_group_find_id(group, group->my_pe); + + for (src_blk_idx = 0; src_blk_idx < group->proc_count; src_blk_idx++) { + + dst_pe = get_dst_pe(group, src_blk_idx, dst_blk_idx); + for (elem_idx = 0; elem_idx < nelems; elem_idx++) { + rc = MCA_SPML_CALL(put( + get_stride_elem(target, tst, nelems, element_size, + dst_blk_idx, elem_idx), + element_size, + get_stride_elem(source, sst, nelems, element_size, + src_blk_idx, elem_idx), + dst_pe)); + if (OSHMEM_SUCCESS != rc) { + return rc; + } + } + } + return OSHMEM_SUCCESS; +} + +static int a2a_alg_simple(struct oshmem_group_t *group, + void *target, + const void *source, + size_t nelems, + size_t element_size) +{ + int rc; + int dst_pe; + int src_blk_idx; + int dst_blk_idx; + void *dst_blk; + + SCOLL_VERBOSE(14, + "[#%d] send data to all PE in the group", + group->my_pe); + + dst_blk_idx = oshmem_proc_group_find_id(group, group->my_pe); + + /* block start at stride 1 first elem */ + dst_blk = get_stride_elem(target, 1, nelems, element_size, dst_blk_idx, 0); + + for (src_blk_idx = 0; src_blk_idx < group->proc_count; src_blk_idx++) { + + dst_pe = get_dst_pe(group, src_blk_idx, dst_blk_idx); + rc = MCA_SPML_CALL(put(dst_blk, + nelems * element_size, + get_stride_elem(source, 1, nelems, + element_size, src_blk_idx, 0), + dst_pe)); + if (OSHMEM_SUCCESS != rc) { + return rc; + } + } + return OSHMEM_SUCCESS; +}