Skip to content

Commit

Permalink
Update collective implementation. (#10152)
Browse files Browse the repository at this point in the history
* Update collective implementation.

- Cleanup resource during `Finalize` to avoid handling threads in destructor.
- Calculate the size for allgather automatically.
- Use simple allgather for small (smaller than the number of worker) allreduce.
  • Loading branch information
trivialfis authored Mar 30, 2024
1 parent 230010d commit 8bad677
Show file tree
Hide file tree
Showing 31 changed files with 233 additions and 127 deletions.
10 changes: 3 additions & 7 deletions plugin/federated/federated_coll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,15 @@ Coll *FederatedColl::MakeCUDAVar() {

[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) {
if (comm.Rank() == root) {
return BroadcastImpl(comm, &this->sequence_number_, data, root);
} else {
return BroadcastImpl(comm, &this->sequence_number_, data, root);
}
return BroadcastImpl(comm, &this->sequence_number_, data, root);
}

[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
std::int64_t size) {
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
using namespace federated; // NOLINT
auto fed = dynamic_cast<FederatedComm const *>(&comm);
CHECK(fed);
auto stub = fed->Handle();
auto size = data.size_bytes() / comm.World();

auto offset = comm.Rank() * size;
auto segment = data.subspan(offset, size);
Expand Down
5 changes: 2 additions & 3 deletions plugin/federated/federated_coll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ Coll *FederatedColl::MakeCUDAVar() {
};
}

[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
std::int64_t size) {
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);
std::vector<std::int8_t> h_data(data.size());
Expand All @@ -63,7 +62,7 @@ Coll *FederatedColl::MakeCUDAVar() {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size);
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()});
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
cudaMemcpyHostToDevice, cufed->Stream()));
Expand Down
5 changes: 2 additions & 3 deletions plugin/federated/federated_coll.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#include "../../src/collective/comm.h" // for Comm, Coll
#include "federated_coll.h" // for FederatedColl
Expand All @@ -16,8 +16,7 @@ class CUDAFederatedColl : public Coll {
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
std::int64_t size) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
Expand Down
8 changes: 2 additions & 6 deletions plugin/federated/federated_coll.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#pragma once
#include "../../src/collective/coll.h" // for Coll
#include "../../src/collective/comm.h" // for Comm
#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h" // for Json

namespace xgboost::collective {
class FederatedColl : public Coll {
Expand All @@ -20,8 +17,7 @@ class FederatedColl : public Coll {
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
std::int64_t) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
Expand Down
3 changes: 1 addition & 2 deletions plugin/federated/federated_comm.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once

Expand All @@ -9,7 +9,6 @@
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
#include "federated_comm.h" // for FederatedComm
#include "xgboost/context.h" // for Context
#include "xgboost/logging.h"

namespace xgboost::collective {
class CUDAFederatedComm : public FederatedComm {
Expand Down
15 changes: 13 additions & 2 deletions plugin/federated/federated_comm.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#pragma once

Expand All @@ -11,7 +11,6 @@
#include <string> // for string

#include "../../src/collective/comm.h" // for HostComm
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h"

namespace xgboost::collective {
Expand Down Expand Up @@ -51,6 +50,10 @@ class FederatedComm : public HostComm {
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});
}
[[nodiscard]] Result Shutdown() final {
this->ResetState();
return Success();
}
~FederatedComm() override { stub_.reset(); }

[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
Expand All @@ -65,5 +68,13 @@ class FederatedComm : public HostComm {
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }

[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
/**
* @brief Get a string ID for the current process.
*/
[[nodiscard]] Result ProcessorName(std::string* out) const final {
auto rank = this->Rank();
*out = "rank:" + std::to_string(rank);
return Success();
};
};
} // namespace xgboost::collective
8 changes: 2 additions & 6 deletions plugin/federated/federated_server.h
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
/**
* Copyright 2022-2023, XGBoost contributors
* Copyright 2022-2024, XGBoost contributors
*/
#pragma once

#include <federated.old.grpc.pb.h>

#include <cstdint> // for int32_t
#include <future> // for future

#include "../../src/collective/in_memory_handler.h"
#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result

namespace xgboost::federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(std::int32_t world_size)
: handler_{static_cast<std::size_t>(world_size)} {}
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}

grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
Expand Down
3 changes: 1 addition & 2 deletions plugin/federated/federated_tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ namespace xgboost::collective {
namespace federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(std::int32_t world_size)
: handler_{static_cast<std::size_t>(world_size)} {}
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}

grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
Expand Down
6 changes: 3 additions & 3 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -694,9 +694,9 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void
common::Span<T>{cast_d_ptr, static_cast<typename common::Span<T>::index_type>(size)},
{size}, DeviceOrd::CPU());
CHECK(t.CContiguous());
Json interface{linalg::ArrayInterface(t)};
CHECK(ArrayInterface<1>{interface}.is_contiguous);
str = Json::Dump(interface);
Json iface{linalg::ArrayInterface(t)};
CHECK(ArrayInterface<1>{iface}.is_contiguous);
str = Json::Dump(iface);
return str;
};

