Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 18 additions & 22 deletions sycl/include/sycl/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1653,15 +1653,13 @@ void initReduLocalAccs(bool Pow2WG, size_t LID, size_t WGSize,
const std::tuple<ReducerT...> &Reducers,
ReduTupleT<ResultT...> Identities,
std::index_sequence<Is...>) {
std::tie(std::get<Is>(LocalAccs)[LID]...) =
std::make_tuple(std::get<Is>(Reducers).MValue...);
((std::get<Is>(LocalAccs)[LID] = std::get<Is>(Reducers).MValue), ...);

// For work-groups, which size is not power of two, local accessors have
// an additional element with index WGSize that is used by the tree-reduction
// algorithm. Initialize those additional elements with identity values here.
if (!Pow2WG)
std::tie(std::get<Is>(LocalAccs)[WGSize]...) =
std::make_tuple(std::get<Is>(Identities)...);
((std::get<Is>(LocalAccs)[WGSize] = std::get<Is>(Identities)), ...);
}

template <typename... LocalAccT, typename... InputAccT, typename... ResultT,
Expand All @@ -1679,28 +1677,26 @@ void initReduLocalAccs(bool UniformPow2WG, size_t LID, size_t GID,
// give any impact into the final partial sums during the tree-reduction
// algorithm work.
if (UniformPow2WG || GID < NWorkItems)
std::tie(std::get<Is>(LocalAccs)[LID]...) =
std::make_tuple(std::get<Is>(InputAccs)[GID]...);
((std::get<Is>(LocalAccs)[LID] = std::get<Is>(InputAccs)[GID]), ...);
else
std::tie(std::get<Is>(LocalAccs)[LID]...) =
std::make_tuple(std::get<Is>(Identities)...);
((std::get<Is>(LocalAccs)[LID] = std::get<Is>(Identities)), ...);

// For work-groups, which size is not power of two, local accessors have
// an additional element with index WGSize that is used by the tree-reduction
// algorithm. Initialize those additional elements with identity values here.
if (!UniformPow2WG)
std::tie(std::get<Is>(LocalAccs)[WGSize]...) =
std::make_tuple(std::get<Is>(Identities)...);
((std::get<Is>(LocalAccs)[WGSize] = std::get<Is>(Identities)), ...);
}

template <typename... LocalAccT, typename... BOPsT, size_t... Is>
void reduceReduLocalAccs(size_t IndexA, size_t IndexB,
ReduTupleT<LocalAccT...> LocalAccs,
ReduTupleT<BOPsT...> BOPs,
std::index_sequence<Is...>) {
std::tie(std::get<Is>(LocalAccs)[IndexA]...) =
std::make_tuple((std::get<Is>(BOPs)(std::get<Is>(LocalAccs)[IndexA],
std::get<Is>(LocalAccs)[IndexB]))...);
auto ProcessOne = [=](auto &LocalAcc, auto &BOp) {
LocalAcc[IndexA] = BOp(LocalAcc[IndexA], LocalAcc[IndexB]);
};
(ProcessOne(std::get<Is>(LocalAccs), std::get<Is>(BOPs)), ...);
}

template <typename... Reductions, typename... OutAccT, typename... LocalAccT,
Expand All @@ -1713,23 +1709,23 @@ void writeReduSumsToOutAccs(
std::index_sequence<Is...>) {
// Add the initial value of user's variable to the final result.
if (IsOneWG)
std::tie(std::get<Is>(LocalAccs)[0]...) = std::make_tuple(std::get<Is>(
BOPs)(std::get<Is>(LocalAccs)[0], IsInitializeToIdentity[Is]
? std::get<Is>(IdentityVals)
: std::get<Is>(OutAccs)[0])...);
((std::get<Is>(LocalAccs)[0] = std::get<Is>(BOPs)(
std::get<Is>(LocalAccs)[0], IsInitializeToIdentity[Is]
? std::get<Is>(IdentityVals)
: std::get<Is>(OutAccs)[0])),
...);

if (Pow2WG) {
// The partial sums for the work-group are stored in 0-th elements of local
// accessors. Simply write those sums to output accessors.
std::tie(std::get<Is>(OutAccs)[OutAccIndex]...) =
std::make_tuple(std::get<Is>(LocalAccs)[0]...);
((std::get<Is>(OutAccs)[OutAccIndex] = std::get<Is>(LocalAccs)[0]), ...);
} else {
// Each of local accessors keeps two partial sums: in 0-th and WGsize-th
// elements. Combine them into final partial sums and write to output
// accessors.
std::tie(std::get<Is>(OutAccs)[OutAccIndex]...) =
std::make_tuple(std::get<Is>(BOPs)(std::get<Is>(LocalAccs)[0],
std::get<Is>(LocalAccs)[WGSize])...);
((std::get<Is>(OutAccs)[OutAccIndex] = std::get<Is>(BOPs)(
std::get<Is>(LocalAccs)[0], std::get<Is>(LocalAccs)[WGSize])),
...);
}
}

Expand Down