Skip to content

Commit

Permalink
Update to use getaddrinfo and set to allow only from a certain node.
Browse files Browse the repository at this point in the history
  • Loading branch information
coldav committed May 1, 2024
1 parent adf7ac5 commit db1cdd2
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 43 deletions.
2 changes: 1 addition & 1 deletion hal/hal_remote/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

set(HAL_SOURCE
${CMAKE_CURRENT_SOURCE_DIR}/include/hal_remote/hal_transmitter.h
${CMAKE_CURRENT_SOURCE_DIR}/source/hal_socket_transmitter.cpp
)

if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
list (APPEND HAL_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/include/hal_remote/hal_socket_transmitter.h)
list (APPEND HAL_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/source/hal_socket_transmitter.cpp)
endif()

add_library(
Expand Down
38 changes: 29 additions & 9 deletions hal/hal_remote/include/hal_remote/hal_socket_transmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,39 @@
#include <hal_remote/hal_transmitter.h>
#include <netinet/in.h>

#include <string>

namespace hal {

/// @brief A very simple socket based version of a hal_transmitter
/// @note This supports both client and server mode, and the sure should use
/// start_server() or make_connection() as appropriate. This supports the
/// required port being 0 for a server allows it to find a free port. For
/// some operations error_code enum will be used, but for `send` and `receive`
/// start_server() or make_connection() as appropriate. This does not support
/// the required port being 0 to allow a server to find a free port. For some
/// operations the `error_code` enum will be used, but for `send` and `receive`
/// these are derived functions, so get_last_error() can be used.
///
/// We highly recommend using port forwarding and a user process with this to
/// reduce any security risk.
class hal_socket_transmitter : public hal_transmitter {
public:
hal_socket_transmitter(uint16_t port = 0) : port_requested(port) {}
/// @note the default port allows us to create the
hal_socket_transmitter(uint16_t port = 0, const char *node = "127.0.0.1")
: port_requested(port), node(node) {}
~hal_socket_transmitter();

enum error_code {
success,
socket_failed,
port_0_requested,
bind_failed,
connect_failed,
connection_closed,
listen_failed,
accept_failed,
send_error,
recv_error,
getsockname_failed
getsockname_failed,
getaddrinfo_failed
};

/// @brief set port we wish to request on. This must be done before any calls
Expand All @@ -52,6 +62,12 @@ class hal_socket_transmitter : public hal_transmitter {
/// in some cases to default it initially and then set it later.
void set_port(uint16_t port) { port_requested = port; }

/// @brief set node we wish to limit connections from.
/// @param node_in we wish to limit connections from
/// @note This duplicates the constructor argument but makes it easier
/// in some cases to default it initially and then set it later.
void set_node(const char *node_in) { node = node_in; }

/// @brief Start the server end
/// @param print_port optionally print out that we are listening on a
/// particular port
Expand Down Expand Up @@ -84,7 +100,10 @@ class hal_socket_transmitter : public hal_transmitter {
bool receive(void *data, uint32_t size) override;

/// @brief Send `size` bytes of `data` with an optional flush
bool send(const void *data, uint32_t size, bool flush);
bool send(const void *data, uint32_t size, bool flush) override;

/// @brief Attempt to shut the connection down gracefully.
void shutdown();

private:
/// @brief connect to the remote server
Expand All @@ -106,12 +125,13 @@ class hal_socket_transmitter : public hal_transmitter {

uint16_t port_requested;
uint16_t current_port = 0;
sockaddr_in server_address;
int sock = 0;
int fd_to_use = 0;
sockaddr server_address;
int sock = -1;
int fd_to_use = -1;
bool setup_connection_done = false;
error_code last_error = error_code::success;
bool is_connected = false;
std::string node;
};
} // namespace hal
#endif
125 changes: 92 additions & 33 deletions hal/hal_remote/source/hal_socket_transmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,27 @@

#include <hal_remote/hal_socket_transmitter.h>
#include <hal_remote/hal_transmitter.h>
#include <netdb.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>

#include <cstdio>
#include <cstring>
#include <string>

namespace hal {
hal_socket_transmitter::~hal_socket_transmitter() {
if (sock) {
close(sock);
}
if (fd_to_use && fd_to_use != sock) {
close(fd_to_use);
}
shutdown();
}

hal_socket_transmitter::error_code hal_socket_transmitter::start_server(
bool print_port) {
if (!setup_connection_done) {
hal_socket_transmitter::error_code res = setup_connection(true);
if (res != 0) {
const hal_socket_transmitter::error_code res = setup_connection(true);
if (res != hal_socket_transmitter::success) {
last_error = res;
return res;
}
setup_connection_done = true;
Expand All @@ -45,8 +46,8 @@ hal_socket_transmitter::error_code hal_socket_transmitter::start_server(
return last_error;
}
if (print_port) {
printf("Listening on port %d\n", get_port());
fflush(stdout);
(void)printf("Listening on port %d\n", get_port());
(void)fflush(stdout);
}
if (accept() == -1) {
last_error = hal_socket_transmitter::accept_failed;
Expand Down Expand Up @@ -76,12 +77,12 @@ bool hal_socket_transmitter::receive(void *data, uint32_t size) {

// repeatedly recv until `size` bytes is read.
do {
int res = recv(fd_to_use, ((char *)data) + offset, data_to_read, 0);
const int res = recv(fd_to_use, ((char *)data) + offset, data_to_read, 0);
// If recv returns 0, this indicates the connection has been dropped
// It's not an error as such but we are not able to continue
if (res == 0) {
is_connected = false;
fd_to_use = 0;
fd_to_use = -1;
last_error = hal_socket_transmitter::connection_closed;
return false;
}
Expand All @@ -108,8 +109,8 @@ bool hal_socket_transmitter::send(const void *data, uint32_t size, bool flush) {
}

hal_socket_transmitter::error_code hal_socket_transmitter::connect() {
int res = ::connect(sock, (struct sockaddr *)&server_address,
sizeof(server_address));
const int res = ::connect(sock, (struct sockaddr *)&server_address,
sizeof(server_address));
if (res != -1) {
fd_to_use = sock;
is_connected = true;
Expand All @@ -121,43 +122,101 @@ hal_socket_transmitter::error_code hal_socket_transmitter::connect() {

hal_socket_transmitter::error_code hal_socket_transmitter::setup_connection(
bool server) {
// @return 0 if success, otherwise -1 and errno set
sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
server_address.sin_family = AF_INET;
server_address.sin_port = htons(port_requested);
server_address.sin_addr.s_addr = INADDR_ANY;
if (server) {
if (int res = bind(sock, (struct sockaddr *)&server_address,
sizeof(server_address)) != 0) {
return hal_socket_transmitter::bind_failed;
// don't support port 0
if (port_requested == 0) {
if (debug_enabled()) {
(void)fprintf(stderr, "port 0 requested. This is disallowed\n");
}
return hal_socket_transmitter::port_0_requested;
}
sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);

struct sockaddr_in local_addr;
socklen_t len = sizeof(local_addr);
if (int res =
getsockname(sock, (struct sockaddr *)&local_addr, &len) != 0) {
return hal_socket_transmitter::getsockname_failed;
struct addrinfo hints;
struct addrinfo *result;
int sfd, s;
std::memset(&hints, 0, sizeof(struct addrinfo));
hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_NUMERICSERV;
hints.ai_protocol = IPPROTO_TCP;

const std::string port_str = std::to_string(port_requested);
s = getaddrinfo(node.c_str(), port_str.c_str(), &hints, &result);
// Use getaddrinfo on the node and port. This may return one or
// more entries we can bind to in priority order. We take the first
// one that matches.
if (s != 0) {
if (debug_enabled()) {
(void)fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(s));
}
current_port = ntohs(local_addr.sin_port);
return hal_socket_transmitter::getaddrinfo_failed;
} else {
current_port = port_requested;
bool bind_failed = false;
bool socket_failed = false;
for (auto *rp = result; rp != NULL; rp = rp->ai_next) {
bind_failed = false;
socket_failed = false;
sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (sock == -1) {
socket_failed = true;
continue;
}
server_address = *rp->ai_addr;
if (server) {
if (const int res = bind(sock, rp->ai_addr, rp->ai_addrlen) != 0) {
close(sock);
sock = -1;
bind_failed = true;
continue;
}

struct sockaddr_in local_addr;
socklen_t len = sizeof(local_addr);
if (const int res =
getsockname(sock, (struct sockaddr *)&local_addr, &len) != 0) {
return hal_socket_transmitter::getsockname_failed;
}
current_port = ntohs(local_addr.sin_port);
}
}
if (sock == -1) {
if (bind_failed) {
return hal_socket_transmitter::bind_failed;
} else if (socket_failed) {
return hal_socket_transmitter::socket_failed;
}
}
if (!server) {
current_port = port_requested;
}
}
setup_connection_done = true;
return hal_socket_transmitter::success;
}

int hal_socket_transmitter::listen() {
if (int res = ::listen(sock, 1) != 0) {
if (const int res = ::listen(sock, 1) != 0) {
return res;
}
return 0;
}
int hal_socket_transmitter::accept() {
int res = ::accept(sock, nullptr, nullptr);
const int res = ::accept(sock, nullptr, nullptr);
if (res != -1) {
fd_to_use = res;
}
return res;
}

void hal_socket_transmitter::shutdown() {
if (fd_to_use != -1) {
close(fd_to_use);
}
if (sock != -1 && sock != fd_to_use) {
close(sock);
}
fd_to_use = -1;
sock = -1;
}

} // namespace hal

0 comments on commit db1cdd2

Please sign in to comment.