Expand Down
27 changes: 20 additions & 7 deletions src/c_api/coll_c_api.cc
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include <chrono> // for seconds
#include <cstddef> // for size_t
#include <future> // for future
#include <memory> // for unique_ptr
#include <string> // for string
#include <type_traits> // for is_same_v, remove_pointer_t
#include <utility> // for pair

#include "../collective/tracker.h" // for RabitTracker
#include "../common/timer.h" // for Timer
#include "c_api_error.h" // for API_BEGIN
#include "xgboost/c_api.h"
#include "xgboost/collective/result.h" // for Result
Expand Down Expand Up @@ -40,17 +40,27 @@ struct CollAPIEntry {
};
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;

void WaitImpl(TrackerHandleT *ptr) {
std::chrono::seconds wait_for{100};
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
constexpr std::int64_t kDft{60};
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};

common::Timer timer;
timer.Start();

auto fut = ptr->second;
while (fut.valid()) {
auto res = fut.wait_for(wait_for);
CHECK(res != std::future_status::deferred);

if (res == std::future_status::ready) {
auto const &rc = ptr->second.get();
CHECK(rc.OK()) << rc.Report();
collective::SafeColl(rc);
break;
}

if (timer.Duration() > timeout && timeout.count() != 0) {
collective::SafeColl(collective::Fail("Timeout waiting for the tracker."));
}
}
}
} // namespace
Expand Down Expand Up @@ -106,14 +116,17 @@ XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
auto *ptr = GetTrackerHandle(handle);
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
WaitImpl(ptr);
// Internally, 0 indicates no timeout, which is the default since we don't want to
// interrupt the model training.
auto timeout = OptionalArg<Integer>(jconfig, "timeout", std::int64_t{0});
WaitImpl(ptr, std::chrono::seconds{timeout});
API_END();
}

XGB_DLL int XGTrackerFree(TrackerHandle handle) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
WaitImpl(ptr);
WaitImpl(ptr, ptr->first->Timeout());
delete ptr;
API_END();
}
21 changes: 14 additions & 7 deletions src/collective/allgather.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include "allgather.h"

#include <algorithm> // for min, copy_n, fill_n
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t, int64_t
#include <memory> // for shared_ptr
#include <utility> // for move

#include "broadcast.h"
#include "comm.h" // for Comm, Channel
Expand All @@ -29,16 +30,20 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
auto rc = Success() << [&] {
auto send_rank = (rank + world - r + worker_off) % world;
auto send_off = send_rank * segment_size;
send_off = std::min(send_off, data.size_bytes());
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
bool is_last_segment = send_rank == (world - 1);
auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size;
auto send_seg = data.subspan(send_off, send_nbytes);
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
} << [&] {
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
auto recv_off = recv_rank * segment_size;
recv_off = std::min(recv_off, data.size_bytes());
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
bool is_last_segment = recv_rank == (world - 1);
auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size;
auto recv_seg = data.subspan(recv_off, recv_nbytes);
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
} << [&] { return prev_ch->Block(); };
} << [&] {
return prev_ch->Block();
};
if (!rc.OK()) {
return rc;
}
Expand Down Expand Up @@ -91,7 +96,9 @@ namespace detail {
auto recv_size = sizes[recv_rank];
auto recv_seg = erased_result.subspan(recv_off, recv_size);
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
} << [&] { return prev_ch->Block(); };
} << [&] {
return prev_ch->Block();
};
if (!rc.OK()) {
return rc;
}
Expand Down
22 changes: 13 additions & 9 deletions src/collective/allgather.h
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include <numeric> // for accumulate
#include <string> // for string
#include <type_traits> // for remove_cv_t
#include <vector> // for vector

#include "../common/type.h" // for EraseType
#include "../common/type.h" // for EraseType
#include "comm.h" // for Comm, Channel
#include "comm_group.h" // for CommGroup
#include "xgboost/collective/result.h" // for Result
#include "xgboost/linalg.h"
#include "xgboost/span.h" // for Span
#include "xgboost/linalg.h" // for MakeVec
#include "xgboost/span.h" // for Span

namespace xgboost::collective {
namespace cpu_impl {
/**
* @param worker_off Segment offset. For example, if the rank 2 worker specifies
* worker_off = 1, then it owns the third segment.
* worker_off = 1, then it owns the third segment (2 + 1).
*/
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
std::size_t segment_size, std::int32_t worker_off,
Expand Down Expand Up @@ -51,8 +53,10 @@ inline void AllgatherVOffset(common::Span<std::int64_t const> sizes,
} // namespace detail

template <typename T>
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
auto n_bytes = sizeof(T) * size;
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data) {
// This function is also used for ring allreduce, hence we allow the last segment to be
// larger due to round-down.
auto n_bytes_per_segment = data.size_bytes() / comm.World();
auto erased = common::EraseType(data);

auto rank = comm.Rank();
Expand All @@ -61,7 +65,7 @@ template <typename T>

auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch);
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes_per_segment, 0, prev_ch, next_ch);
if (!rc.OK()) {
return rc;
}
Expand All @@ -76,7 +80,7 @@ template <typename T>

std::vector<std::int64_t> sizes(world, 0);
sizes[rank] = data.size_bytes();
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1);
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()});
if (!rc.OK()) {
return rc;
}
Expand Down
Loading

0 comments on commit 8bad677

Please sign in to comment.