diff --git a/toolbelt/sockets.cc b/toolbelt/sockets.cc index 6157238..f3a8d69 100644 --- a/toolbelt/sockets.cc +++ b/toolbelt/sockets.cc @@ -5,12 +5,14 @@ #include "sockets.h" #include +#include #include #include #include #include #include #include +#include #include #include @@ -329,14 +331,24 @@ static struct sockaddr_un BuildUnixSocketName(const std::string &pathname) { // On Linux we can create it in the abstract namespace which doesn't // consume a pathname. addr.sun_path[0] = '\0'; - memcpy(addr.sun_path + 1, pathname.c_str(), pathname.size()); + memcpy(addr.sun_path + 1, pathname.c_str(), std::min(pathname.size(), sizeof(addr.sun_path) - 2)); #else // Portable uses the file system so it must be a valid path name. - memcpy(addr.sun_path, pathname.c_str(), pathname.size()); + memcpy(addr.sun_path, pathname.c_str(), std::min(pathname.size(), sizeof(addr.sun_path) - 1)); #endif return addr; } +static std::string ExtractUnixSocketNameString(const struct sockaddr_un &addr, socklen_t addrlen) { +#if defined(__linux__) + auto addr_str_len = strnlen(addr.sun_path + 1, addrlen - offsetof(sockaddr_un, sun_path) - 1); + return std::string(addr.sun_path + 1, addr.sun_path + addr_str_len + 1); +#else + auto addr_str_len = strnlen(addr.sun_path, addrlen - offsetof(sockaddr_un, sun_path)); + return std::string(addr.sun_path, addr.sun_path + addr_str_len); +#endif +} + absl::Status UnixSocket::Bind(const std::string &pathname, bool listen) { struct sockaddr_un addr = BuildUnixSocketName(pathname); @@ -381,12 +393,7 @@ absl::StatusOr UnixSocket::Accept(co::Coroutine *c) const { "Failed to obtain bound address for accepted socket: %s", strerror(errno))); } -#ifdef __linux__ - new_socket.bound_address_ = bound.sun_path + 1; -#else - new_socket.bound_address_ = bound.sun_path; - -#endif + new_socket.bound_address_ = ExtractUnixSocketNameString(bound, len); return new_socket; } @@ -531,11 +538,7 @@ absl::StatusOr UnixSocket::GetPeerName() const { return absl::InternalError(absl::StrFormat( "Failed to obtain peer address for socket: %s", strerror(errno))); } -#if defined(__linux__) - return std::string(peer.sun_path + 1); -#else - return std::string(peer.sun_path); -#endif + return ExtractUnixSocketNameString(peer, len); } absl::StatusOr UnixSocket::LocalAddress() const { @@ -550,11 +553,7 @@ absl::StatusOr UnixSocket::LocalAddress() const { return absl::InternalError(absl::StrFormat( "Failed to obtain local address for socket: %s", strerror(errno))); } -#if defined(__linux__) - return std::string(local.sun_path + 1); -#else - return std::string(local.sun_path); -#endif + return ExtractUnixSocketNameString(local, len); } // Network socket. diff --git a/toolbelt/sockets.h b/toolbelt/sockets.h index d951662..3535518 100644 --- a/toolbelt/sockets.h +++ b/toolbelt/sockets.h @@ -291,7 +291,7 @@ template inline H AbslHashValue(H h, const SocketAddress &a) { // This is a general socket initialized with a file descriptor. Subclasses // implement the different socket types. class Socket { -public: +protected: Socket() = default; explicit Socket(int fd, bool connected = false) : fd_(fd), connected_(connected) {} @@ -307,6 +307,7 @@ class Socket { return *this; } ~Socket() {} +public: void Close() { fd_.Close(); @@ -363,12 +364,6 @@ class UnixSocket : public Socket { public: UnixSocket(); explicit UnixSocket(int fd, bool connected = false) : Socket(fd, connected) {} - UnixSocket(UnixSocket &&s) : Socket(std::move(s)) {} - UnixSocket(const UnixSocket &s) = default; - UnixSocket &operator=(const UnixSocket &s) = default; - UnixSocket &operator=(UnixSocket &&s) = default; - - ~UnixSocket() = default; absl::Status Bind(const std::string &pathname, bool listen); absl::Status Connect(const std::string &pathname); @@ -391,16 +386,11 @@ class UnixSocket : public Socket { // A socket for communication across the network. This is the base // class for UDP and TCP sockets. class NetworkSocket : public Socket { -public: +protected: NetworkSocket() = default; explicit NetworkSocket(int fd, bool connected = false) : Socket(fd, connected) {} - NetworkSocket(const NetworkSocket &s) - : Socket(s), bound_address_(s.bound_address_) {} - NetworkSocket(NetworkSocket &&s) - : Socket(std::move(s)), bound_address_(std::move(s.bound_address_)) {} - ~NetworkSocket() = default; - NetworkSocket &operator=(const NetworkSocket &s) = default; +public: absl::Status Connect(const InetAddress &addr); @@ -419,10 +409,6 @@ class UDPSocket : public NetworkSocket { UDPSocket(); explicit UDPSocket(int fd, bool connected = false) : NetworkSocket(fd, connected) {} - UDPSocket(const UDPSocket &) = default; - UDPSocket(UDPSocket &&s) : NetworkSocket(std::move(s)) {} - ~UDPSocket() = default; - UDPSocket &operator=(const UDPSocket &s) = default; absl::Status Bind(const InetAddress &addr); @@ -448,10 +434,6 @@ class TCPSocket : public NetworkSocket { TCPSocket(); explicit TCPSocket(int fd, bool connected = false) : NetworkSocket(fd, connected) {} - TCPSocket(const TCPSocket &) = default; - TCPSocket(TCPSocket &&s) : NetworkSocket(std::move(s)) {} - ~TCPSocket() = default; - TCPSocket &operator=(const TCPSocket &s) = default; absl::Status Bind(const InetAddress &addr, bool listen); @@ -467,10 +449,7 @@ class VirtualStreamSocket : public Socket { VirtualStreamSocket(); explicit VirtualStreamSocket(int fd, bool connected = false) : Socket(fd, connected) {} - VirtualStreamSocket(const VirtualStreamSocket &) = default; - VirtualStreamSocket(VirtualStreamSocket &&s) : Socket(std::move(s)) {} - ~VirtualStreamSocket() = default; - VirtualStreamSocket &operator=(const VirtualStreamSocket &s) = default; + absl::Status Connect(const VirtualAddress &addr); absl::Status Bind(const VirtualAddress &addr, bool listen); @@ -496,6 +475,7 @@ class StreamSocket { StreamSocket(StreamSocket &&s) = default; ~StreamSocket() = default; StreamSocket &operator=(const StreamSocket &s) = default; + StreamSocket &operator=(StreamSocket &&s) = default; // Binders for TCP, Virtual, and Unix sockets. //