Skip to content

Commit

Permalink
small allreduce.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 25, 2024
1 parent 9a625aa commit 3f702cf
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/c_api/coll_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <utility> // for pair

#include "../collective/allgather.h" // for Allgather
#include "../collective/allreduce.h" // for Allreduce
#include "../collective/broadcast.h" // for Broadcast
#include "../collective/comm.h" // for GlobalCommGroup
#include "../collective/communicator-inl.h" // for GetRank
Expand Down
3 changes: 1 addition & 2 deletions src/collective/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <vector> // for vector

#include "../common/type.h" // for EraseType
#include "allreduce.h" // for Allreduce
#include "comm.h" // for Comm, Channel
#include "comm_group.h" // for CommGroup
#include "xgboost/collective/result.h" // for Result
Expand All @@ -22,7 +21,7 @@ namespace xgboost::collective {
namespace cpu_impl {
/**
* @param worker_off Segment offset. For example, if the rank 2 worker specifies
* worker_off = 1, then it owns the third segment.
* worker_off = 1, then it owns the third segment (2 + 1).
*/
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
std::size_t segment_size, std::int32_t worker_off,
Expand Down
39 changes: 39 additions & 0 deletions src/collective/allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,42 @@
#include "xgboost/span.h" // for Span

namespace xgboost::collective::cpu_impl {
namespace {
template <typename T>
Result RingAllreduceSmall(Comm const& comm, common::Span<std::int8_t> data, Func const& op) {
auto rank = comm.Rank();
auto world = comm.World();

auto next_ch = comm.Chan(BootstrapNext(rank, world));
auto prev_ch = comm.Chan(BootstrapPrev(rank, world));

std::vector<std::int8_t> buffer(data.size_bytes() * world, 0);
auto s_buffer = common::Span{buffer.data(), buffer.size()};

auto offset = data.size_bytes() * rank;
auto self = s_buffer.subspan(offset, data.size_bytes());
std::copy_n(data.data(), data.size_bytes(), self.data());

auto typed = common::RestoreType<T>(s_buffer);
auto rc = RingAllgather(comm, typed, common::RestoreType<T>(data).size());

if (!rc.OK()) {
return rc;
}
auto first = s_buffer.subspan(0, data.size_bytes());
CHECK_EQ(first.size(), data.size());

for (std::int32_t r = 1; r < world; ++r) {
auto offset = data.size_bytes() * r;
auto buf = s_buffer.subspan(offset, data.size_bytes());
op(buf, first);
}
std::copy_n(first.data(), first.size(), data.data());

return Success();
}
} // namespace

template <typename T>
Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
std::size_t n_bytes_in_seg, Func const& op) {
Expand Down Expand Up @@ -80,6 +116,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
auto n = data.size_bytes() / n_bytes_elem;
auto world = comm.World();
auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T);
if (n < static_cast<decltype(n)>(comm.World())) {
return RingAllreduceSmall<T>(comm, data, op);
}
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
if (!rc.OK()) {
return rc;
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/collective/test_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ TEST(AllreduceGlobal, Basic) {

TEST(AllreduceGlobal, Small) {
// Test when the data is not large enougth to be divided by the number of workers
auto n_workers = 4;
auto n_workers = 8;
TestDistributedGlobal(n_workers, [&]() {
std::uint64_t value{1};
Context ctx;
Expand Down

0 comments on commit 3f702cf

Please sign in to comment.