Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions toolbelt/sockets.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
#include "sockets.h"

#include <arpa/inet.h>
#include <cstring>
#include <netdb.h>
#include <netinet/in.h>
#include <stdint.h>
#include <stdlib.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>

#include <algorithm>
Expand Down Expand Up @@ -329,14 +331,24 @@ static struct sockaddr_un BuildUnixSocketName(const std::string &pathname) {
// On Linux we can create it in the abstract namespace which doesn't
// consume a pathname.
addr.sun_path[0] = '\0';
memcpy(addr.sun_path + 1, pathname.c_str(), pathname.size());
memcpy(addr.sun_path + 1, pathname.c_str(), std::min(pathname.size(), sizeof(addr.sun_path) - 2));
#else
// Portable uses the file system so it must be a valid path name.
memcpy(addr.sun_path, pathname.c_str(), pathname.size());
memcpy(addr.sun_path, pathname.c_str(), std::min(pathname.size(), sizeof(addr.sun_path) - 1));
#endif
return addr;
}

static std::string ExtractUnixSocketNameString(const struct sockaddr_un &addr, socklen_t addrlen) {
#if defined(__linux__)
auto addr_str_len = strnlen(addr.sun_path + 1, addrlen - offsetof(sockaddr_un, sun_path) - 1);
return std::string(addr.sun_path + 1, addr.sun_path + addr_str_len + 1);
#else
auto addr_str_len = strnlen(addr.sun_path, addrlen - offsetof(sockaddr_un, sun_path));
return std::string(addr.sun_path, addr.sun_path + addr_str_len);
#endif
}

absl::Status UnixSocket::Bind(const std::string &pathname, bool listen) {
struct sockaddr_un addr = BuildUnixSocketName(pathname);

Expand Down Expand Up @@ -381,12 +393,7 @@ absl::StatusOr<UnixSocket> UnixSocket::Accept(co::Coroutine *c) const {
"Failed to obtain bound address for accepted socket: %s",
strerror(errno)));
}
#ifdef __linux__
new_socket.bound_address_ = bound.sun_path + 1;
#else
new_socket.bound_address_ = bound.sun_path;

#endif
new_socket.bound_address_ = ExtractUnixSocketNameString(bound, len);
return new_socket;
}

Expand Down Expand Up @@ -531,11 +538,7 @@ absl::StatusOr<std::string> UnixSocket::GetPeerName() const {
return absl::InternalError(absl::StrFormat(
"Failed to obtain peer address for socket: %s", strerror(errno)));
}
#if defined(__linux__)
return std::string(peer.sun_path + 1);
#else
return std::string(peer.sun_path);
#endif
return ExtractUnixSocketNameString(peer, len);
}

absl::StatusOr<std::string> UnixSocket::LocalAddress() const {
Expand All @@ -550,11 +553,7 @@ absl::StatusOr<std::string> UnixSocket::LocalAddress() const {
return absl::InternalError(absl::StrFormat(
"Failed to obtain local address for socket: %s", strerror(errno)));
}
#if defined(__linux__)
return std::string(local.sun_path + 1);
#else
return std::string(local.sun_path);
#endif
return ExtractUnixSocketNameString(local, len);
}

// Network socket.
Expand Down
32 changes: 6 additions & 26 deletions toolbelt/sockets.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ template <typename H> inline H AbslHashValue(H h, const SocketAddress &a) {
// This is a general socket initialized with a file descriptor. Subclasses
// implement the different socket types.
class Socket {
public:
protected:
Socket() = default;
explicit Socket(int fd, bool connected = false)
: fd_(fd), connected_(connected) {}
Expand All @@ -307,6 +307,7 @@ class Socket {
return *this;
}
~Socket() {}
public:

void Close() {
fd_.Close();
Expand Down Expand Up @@ -363,12 +364,6 @@ class UnixSocket : public Socket {
public:
UnixSocket();
explicit UnixSocket(int fd, bool connected = false) : Socket(fd, connected) {}
UnixSocket(UnixSocket &&s) : Socket(std::move(s)) {}
UnixSocket(const UnixSocket &s) = default;
UnixSocket &operator=(const UnixSocket &s) = default;
UnixSocket &operator=(UnixSocket &&s) = default;

~UnixSocket() = default;

absl::Status Bind(const std::string &pathname, bool listen);
absl::Status Connect(const std::string &pathname);
Expand All @@ -391,16 +386,11 @@ class UnixSocket : public Socket {
// A socket for communication across the network. This is the base
// class for UDP and TCP sockets.
class NetworkSocket : public Socket {
public:
protected:
NetworkSocket() = default;
explicit NetworkSocket(int fd, bool connected = false)
: Socket(fd, connected) {}
NetworkSocket(const NetworkSocket &s)
: Socket(s), bound_address_(s.bound_address_) {}
NetworkSocket(NetworkSocket &&s)
: Socket(std::move(s)), bound_address_(std::move(s.bound_address_)) {}
~NetworkSocket() = default;
NetworkSocket &operator=(const NetworkSocket &s) = default;
public:

absl::Status Connect(const InetAddress &addr);

Expand All @@ -419,10 +409,6 @@ class UDPSocket : public NetworkSocket {
UDPSocket();
explicit UDPSocket(int fd, bool connected = false)
: NetworkSocket(fd, connected) {}
UDPSocket(const UDPSocket &) = default;
UDPSocket(UDPSocket &&s) : NetworkSocket(std::move(s)) {}
~UDPSocket() = default;
UDPSocket &operator=(const UDPSocket &s) = default;

absl::Status Bind(const InetAddress &addr);

Expand All @@ -448,10 +434,6 @@ class TCPSocket : public NetworkSocket {
TCPSocket();
explicit TCPSocket(int fd, bool connected = false)
: NetworkSocket(fd, connected) {}
TCPSocket(const TCPSocket &) = default;
TCPSocket(TCPSocket &&s) : NetworkSocket(std::move(s)) {}
~TCPSocket() = default;
TCPSocket &operator=(const TCPSocket &s) = default;

absl::Status Bind(const InetAddress &addr, bool listen);

Expand All @@ -467,10 +449,7 @@ class VirtualStreamSocket : public Socket {
VirtualStreamSocket();
explicit VirtualStreamSocket(int fd, bool connected = false)
: Socket(fd, connected) {}
VirtualStreamSocket(const VirtualStreamSocket &) = default;
VirtualStreamSocket(VirtualStreamSocket &&s) : Socket(std::move(s)) {}
~VirtualStreamSocket() = default;
VirtualStreamSocket &operator=(const VirtualStreamSocket &s) = default;

absl::Status Connect(const VirtualAddress &addr);

absl::Status Bind(const VirtualAddress &addr, bool listen);
Expand All @@ -496,6 +475,7 @@ class StreamSocket {
StreamSocket(StreamSocket &&s) = default;
~StreamSocket() = default;
StreamSocket &operator=(const StreamSocket &s) = default;
StreamSocket &operator=(StreamSocket &&s) = default;

// Binders for TCP, Virtual, and Unix sockets.
//
Expand Down