6 changes: 4 additions & 2 deletions libc/test/integration/startup/gpu/rpc_stream_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ static void test_stream() {

inline_memcpy(send_ptr, str, send_size);
ASSERT_TRUE(inline_memcmp(send_ptr, str, send_size) == 0 && "Data mismatch");
rpc::Client::Port port = rpc::client.open<RPC_TEST_STREAM>();
LIBC_NAMESPACE::rpc::Client::Port port =
LIBC_NAMESPACE::rpc::client.open<RPC_TEST_STREAM>();
port.send_n(send_ptr, send_size);
port.recv_n(&recv_ptr, &recv_size,
[](uint64_t size) { return malloc(size); });
Expand Down Expand Up @@ -77,7 +78,8 @@ static void test_divergent() {
inline_memcpy(buffer, &data[offset], offset);
ASSERT_TRUE(inline_memcmp(buffer, &data[offset], offset) == 0 &&
"Data mismatch");
rpc::Client::Port port = rpc::client.open<RPC_TEST_STREAM>();
LIBC_NAMESPACE::rpc::Client::Port port =
LIBC_NAMESPACE::rpc::client.open<RPC_TEST_STREAM>();
port.send_n(buffer, offset);
inline_memset(buffer, offset, 0);
port.recv_n(&recv_ptr, &recv_size, [&](uint64_t) { return buffer; });
Expand Down
14 changes: 9 additions & 5 deletions libc/test/integration/startup/gpu/rpc_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ static void test_add_simple() {
10 + 10 * gpu::get_thread_id() + 10 * gpu::get_block_id();
uint64_t cnt = 0;
for (uint32_t i = 0; i < num_additions; ++i) {
rpc::Client::Port port = rpc::client.open<RPC_TEST_INCREMENT>();
LIBC_NAMESPACE::rpc::Client::Port port =
LIBC_NAMESPACE::rpc::client.open<RPC_TEST_INCREMENT>();
port.send_and_recv(
[=](rpc::Buffer *buffer, uint32_t) {
[=](LIBC_NAMESPACE::rpc::Buffer *buffer, uint32_t) {
reinterpret_cast<uint64_t *>(buffer->data)[0] = cnt;
},
[&](rpc::Buffer *buffer, uint32_t) {
[&](LIBC_NAMESPACE::rpc::Buffer *buffer, uint32_t) {
cnt = reinterpret_cast<uint64_t *>(buffer->data)[0];
});
port.close();
Expand All @@ -33,8 +34,11 @@ static void test_add_simple() {

// Test to ensure that the RPC mechanism doesn't hang on divergence.
static void test_noop(uint8_t data) {
rpc::Client::Port port = rpc::client.open<RPC_NOOP>();
port.send([=](rpc::Buffer *buffer, uint32_t) { buffer->data[0] = data; });
LIBC_NAMESPACE::rpc::Client::Port port =
LIBC_NAMESPACE::rpc::client.open<RPC_NOOP>();
port.send([=](LIBC_NAMESPACE::rpc::Buffer *buffer, uint32_t) {
buffer->data[0] = data;
});
port.close();
}

Expand Down
52 changes: 36 additions & 16 deletions libc/utils/gpu/server/rpc_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
// Make sure these are included first so they don't conflict with the system.
#include <limits.h>

#include "shared/rpc.h"

#include "llvmlibc_rpc_server.h"

#include "src/__support/RPC/rpc.h"
#include "include/llvm-libc-types/rpc_opcodes_t.h"
#include "src/__support/arg_list.h"
#include "src/stdio/printf_core/converter.h"
#include "src/stdio/printf_core/parser.h"
#include "src/stdio/printf_core/writer.h"

#include "src/stdio/gpu/file.h"
#include <algorithm>
#include <atomic>
#include <cstdio>
Expand Down Expand Up @@ -53,6 +54,26 @@ struct TempStorage {
};
} // namespace

enum Stream {
File = 0,
Stdin = 1,
Stdout = 2,
Stderr = 3,
};

// Get the associated stream out of an encoded number.
LIBC_INLINE ::FILE *to_stream(uintptr_t f) {
::FILE *stream = reinterpret_cast<FILE *>(f & ~0x3ull);
Stream type = static_cast<Stream>(f & 0x3ull);
if (type == Stdin)
return stdin;
if (type == Stdout)
return stdout;
if (type == Stderr)
return stderr;
return stream;
}

template <bool packed, uint32_t lane_size>
static void handle_printf(rpc::Server::Port &port, TempStorage &temp_storage) {
FILE *files[lane_size] = {nullptr};
Expand Down Expand Up @@ -260,7 +281,7 @@ rpc_status_t handle_server_impl(
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
data[id] = temp_storage.alloc(buffer->data[0]);
sizes[id] =
fread(data[id], 1, buffer->data[0], file::to_stream(buffer->data[1]));
fread(data[id], 1, buffer->data[0], to_stream(buffer->data[1]));
});
port->send_n(data, sizes);
port->send([&](rpc::Buffer *buffer, uint32_t id) {
Expand All @@ -273,9 +294,8 @@ rpc_status_t handle_server_impl(
void *data[lane_size] = {nullptr};
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
data[id] = temp_storage.alloc(buffer->data[0]);
const char *str =
fgets(reinterpret_cast<char *>(data[id]), buffer->data[0],
file::to_stream(buffer->data[1]));
const char *str = fgets(reinterpret_cast<char *>(data[id]),
buffer->data[0], to_stream(buffer->data[1]));
sizes[id] = !str ? 0 : std::strlen(str) + 1;
});
port->send_n(data, sizes);
Expand Down Expand Up @@ -335,46 +355,46 @@ rpc_status_t handle_server_impl(
}
case RPC_FEOF: {
port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
buffer->data[0] = feof(file::to_stream(buffer->data[0]));
buffer->data[0] = feof(to_stream(buffer->data[0]));
});
break;
}
case RPC_FERROR: {
port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
buffer->data[0] = ferror(file::to_stream(buffer->data[0]));
buffer->data[0] = ferror(to_stream(buffer->data[0]));
});
break;
}
case RPC_CLEARERR: {
port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
clearerr(file::to_stream(buffer->data[0]));
clearerr(to_stream(buffer->data[0]));
});
break;
}
case RPC_FSEEK: {
port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
buffer->data[0] = fseek(file::to_stream(buffer->data[0]),
static_cast<long>(buffer->data[1]),
static_cast<int>(buffer->data[2]));
buffer->data[0] =
fseek(to_stream(buffer->data[0]), static_cast<long>(buffer->data[1]),
static_cast<int>(buffer->data[2]));
});
break;
}
case RPC_FTELL: {
port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
buffer->data[0] = ftell(file::to_stream(buffer->data[0]));
buffer->data[0] = ftell(to_stream(buffer->data[0]));
});
break;
}
case RPC_FFLUSH: {
port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
buffer->data[0] = fflush(file::to_stream(buffer->data[0]));
buffer->data[0] = fflush(to_stream(buffer->data[0]));
});
break;
}
case RPC_UNGETC: {
port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
buffer->data[0] = ungetc(static_cast<int>(buffer->data[0]),
file::to_stream(buffer->data[1]));
buffer->data[0] =
ungetc(static_cast<int>(buffer->data[0]), to_stream(buffer->data[1]));
});
break;
}
Expand Down