Skip to content

Commit

Permalink
[coll] Implement shutdown for tracker and comm. (#10208)
Browse files Browse the repository at this point in the history
- Force shutdown the tracker.
- Implement shutdown notice for error handling thread in comm.
  • Loading branch information
trivialfis committed Apr 19, 2024
1 parent 8fb05c8 commit 3fbb221
Show file tree
Hide file tree
Showing 24 changed files with 553 additions and 199 deletions.
18 changes: 11 additions & 7 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1555,7 +1555,7 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);

/**
* @brief Get the arguments needed for running workers. This should be called after
* XGTrackerRun() and XGTrackerWait()
* XGTrackerRun().
*
* @param handle The handle to the tracker.
* @param args The arguments returned as a JSON document.
Expand All @@ -1565,28 +1565,32 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);

/**
* @brief Run the tracker.
* @brief Start the tracker. The tracker runs in the background and this function returns
* once the tracker is started.
*
* @param handle The handle to the tracker.
* @param config Unused at the moment, preserved for the future.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerRun(TrackerHandle handle);
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *config);

/**
* @brief Wait for the tracker to finish, should be called after XGTrackerRun().
* @brief Wait for the tracker to finish, should be called after XGTrackerRun(). This
* function will block until the tracker task is finished or timeout is reached.
*
* @param handle The handle to the tracker.
* @param config JSON encoded configuration. No argument is required yet, preserved for
* the future.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config);
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);

/**
* @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker
* cannot close properly, manual interruption is required.
* @brief Free a tracker instance. This should be called after XGTrackerWaitFor(). If the
* tracker is not properly waited, this function will shutdown all connections with
* the tracker, potentially leading to undefined behavior.
*
* @param handle The handle to the tracker.
*
Expand Down
81 changes: 64 additions & 17 deletions include/xgboost/collective/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,21 @@ inline std::int32_t CloseSocket(SocketT fd) {
#endif
}

inline std::int32_t ShutdownSocket(SocketT fd) {
#if defined(_WIN32)
auto rc = shutdown(fd, SD_BOTH);
if (rc != 0 && LastError() == WSANOTINITIALISED) {
return 0;
}
#else
auto rc = shutdown(fd, SHUT_RDWR);
if (rc != 0 && LastError() == ENOTCONN) {
return 0;
}
#endif
return rc;
}

inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
Expand Down Expand Up @@ -499,36 +514,49 @@ class TCPSocket {
*/
[[nodiscard]] HandleT const &Handle() const { return handle_; }
/**
* \brief Listen to incoming requests. Should be called after bind.
* @brief Listen to incoming requests. Should be called after bind.
*/
void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
[[nodiscard]] Result Listen(std::int32_t backlog = 16) {
if (listen(handle_, backlog) != 0) {
return system::FailWithCode("Failed to listen.");
}
return Success();
}
/**
* \brief Bind socket to INADDR_ANY, return the port selected by the OS.
* @brief Bind socket to INADDR_ANY, return the port selected by the OS.
*/
[[nodiscard]] in_port_t BindHost() {
[[nodiscard]] Result BindHost(std::int32_t* p_out) {
// Use int32 instead of in_port_t for consistency. We take port as parameter from
// users using other languages, the port is usually stored and passed around as int.
if (Domain() == SockDomain::kV6) {
auto addr = SockAddrV6::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
xgboost_CHECK_SYS_CALL(
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
return system::FailWithCode("bind failed.");
}

sockaddr_in6 res_addr;
socklen_t addrlen = sizeof(res_addr);
xgboost_CHECK_SYS_CALL(
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
return ntohs(res_addr.sin6_port);
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
return system::FailWithCode("getsockname failed.");
}
*p_out = ntohs(res_addr.sin6_port);
} else {
auto addr = SockAddrV4::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
xgboost_CHECK_SYS_CALL(
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
return system::FailWithCode("bind failed.");
}

sockaddr_in res_addr;
socklen_t addrlen = sizeof(res_addr);
xgboost_CHECK_SYS_CALL(
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
return ntohs(res_addr.sin_port);
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
return system::FailWithCode("getsockname failed.");
}
*p_out = ntohs(res_addr.sin_port);
}

return Success();
}

[[nodiscard]] auto Port() const {
Expand Down Expand Up @@ -641,13 +669,13 @@ class TCPSocket {
*/
std::size_t Send(StringView str);
/**
* \brief Receive string, format is matched with the Python socket wrapper in RABIT.
* @brief Receive string, format is matched with the Python socket wrapper in RABIT.
*/
std::size_t Recv(std::string *p_str);
[[nodiscard]] Result Recv(std::string *p_str);
/**
* @brief Close the socket, called automatically in destructor if the socket is not closed.
*/
Result Close() {
[[nodiscard]] Result Close() {
if (InvalidSocket() != handle_) {
auto rc = system::CloseSocket(handle_);
#if defined(_WIN32)
Expand All @@ -664,6 +692,25 @@ class TCPSocket {
}
return Success();
}
/**
* @brief Call shutdown on the socket.
*/
[[nodiscard]] Result Shutdown() {
if (this->IsClosed()) {
return Success();
}
auto rc = system::ShutdownSocket(this->Handle());
#if defined(_WIN32)
// Windows cannot shutdown a socket if it's not connected.
if (rc == -1 && system::LastError() == WSAENOTCONN) {
return Success();
}
#endif
if (rc != 0) {
return system::FailWithCode("Failed to shutdown socket.");
}
return Success();
}

/**
* \brief Create a TCP socket on specified domain.
Expand Down
6 changes: 3 additions & 3 deletions plugin/federated/federated_tracker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ Result FederatedTracker::Shutdown() {

[[nodiscard]] Json FederatedTracker::WorkerArgs() const {
auto rc = this->WaitUntilReady();
CHECK(rc.OK()) << rc.Report();
SafeColl(rc);

std::string host;
rc = GetHostAddress(&host);
CHECK(rc.OK());
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host};
args["DMLC_TRACKER_PORT"] = this->Port();
args["dmlc_tracker_uri"] = String{host};
args["dmlc_tracker_port"] = this->Port();
return args;
}
} // namespace xgboost::collective
24 changes: 22 additions & 2 deletions rabit/include/rabit/internal/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E
if ((revents & POLLNVAL) != 0) {
return xgboost::system::FailWithCode("Invalid polling request.");
}
if ((revents & POLLHUP) != 0) {
// Excerpt from the Linux manual:
//
// Note that when reading from a channel such as a pipe or a stream socket, this event
// merely indicates that the peer closed its end of the channel.Subsequent reads from
// the channel will return 0 (end of file) only after all outstanding data in the
// channel has been consumed.
//
// We don't usually have a barrier for exiting workers, it's normal to have one end
// exit while the other still reading data.
return xgboost::collective::Success();
}
#if defined(POLLRDHUP)
// Linux only flag
if ((revents & POLLRDHUP) != 0) {
return xgboost::system::FailWithCode("Poll hung up on the other end.");
}
#endif // defined(POLLRDHUP)
return xgboost::collective::Success();
}

Expand Down Expand Up @@ -179,9 +197,11 @@ struct PollHelper {
}
std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
if (ret == 0) {
return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out));
return xgboost::collective::Fail(
"Poll timeout:" + std::to_string(timeout.count()) + " seconds.",
std::make_error_code(std::errc::timed_out));
} else if (ret < 0) {
return xgboost::system::FailWithCode("Poll failed.");
return xgboost::system::FailWithCode("Poll failed, nfds:" + std::to_string(fdset.size()));
}

for (auto& pfd : fdset) {
Expand Down
24 changes: 12 additions & 12 deletions rabit/src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ bool AllreduceBase::Shutdown() {
try {
for (auto &all_link : all_links) {
if (!all_link.sock.IsClosed()) {
all_link.sock.Close();
SafeColl(all_link.sock.Close());
}
}
all_links.clear();
Expand All @@ -146,7 +146,7 @@ bool AllreduceBase::Shutdown() {
LOG(FATAL) << rc.Report();
}
tracker.Send(xgboost::StringView{"shutdown"});
tracker.Close();
SafeColl(tracker.Close());
xgboost::system::SocketFinalize();
return true;
} catch (std::exception const &e) {
Expand All @@ -167,7 +167,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {

tracker.Send(xgboost::StringView{"print"});
tracker.Send(xgboost::StringView{msg});
tracker.Close();
SafeColl(tracker.Close());
}

// util to parse data with unit suffix
Expand Down Expand Up @@ -332,15 +332,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) {

auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
// create listening socket
int port = sock_listen.BindHost();
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
sock_listen.Listen();
std::int32_t port{0};
SafeColl(sock_listen.BindHost(&port));
SafeColl(sock_listen.Listen());

// get number of to connect and number of to accept nodes from tracker
int num_conn, num_accept, num_error = 1;
do {
for (auto & all_link : all_links) {
all_link.sock.Close();
SafeColl(all_link.sock.Close());
}
// tracker construct goodset
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
Expand All @@ -352,15 +352,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
LinkRecord r;
int hport, hrank;
std::string hname;
tracker.Recv(&hname);
SafeColl(tracker.Recv(&hname));
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
// connect to peer
if (!xgboost::collective::Connect(xgboost::StringView{hname}, hport, connect_retry,
timeout_sec, &r.sock)
.OK()) {
num_error += 1;
r.sock.Close();
SafeColl(r.sock.Close());
continue;
}
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
Expand All @@ -386,7 +386,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
// send back socket listening port to tracker
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
// close connection to tracker
tracker.Close();
SafeColl(tracker.Close());

// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
Expand All @@ -408,7 +408,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
}
if (!match) all_links.emplace_back(std::move(r));
}
sock_listen.Close();
SafeColl(sock_listen.Close());

this->parent_index = -1;
// setup tree links and ring structure
Expand Down Expand Up @@ -635,7 +635,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);

if (len == 0) {
links[parent_index].sock.Close();
SafeColl(links[parent_index].sock.Close());
return ReportError(&links[parent_index], kRecvZeroLen);
}
if (len != -1) {
Expand Down
4 changes: 2 additions & 2 deletions rabit/src/allreduce_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class AllreduceBase : public IEngine {
ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return kRecvZeroLen;
SafeColl(sock.Close()); return kRecvZeroLen;
}
if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len);
Expand All @@ -289,7 +289,7 @@ class AllreduceBase : public IEngine {
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return kRecvZeroLen;
SafeColl(sock.Close()); return kRecvZeroLen;
}
if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len);
Expand Down
Loading

0 comments on commit 3fbb221

Please sign in to comment.