Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[coll] Small cleanup to federated comm. #10397

Merged
merged 1 commit into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions plugin/federated/federated_comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port),
grpc::SslCredentials(options), args);
channel->WaitForConnected(
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
channel->WaitForConnected(gpr_time_add(
gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(DefaultTimeoutSec(), GPR_TIMESPAN)));
return federated::Federated::NewStub(channel);
}();
}
Expand Down Expand Up @@ -90,8 +90,6 @@ FederatedComm::FederatedComm(std::int32_t retry, std::chrono::seconds timeout, s
auto parsed = common::Split(server_address, ':');
CHECK_EQ(parsed.size(), 2) << "Invalid server address:" << server_address;

CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";

/**
Expand Down
7 changes: 2 additions & 5 deletions plugin/federated/federated_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
#include <federated.grpc.pb.h>
#include <federated.pb.h>

#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <memory> // for unique_ptr
#include <memory> // for shared_ptr
#include <string> // for string

#include "../../src/collective/comm.h" // for HostComm
Expand Down Expand Up @@ -46,10 +47,6 @@ class FederatedComm : public HostComm {
*/
explicit FederatedComm(std::int32_t retry, std::chrono::seconds timeout, std::string task_id,
Json const& config);
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});
}
[[nodiscard]] Result Shutdown() final {
this->ResetState();
return Success();
Expand Down
25 changes: 20 additions & 5 deletions tests/cpp/plugin/federated/test_federated_comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,35 @@
namespace xgboost::collective {
namespace {
class FederatedCommTest : public SocketTest {};
auto MakeConfig(std::string host, std::int32_t port, std::int32_t world, std::int32_t rank) {
Json config{Object{}};
config["federated_server_address"] = host + ":" + std::to_string(port);
config["federated_world_size"] = Integer{world};
config["federated_rank"] = Integer{rank};
return config;
}
} // namespace

TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
auto config = MakeConfig("localhost", 0, 0, 0);
auto construct = [config] {
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
};
ASSERT_THAT(construct, GMockThrow("Invalid world size"));
}

TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
auto config = MakeConfig("localhost", 0, 1, -1);
auto construct = [config] {
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
};
ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
}

TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
auto construct = [] {
FederatedComm comm{"localhost", 0, 1, 1};
auto config = MakeConfig("localhost", 0, 1, 1);
auto construct = [config] {
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
};
ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
}
Expand Down Expand Up @@ -68,7 +82,8 @@ TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
}

TEST_F(FederatedCommTest, IsDistributed) {
FederatedComm comm{"localhost", 0, 2, 1};
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "",
MakeConfig("localhost", 0, 2, 1)};
EXPECT_TRUE(comm.IsDistributed());
}

Expand Down
Loading