diff --git a/llvm/include/llvm/Support/raw_ostream.h b/llvm/include/llvm/Support/raw_ostream.h index 1e01eb9ea19c4..7c8d264afeff2 100644 --- a/llvm/include/llvm/Support/raw_ostream.h +++ b/llvm/include/llvm/Support/raw_ostream.h @@ -16,6 +16,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/DataTypes.h" +#include "llvm/Support/Threading.h" #include #include #include @@ -615,6 +616,8 @@ class raw_fd_stream : public raw_fd_ostream { /// immediately destroyed. raw_fd_stream(StringRef Filename, std::error_code &EC); + raw_fd_stream(int fd, bool shouldClose); + /// This reads the \p Size bytes into a buffer pointed by \p Ptr. /// /// \param Ptr The start of the buffer to hold data to be read. @@ -630,6 +633,54 @@ class raw_fd_stream : public raw_fd_ostream { static bool classof(const raw_ostream *OS); }; +//===----------------------------------------------------------------------===// +// Socket Streams +//===----------------------------------------------------------------------===// + +/// A raw stream for sockets reading/writing + +class raw_socket_stream; + +// Make sure that calls to WSAStartup and WSACleanup are balanced. +#ifdef _WIN32 +class WSABalancer { +public: + WSABalancer(); + ~WSABalancer(); +}; +#endif // _WIN32 + +class ListeningSocket { + int FD; + std::string SocketPath; + ListeningSocket(int SocketFD, StringRef SocketPath); +#ifdef _WIN32 + WSABalancer _; +#endif // _WIN32 + +public: + static Expected createUnix( + StringRef SocketPath, + int MaxBacklog = llvm::hardware_concurrency().compute_thread_count()); + Expected> accept(); + ListeningSocket(ListeningSocket &&LS); + ~ListeningSocket(); +}; +class raw_socket_stream : public raw_fd_stream { + uint64_t current_pos() const override { return 0; } +#ifdef _WIN32 + WSABalancer _; +#endif // _WIN32 + +public: + raw_socket_stream(int SocketFD); + /// Create a \p raw_socket_stream connected to the Unix domain socket at \p + /// SocketPath. + static Expected> + createConnectedUnix(StringRef SocketPath); + ~raw_socket_stream(); +}; + //===----------------------------------------------------------------------===// // Output Stream Adaptors //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt index b96d62c7a6224..80854d1d09d98 100644 --- a/llvm/lib/Support/CMakeLists.txt +++ b/llvm/lib/Support/CMakeLists.txt @@ -40,7 +40,7 @@ endif() if( MSVC OR MINGW ) # libuuid required for FOLDERID_Profile usage in lib/Support/Windows/Path.inc. # advapi32 required for CryptAcquireContextW in lib/Support/Windows/Path.inc. - set(system_libs ${system_libs} psapi shell32 ole32 uuid advapi32) + set(system_libs ${system_libs} psapi shell32 ole32 uuid advapi32 Ws2_32) elseif( CMAKE_HOST_UNIX ) if( HAVE_LIBRT ) set(system_libs ${system_libs} rt) diff --git a/llvm/lib/Support/raw_ostream.cpp b/llvm/lib/Support/raw_ostream.cpp index 8908e7b6a150c..a8ffd23fc6e94 100644 --- a/llvm/lib/Support/raw_ostream.cpp +++ b/llvm/lib/Support/raw_ostream.cpp @@ -15,6 +15,7 @@ #include "llvm/Config/config.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Duration.h" +#include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/Format.h" @@ -23,11 +24,17 @@ #include "llvm/Support/NativeFormatting.h" #include "llvm/Support/Process.h" #include "llvm/Support/Program.h" +#include "llvm/Support/Threading.h" #include #include #include #include +#ifndef _WIN32 +#include +#include +#endif // _WIN32 + // may provide O_BINARY. #if defined(HAVE_FCNTL_H) # include @@ -58,6 +65,13 @@ #include "llvm/Support/ConvertUTF.h" #include "llvm/Support/Signals.h" #include "llvm/Support/Windows/WindowsSupport.h" +// winsock2.h must be included before afunix.h. Briefly turn off clang-format to +// avoid error. +// clang-format off +#include +#include +// clang-format on +#include #endif using namespace llvm; @@ -644,7 +658,7 @@ raw_fd_ostream::raw_fd_ostream(int fd, bool shouldClose, bool unbuffered, // Check if this is a console device. This is not equivalent to isatty. IsWindowsConsole = ::GetFileType((HANDLE)::_get_osfhandle(fd)) == FILE_TYPE_CHAR; -#endif +#endif // _WIN32 // Get the starting position. off_t loc = ::lseek(FD, 0, SEEK_CUR); @@ -928,6 +942,9 @@ raw_fd_stream::raw_fd_stream(StringRef Filename, std::error_code &EC) EC = std::make_error_code(std::errc::invalid_argument); } +raw_fd_stream::raw_fd_stream(int fd, bool shouldClose) + : raw_fd_ostream(fd, shouldClose, false, OStreamKind::OK_FDStream) {} + ssize_t raw_fd_stream::read(char *Ptr, size_t Size) { assert(get_fd() >= 0 && "File already closed."); ssize_t Ret = ::read(get_fd(), (void *)Ptr, Size); @@ -942,6 +959,145 @@ bool raw_fd_stream::classof(const raw_ostream *OS) { return OS->get_kind() == OStreamKind::OK_FDStream; } +//===----------------------------------------------------------------------===// +// raw_socket_stream +//===----------------------------------------------------------------------===// + +#ifdef _WIN32 +WSABalancer::WSABalancer() { + WSADATA WsaData = {0}; + if (WSAStartup(MAKEWORD(2, 2), &WsaData) != 0) { + llvm::report_fatal_error("WSAStartup failed"); + } +} + +WSABalancer::~WSABalancer() { WSACleanup(); } + +#endif // _WIN32 + +static std::error_code getLastSocketErrorCode() { +#ifdef _WIN32 + return std::error_code(::WSAGetLastError(), std::system_category()); +#else + return std::error_code(errno, std::system_category()); +#endif +} + +ListeningSocket::ListeningSocket(int SocketFD, StringRef SocketPath) + : FD(SocketFD), SocketPath(SocketPath) {} + +ListeningSocket::ListeningSocket(ListeningSocket &&LS) + : FD(LS.FD), SocketPath(LS.SocketPath) { + LS.FD = -1; +} + +Expected ListeningSocket::createUnix(StringRef SocketPath, + int MaxBacklog) { + +#ifdef _WIN32 + WSABalancer _; + SOCKET MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0); + if (MaybeWinsocket == INVALID_SOCKET) { +#else + int MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0); + if (MaybeWinsocket == -1) { +#endif + return llvm::make_error(getLastSocketErrorCode(), + "socket create failed"); + } + + struct sockaddr_un Addr; + memset(&Addr, 0, sizeof(Addr)); + Addr.sun_family = AF_UNIX; + strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1); + + if (bind(MaybeWinsocket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) { + std::error_code Err = getLastSocketErrorCode(); + if (Err == std::errc::address_in_use) + ::close(MaybeWinsocket); + return llvm::make_error(Err, "Bind error"); + } + if (listen(MaybeWinsocket, MaxBacklog) == -1) { + return llvm::make_error(getLastSocketErrorCode(), + "Listen error"); + } + int UnixSocket; +#ifdef _WIN32 + UnixSocket = _open_osfhandle(MaybeWinsocket, 0); +#else + UnixSocket = MaybeWinsocket; +#endif // _WIN32 + ListeningSocket ListenSocket(UnixSocket, SocketPath); + return ListenSocket; +} + +Expected> ListeningSocket::accept() { + int AcceptFD; +#ifdef _WIN32 + SOCKET WinServerSock = _get_osfhandle(FD); + SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL); + AcceptFD = _open_osfhandle(WinAcceptSock, 0); +#else + AcceptFD = ::accept(FD, NULL, NULL); +#endif //_WIN32 + if (AcceptFD == -1) + return llvm::make_error(getLastSocketErrorCode(), + "Accept failed"); + return std::make_unique(AcceptFD); +} + +ListeningSocket::~ListeningSocket() { + if (FD == -1) + return; + ::close(FD); + unlink(SocketPath.c_str()); +} + +static Expected GetSocketFD(StringRef SocketPath) { +#ifdef _WIN32 + SOCKET MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0); + if (MaybeWinsocket == INVALID_SOCKET) { +#else + int MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0); + if (MaybeWinsocket == -1) { +#endif // _WIN32 + return llvm::make_error(getLastSocketErrorCode(), + "Create socket failed"); + } + + struct sockaddr_un Addr; + memset(&Addr, 0, sizeof(Addr)); + Addr.sun_family = AF_UNIX; + strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1); + + int status = connect(MaybeWinsocket, (struct sockaddr *)&Addr, sizeof(Addr)); + if (status == -1) { + return llvm::make_error(getLastSocketErrorCode(), + "Connect socket failed"); + } +#ifdef _WIN32 + return _open_osfhandle(MaybeWinsocket, 0); +#else + return MaybeWinsocket; +#endif // _WIN32 +} + +raw_socket_stream::raw_socket_stream(int SocketFD) + : raw_fd_stream(SocketFD, true) {} + +Expected> +raw_socket_stream::createConnectedUnix(StringRef SocketPath) { +#ifdef _WIN32 + WSABalancer _; +#endif // _WIN32 + Expected FD = GetSocketFD(SocketPath); + if (!FD) + return FD.takeError(); + return std::make_unique(*FD); +} + +raw_socket_stream::~raw_socket_stream() {} + //===----------------------------------------------------------------------===// // raw_string_ostream //===----------------------------------------------------------------------===// diff --git a/llvm/unittests/Support/CMakeLists.txt b/llvm/unittests/Support/CMakeLists.txt index e1bf793536b68..df35a7b7f3626 100644 --- a/llvm/unittests/Support/CMakeLists.txt +++ b/llvm/unittests/Support/CMakeLists.txt @@ -103,6 +103,7 @@ add_llvm_unittest(SupportTests raw_ostream_test.cpp raw_pwrite_stream_test.cpp raw_sha1_ostream_test.cpp + raw_socket_stream_test.cpp xxhashTest.cpp DEPENDS diff --git a/llvm/unittests/Support/raw_socket_stream_test.cpp b/llvm/unittests/Support/raw_socket_stream_test.cpp new file mode 100644 index 0000000000000..53eb86ae45d29 --- /dev/null +++ b/llvm/unittests/Support/raw_socket_stream_test.cpp @@ -0,0 +1,52 @@ +#include "llvm/ADT/SmallString.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/FileUtilities.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" +#include +#include +#include + +using namespace llvm; + +namespace { + +TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) { + SmallString<100> SocketPath; + llvm::sys::fs::createUniquePath("test_raw_socket_stream.sock", SocketPath, + true); + + char Bytes[8]; + + Expected MaybeServerListener = + ListeningSocket::createUnix(SocketPath); + ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded()); + + ListeningSocket ServerListener = std::move(*MaybeServerListener); + + Expected> MaybeClient = + raw_socket_stream::createConnectedUnix(SocketPath); + ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded()); + + raw_socket_stream &Client = **MaybeClient; + + Expected> MaybeServer = + ServerListener.accept(); + ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded()); + + raw_socket_stream &Server = **MaybeServer; + + Client << "01234567"; + Client.flush(); + + ssize_t BytesRead = Server.read(Bytes, 8); + + std::string string(Bytes, 8); + + ASSERT_EQ(8, BytesRead); + ASSERT_EQ("01234567", string); +} +} // namespace \ No newline at end of file