Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[coll] Implement shutdown for tracker and comm. #10208

Merged
merged 7 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1555,7 +1555,7 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);

/**
* @brief Get the arguments needed for running workers. This should be called after
* XGTrackerRun() and XGTrackerWait()
* XGTrackerRun().
*
* @param handle The handle to the tracker.
* @param args The arguments returned as a JSON document.
Expand All @@ -1565,28 +1565,32 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);

/**
* @brief Run the tracker.
* @brief Start the tracker. The tracker runs in the background and this function returns
* once the tracker is started.
*
* @param handle The handle to the tracker.
* @param config Unused at the moment, preserved for the future.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerRun(TrackerHandle handle);
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *config);

/**
* @brief Wait for the tracker to finish, should be called after XGTrackerRun().
* @brief Wait for the tracker to finish, should be called after XGTrackerRun(). This
* function will block until the tracker task is finished or timeout is reached.
*
* @param handle The handle to the tracker.
* @param config JSON encoded configuration. No argument is required yet, preserved for
* the future.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config);
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);

/**
* @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker
* cannot close properly, manual interruption is required.
* @brief Free a tracker instance. This should be called after XGTrackerWaitFor(). If the
* tracker is not properly waited, this function will shutdown all connections with
* the tracker, potentially leading to undefined behavior.
*
* @param handle The handle to the tracker.
*
Expand Down
81 changes: 64 additions & 17 deletions include/xgboost/collective/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,21 @@ inline std::int32_t CloseSocket(SocketT fd) {
#endif
}

inline std::int32_t ShutdownSocket(SocketT fd) {
#if defined(_WIN32)
auto rc = shutdown(fd, SD_BOTH);
if (rc != 0 && LastError() == WSANOTINITIALISED) {
return 0;
}
#else
auto rc = shutdown(fd, SHUT_RDWR);
if (rc != 0 && LastError() == ENOTCONN) {
return 0;
}
#endif
return rc;
}

inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
Expand Down Expand Up @@ -499,36 +514,49 @@ class TCPSocket {
*/
[[nodiscard]] HandleT const &Handle() const { return handle_; }
/**
* \brief Listen to incoming requests. Should be called after bind.
* @brief Listen to incoming requests. Should be called after bind.
*/
void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
[[nodiscard]] Result Listen(std::int32_t backlog = 16) {
if (listen(handle_, backlog) != 0) {
return system::FailWithCode("Failed to listen.");
}
return Success();
}
/**
* \brief Bind socket to INADDR_ANY, return the port selected by the OS.
* @brief Bind socket to INADDR_ANY, return the port selected by the OS.
*/
[[nodiscard]] in_port_t BindHost() {
[[nodiscard]] Result BindHost(std::int32_t* p_out) {
// Use int32 instead of in_port_t for consistency. We take port as parameter from
// users using other languages, the port is usually stored and passed around as int.
if (Domain() == SockDomain::kV6) {
auto addr = SockAddrV6::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
xgboost_CHECK_SYS_CALL(
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
return system::FailWithCode("bind failed.");
}

sockaddr_in6 res_addr;
socklen_t addrlen = sizeof(res_addr);
xgboost_CHECK_SYS_CALL(
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
return ntohs(res_addr.sin6_port);
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
return system::FailWithCode("getsockname failed.");
}
*p_out = ntohs(res_addr.sin6_port);
} else {
auto addr = SockAddrV4::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
xgboost_CHECK_SYS_CALL(
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
return system::FailWithCode("bind failed.");
}

sockaddr_in res_addr;
socklen_t addrlen = sizeof(res_addr);
xgboost_CHECK_SYS_CALL(
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
return ntohs(res_addr.sin_port);
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
return system::FailWithCode("getsockname failed.");
}
*p_out = ntohs(res_addr.sin_port);
}

return Success();
}

[[nodiscard]] auto Port() const {
Expand Down Expand Up @@ -641,13 +669,13 @@ class TCPSocket {
*/
std::size_t Send(StringView str);
/**
* \brief Receive string, format is matched with the Python socket wrapper in RABIT.
* @brief Receive string, format is matched with the Python socket wrapper in RABIT.
*/
std::size_t Recv(std::string *p_str);
[[nodiscard]] Result Recv(std::string *p_str);
/**
* @brief Close the socket, called automatically in destructor if the socket is not closed.
*/
Result Close() {
[[nodiscard]] Result Close() {
if (InvalidSocket() != handle_) {
auto rc = system::CloseSocket(handle_);
#if defined(_WIN32)
Expand All @@ -664,6 +692,25 @@ class TCPSocket {
}
return Success();
}
/**
* @brief Call shutdown on the socket.
*/
[[nodiscard]] Result Shutdown() {
if (this->IsClosed()) {
return Success();
}
auto rc = system::ShutdownSocket(this->Handle());
#if defined(_WIN32)
// Windows cannot shutdown a socket if it's not connected.
if (rc == -1 && system::LastError() == WSAENOTCONN) {
return Success();
}
#endif
if (rc != 0) {
return system::FailWithCode("Failed to shutdown socket.");
}
return Success();
}

/**
* \brief Create a TCP socket on specified domain.
Expand Down
6 changes: 3 additions & 3 deletions plugin/federated/federated_tracker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ Result FederatedTracker::Shutdown() {

[[nodiscard]] Json FederatedTracker::WorkerArgs() const {
auto rc = this->WaitUntilReady();
CHECK(rc.OK()) << rc.Report();
SafeColl(rc);

std::string host;
rc = GetHostAddress(&host);
CHECK(rc.OK());
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host};
args["DMLC_TRACKER_PORT"] = this->Port();
args["dmlc_tracker_uri"] = String{host};
args["dmlc_tracker_port"] = this->Port();
return args;
}
} // namespace xgboost::collective
24 changes: 22 additions & 2 deletions rabit/include/rabit/internal/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E
if ((revents & POLLNVAL) != 0) {
return xgboost::system::FailWithCode("Invalid polling request.");
}
if ((revents & POLLHUP) != 0) {
// Excerpt from the Linux manual:
//
// Note that when reading from a channel such as a pipe or a stream socket, this event
// merely indicates that the peer closed its end of the channel.Subsequent reads from
// the channel will return 0 (end of file) only after all outstanding data in the
// channel has been consumed.
//
// We don't usually have a barrier for exiting workers, it's normal to have one end
// exit while the other still reading data.
return xgboost::collective::Success();
}
#if defined(POLLRDHUP)
// Linux only flag
if ((revents & POLLRDHUP) != 0) {
return xgboost::system::FailWithCode("Poll hung up on the other end.");
}
#endif // defined(POLLRDHUP)
return xgboost::collective::Success();
}

Expand Down Expand Up @@ -179,9 +197,11 @@ struct PollHelper {
}
std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
if (ret == 0) {
return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out));
return xgboost::collective::Fail(
"Poll timeout:" + std::to_string(timeout.count()) + " seconds.",
std::make_error_code(std::errc::timed_out));
} else if (ret < 0) {
return xgboost::system::FailWithCode("Poll failed.");
return xgboost::system::FailWithCode("Poll failed, nfds:" + std::to_string(fdset.size()));
}

for (auto& pfd : fdset) {
Expand Down
24 changes: 12 additions & 12 deletions rabit/src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ bool AllreduceBase::Shutdown() {
try {
for (auto &all_link : all_links) {
if (!all_link.sock.IsClosed()) {
all_link.sock.Close();
SafeColl(all_link.sock.Close());
}
}
all_links.clear();
Expand All @@ -146,7 +146,7 @@ bool AllreduceBase::Shutdown() {
LOG(FATAL) << rc.Report();
}
tracker.Send(xgboost::StringView{"shutdown"});
tracker.Close();
SafeColl(tracker.Close());
xgboost::system::SocketFinalize();
return true;
} catch (std::exception const &e) {
Expand All @@ -167,7 +167,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {

tracker.Send(xgboost::StringView{"print"});
tracker.Send(xgboost::StringView{msg});
tracker.Close();
SafeColl(tracker.Close());
}

// util to parse data with unit suffix
Expand Down Expand Up @@ -332,15 +332,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) {

auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
// create listening socket
int port = sock_listen.BindHost();
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
sock_listen.Listen();
std::int32_t port{0};
SafeColl(sock_listen.BindHost(&port));
SafeColl(sock_listen.Listen());

// get number of to connect and number of to accept nodes from tracker
int num_conn, num_accept, num_error = 1;
do {
for (auto & all_link : all_links) {
all_link.sock.Close();
SafeColl(all_link.sock.Close());
}
// tracker construct goodset
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
Expand All @@ -352,15 +352,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
LinkRecord r;
int hport, hrank;
std::string hname;
tracker.Recv(&hname);
SafeColl(tracker.Recv(&hname));
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
// connect to peer
if (!xgboost::collective::Connect(xgboost::StringView{hname}, hport, connect_retry,
timeout_sec, &r.sock)
.OK()) {
num_error += 1;
r.sock.Close();
SafeColl(r.sock.Close());
continue;
}
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
Expand All @@ -386,7 +386,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
// send back socket listening port to tracker
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
// close connection to tracker
tracker.Close();
SafeColl(tracker.Close());

// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
Expand All @@ -408,7 +408,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
}
if (!match) all_links.emplace_back(std::move(r));
}
sock_listen.Close();
SafeColl(sock_listen.Close());

this->parent_index = -1;
// setup tree links and ring structure
Expand Down Expand Up @@ -635,7 +635,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);

if (len == 0) {
links[parent_index].sock.Close();
SafeColl(links[parent_index].sock.Close());
return ReportError(&links[parent_index], kRecvZeroLen);
}
if (len != -1) {
Expand Down
4 changes: 2 additions & 2 deletions rabit/src/allreduce_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class AllreduceBase : public IEngine {
ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return kRecvZeroLen;
SafeColl(sock.Close()); return kRecvZeroLen;
}
if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len);
Expand All @@ -289,7 +289,7 @@ class AllreduceBase : public IEngine {
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return kRecvZeroLen;
SafeColl(sock.Close()); return kRecvZeroLen;
}
if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len);
Expand Down
Loading
Loading