148 changes: 148 additions & 0 deletions libc/utils/gpu/server/rpc_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
#include "llvmlibc_rpc_server.h"

#include "src/__support/RPC/rpc.h"
#include "src/__support/arg_list.h"
#include "src/stdio/printf_core/converter.h"
#include "src/stdio/printf_core/parser.h"
#include "src/stdio/printf_core/writer.h"

#include "src/stdio/gpu/file.h"
#include <algorithm>
#include <atomic>
#include <cstdio>
#include <cstring>
Expand All @@ -25,13 +31,149 @@
#include <vector>

using namespace LIBC_NAMESPACE;
using namespace LIBC_NAMESPACE::printf_core;

static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
"Buffer size mismatch");

static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::MAX_PORT_COUNT,
"Incorrect maximum port count");

template <uint32_t lane_size> void handle_printf(rpc::Server::Port &port) {
FILE *files[lane_size] = {nullptr};
// Get the appropriate output stream to use.
if (port.get_opcode() == RPC_PRINTF_TO_STREAM)
port.recv([&](rpc::Buffer *buffer, uint32_t id) {
files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
});
else if (port.get_opcode() == RPC_PRINTF_TO_STDOUT)
std::fill(files, files + lane_size, stdout);
else
std::fill(files, files + lane_size, stderr);

uint64_t format_sizes[lane_size] = {0};
void *format[lane_size] = {nullptr};

uint64_t args_sizes[lane_size] = {0};
void *args[lane_size] = {nullptr};

// Recieve the format string and arguments from the client.
port.recv_n(format, format_sizes,
[&](uint64_t size) { return new char[size]; });
port.recv_n(args, args_sizes, [&](uint64_t size) { return new char[size]; });

// Identify any arguments that are actually pointers to strings on the client.
// Additionally we want to determine how much buffer space we need to print.
std::vector<void *> strs_to_copy[lane_size];
int buffer_size[lane_size] = {0};
for (uint32_t lane = 0; lane < lane_size; ++lane) {
if (!format[lane])
continue;

WriteBuffer wb(nullptr, 0);
Writer writer(&wb);

internal::StructArgList printf_args(args[lane], args_sizes[lane]);
Parser<internal::StructArgList> parser(
reinterpret_cast<const char *>(format[lane]), printf_args);

for (FormatSection cur_section = parser.get_next_section();
!cur_section.raw_string.empty();
cur_section = parser.get_next_section()) {
if (cur_section.has_conv && cur_section.conv_name == 's' &&
cur_section.conv_val_ptr) {
strs_to_copy[lane].emplace_back(cur_section.conv_val_ptr);
} else if (cur_section.has_conv) {
// Ignore conversion errors for the first pass.
convert(&writer, cur_section);
} else {
writer.write(cur_section.raw_string);
}
}
buffer_size[lane] = writer.get_chars_written();
}

// Recieve any strings from the client and push them into a buffer.
std::vector<void *> copied_strs[lane_size];
while (std::any_of(std::begin(strs_to_copy), std::end(strs_to_copy),
[](const auto &v) { return !v.empty() && v.back(); })) {
port.send([&](rpc::Buffer *buffer, uint32_t id) {
void *ptr = !strs_to_copy[id].empty() ? strs_to_copy[id].back() : nullptr;
buffer->data[1] = reinterpret_cast<uintptr_t>(ptr);
if (!strs_to_copy[id].empty())
strs_to_copy[id].pop_back();
});
uint64_t str_sizes[lane_size] = {0};
void *strs[lane_size] = {nullptr};
port.recv_n(strs, str_sizes, [](uint64_t size) { return new char[size]; });
for (uint32_t lane = 0; lane < lane_size; ++lane) {
if (!strs[lane])
continue;

copied_strs[lane].emplace_back(strs[lane]);
buffer_size[lane] += str_sizes[lane];
}
}

// Perform the final formatting and printing using the LLVM C library printf.
int results[lane_size] = {0};
std::vector<void *> to_be_deleted;
for (uint32_t lane = 0; lane < lane_size; ++lane) {
if (!format[lane])
continue;

std::unique_ptr<char[]> buffer(new char[buffer_size[lane]]);
WriteBuffer wb(buffer.get(), buffer_size[lane]);
Writer writer(&wb);

internal::StructArgList printf_args(args[lane], args_sizes[lane]);
Parser<internal::StructArgList> parser(
reinterpret_cast<const char *>(format[lane]), printf_args);

// Parse and print the format string using the arguments we copied from
// the client.
int ret = 0;
for (FormatSection cur_section = parser.get_next_section();
!cur_section.raw_string.empty();
cur_section = parser.get_next_section()) {
// If this argument was a string we use the memory buffer we copied from
// the client by replacing the raw pointer with the copied one.
if (cur_section.has_conv && cur_section.conv_name == 's') {
if (!copied_strs[lane].empty()) {
cur_section.conv_val_ptr = copied_strs[lane].back();
to_be_deleted.push_back(copied_strs[lane].back());
copied_strs[lane].pop_back();
} else {
cur_section.conv_val_ptr = nullptr;
}
}
if (cur_section.has_conv) {
ret = convert(&writer, cur_section);
if (ret == -1)
break;
} else {
writer.write(cur_section.raw_string);
}
}

results[lane] =
fwrite(buffer.get(), 1, writer.get_chars_written(), files[lane]);
if (results[lane] != writer.get_chars_written() || ret == -1)
results[lane] = -1;
}

// Send the final return value and signal completion by setting the string
// argument to null.
port.send([&](rpc::Buffer *buffer, uint32_t id) {
buffer->data[0] = static_cast<uint64_t>(results[id]);
buffer->data[1] = reinterpret_cast<uintptr_t>(nullptr);
delete[] reinterpret_cast<char *>(format[id]);
delete[] reinterpret_cast<char *>(args[id]);
});
for (void *ptr : to_be_deleted)
delete[] reinterpret_cast<char *>(ptr);
}

template <uint32_t lane_size>
rpc_status_t handle_server_impl(
rpc::Server &server,
Expand Down Expand Up @@ -195,6 +337,12 @@ rpc_status_t handle_server_impl(
});
break;
}
case RPC_PRINTF_TO_STREAM:
case RPC_PRINTF_TO_STDOUT:
case RPC_PRINTF_TO_STDERR: {
handle_printf<lane_size>(*port);
break;
}
case RPC_NOOP: {
port->recv([](rpc::Buffer *) {});
break;
Expand Down