diff --git a/Sources/libreprl/libreprl-posix.c b/Sources/libreprl/libreprl-posix.c index 34905d91d..f523e76eb 100644 --- a/Sources/libreprl/libreprl-posix.c +++ b/Sources/libreprl/libreprl-posix.c @@ -23,9 +23,11 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -33,19 +35,42 @@ #include #include #include +#include #include #include #include +#ifdef __linux__ +# ifdef __has_include +# if __has_include() +# define HAS_CLOSE_RANGE_HEADERS +# endif +# elif defined(__FreeBSD__) +# ifdef CLOSE_RANGE_CLOEXEC +# define HAS_CLOSE_RANGE_HEADERS +# endif +# endif +#endif + // Well-known file descriptor numbers for reprl <-> child communication, child process side +// Make sure you modify reprl_fds[] below if you change these. #define REPRL_CHILD_CTRL_IN 100 #define REPRL_CHILD_CTRL_OUT 101 #define REPRL_CHILD_DATA_IN 102 #define REPRL_CHILD_DATA_OUT 103 +static const int reprl_fds[] = { + REPRL_CHILD_CTRL_IN, + REPRL_CHILD_CTRL_OUT, + REPRL_CHILD_DATA_IN, + REPRL_CHILD_DATA_OUT +}; + /// Maximum timeout in microseconds. Mostly just limited by the fact that the timeout in milliseconds has to fit into a 32-bit integer. #define REPRL_MAX_TIMEOUT_IN_MICROSECONDS ((uint64_t)(INT_MAX) * 1000) +#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0])) + static size_t min(size_t x, size_t y) { return x < y ? x : y; } @@ -79,6 +104,101 @@ static void free_string_array(char** arr) free(arr); } +static int getdtablesize_or_crash() { + const int tablesize = getdtablesize(); + if (tablesize < 0) { + fprintf(stderr, "getdtablesize() failed: %s. This likely means the system is borked.\n", + strerror(errno)); + abort(); + } + + return tablesize; +} + +static bool system_supports_close_range() { +#ifdef HAS_CLOSE_RANGE_HEADERS +# ifdef __linux__ + struct utsname buffer; + int major, minor, patch; + (void)uname(&buffer); // Linux uname can only throw EFAULT if buf is invalid. + + if (sscanf(buffer.release, "%d.%d.%d", &major, &minor, &patch) != 3) { + return false; + } + + return major > 5 || (major == 5 && minor >= 9); +# elif defined(__FreeBSD__) + // TODO: Technically, FreeBSD does support close_range, but I don't need support for it right + // now, so leaving this unimplemented. Don't have a platform to test on. + // https://man.freebsd.org/cgi/man.cgi?close_range(2) + return false; +# else + return false; +# endif +#else + return false; +#endif +} + +static int fd_qsort_compare(const void* a, const void* b) { + int fd_a = *(const int*)a; + int fd_b = *(const int*)b; + return (fd_a > fd_b) - (fd_a < fd_b); +} + +/// Fast path which uses close_range() to close ranges of fds. +static void close_all_non_reprl_fds_fast() { +#ifdef HAS_CLOSE_RANGE_HEADERS + // Unfortunately we cannot trust the reprl_fds array to be sorted since an accidental edit to + // reprl_fds could have broken that assumption. So let's create a new sorted array. It's cheap + // anyways. + int sorted_reprl_fds[ARRAY_SIZE(reprl_fds) + 2]; + sorted_reprl_fds[0] = 3; // Skip the well-known stdin, stdout, stderr fds. + memcpy(sorted_reprl_fds + 1, reprl_fds, sizeof(reprl_fds)); + sorted_reprl_fds[ARRAY_SIZE(sorted_reprl_fds) - 1] = getdtablesize_or_crash(); + qsort(sorted_reprl_fds, ARRAY_SIZE(sorted_reprl_fds), sizeof(int), fd_qsort_compare); + + // Cool, now we will iterate the sorted fds in ranges and close everything in between. + int start_fd = 3, end_fd; + for (size_t i = 0; i < ARRAY_SIZE(sorted_reprl_fds); i++) { + end_fd = sorted_reprl_fds[i]; + if (start_fd < end_fd) { + // Close the range [start_fd, end_fd) + close_range(start_fd, end_fd - 1, 0); + } + start_fd = end_fd + 1; + } +#else + fprintf(stderr, "close_all_non_reprl_fds_fast() called on a system which does not support " + "close_range(). This is likely a programming bug.\n"); + abort(); +#endif +} + +/// Fallback path which makes a close() syscall for each non-REPRL fd. +static void close_all_non_reprl_fds_slow() { + const int tablesize = getdtablesize_or_crash(); + + for (int i = 3; i < tablesize; i++) { + bool is_reprl_fd = false; + for (size_t j = 0; j < ARRAY_SIZE(reprl_fds); j++) { + if (i == reprl_fds[j]) { + is_reprl_fd = true; + break; + } + } + + if (!is_reprl_fd) { + close(i); + } + } +} + +/// Close all file descriptors except the well-known REPRL and stdio fds. +static void close_all_non_reprl_fds() { + system_supports_close_range() ? close_all_non_reprl_fds_fast() : close_all_non_reprl_fds_slow(); +} + // A unidirectional communication channel for larger amounts of data, up to a maximum size (REPRL_MAX_DATA_SIZE). // Implemented as a (RAM-backed) file for which the file descriptor is shared with the child process and which is mapped into our address space. struct data_channel { @@ -250,13 +370,7 @@ static int reprl_spawn_child(struct reprl_context* ctx) close(devnull); // close all other FDs. We try to use FD_CLOEXEC everywhere, but let's be extra sure we don't leak any fds to the child. - int tablesize = getdtablesize(); - for (int i = 3; i < tablesize; i++) { - if (i == REPRL_CHILD_CTRL_IN || i == REPRL_CHILD_CTRL_OUT || i == REPRL_CHILD_DATA_IN || i == REPRL_CHILD_DATA_OUT) { - continue; - } - close(i); - } + close_all_non_reprl_fds(); execve(ctx->argv[0], ctx->argv, ctx->envp);