86 changes: 86 additions & 0 deletions libc/src/stdio/gpu/vfprintf_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//===--- GPU helper functions for printf using RPC ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "src/__support/RPC/rpc_client.h"
#include "src/__support/arg_list.h"
#include "src/stdio/gpu/file.h"
#include "src/string/string_utils.h"

#include <stdio.h>

namespace LIBC_NAMESPACE {

template <uint16_t opcode>
LIBC_INLINE int vfprintf_impl(::FILE *__restrict file,
const char *__restrict format, size_t format_size,
va_list vlist) {
uint64_t mask = gpu::get_lane_mask();
rpc::Client::Port port = rpc::client.open<opcode>();

if constexpr (opcode == RPC_PRINTF_TO_STREAM ||
opcode == RPC_PRINTF_TO_STREAM_PACKED) {
port.send([&](rpc::Buffer *buffer) {
buffer->data[0] = reinterpret_cast<uintptr_t>(file);
});
}

size_t args_size = 0;
port.send_n(format, format_size);
port.recv([&](rpc::Buffer *buffer) {
args_size = static_cast<size_t>(buffer->data[0]);
});
port.send_n(vlist, args_size);

uint32_t ret = 0;
for (;;) {
const char *str = nullptr;
port.recv([&](rpc::Buffer *buffer) {
ret = static_cast<uint32_t>(buffer->data[0]);
str = reinterpret_cast<const char *>(buffer->data[1]);
});
// If any lanes have a string argument it needs to be copied back.
if (!gpu::ballot(mask, str))
break;

uint64_t size = str ? internal::string_length(str) + 1 : 0;
port.send_n(str, size);
}

port.close();
return ret;
}

LIBC_INLINE int vfprintf_internal(::FILE *__restrict stream,
const char *__restrict format,
size_t format_size, va_list vlist) {
// The AMDPGU backend uses a packed struct for its varargs. We pass it as a
// separate opcode so the server knows how much to advance the pointers.
#if defined(LIBC_TARGET_ARCH_IS_AMDGPU)
if (stream == stdout)
return vfprintf_impl<RPC_PRINTF_TO_STDOUT_PACKED>(stream, format,
format_size, vlist);
else if (stream == stderr)
return vfprintf_impl<RPC_PRINTF_TO_STDERR_PACKED>(stream, format,
format_size, vlist);
else
return vfprintf_impl<RPC_PRINTF_TO_STREAM_PACKED>(stream, format,
format_size, vlist);
#else
if (stream == stdout)
return vfprintf_impl<RPC_PRINTF_TO_STDOUT>(stream, format, format_size,
vlist);
else if (stream == stderr)
return vfprintf_impl<RPC_PRINTF_TO_STDERR>(stream, format, format_size,
vlist);
else
return vfprintf_impl<RPC_PRINTF_TO_STREAM>(stream, format, format_size,
vlist);
#endif
}

} // namespace LIBC_NAMESPACE
27 changes: 27 additions & 0 deletions libc/src/stdio/gpu/vprintf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===-- GPU Implementation of vprintf -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "src/stdio/vprintf.h"

#include "src/__support/CPP/string_view.h"
#include "src/__support/arg_list.h"
#include "src/errno/libc_errno.h"
#include "src/stdio/gpu/vfprintf_utils.h"

#include <stdio.h>

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(int, vprintf,
(const char *__restrict format, va_list vlist)) {
cpp::string_view str_view(format);
int ret_val = vfprintf_internal(stdout, format, str_view.size() + 1, vlist);
return ret_val;
}

} // namespace LIBC_NAMESPACE
4 changes: 2 additions & 2 deletions libc/test/integration/src/stdio/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ add_integration_test(
SUITE
stdio-gpu-integration-tests
SRCS
printf.cpp
printf_test.cpp
DEPENDS
libc.src.gpu.rpc_fprintf
libc.src.stdio.fprintf
libc.src.stdio.fopen
LOADER_ARGS
--threads 32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#include "test/IntegrationTest/test.h"

#include "src/__support/GPU/utils.h"
#include "src/gpu/rpc_fprintf.h"
#include "src/stdio/fopen.h"
#include "src/stdio/fprintf.h"

using namespace LIBC_NAMESPACE;

Expand All @@ -20,68 +20,48 @@ TEST_MAIN(int argc, char **argv, char **envp) {
ASSERT_TRUE(file && "failed to open file");
// Check basic printing.
int written = 0;
written = LIBC_NAMESPACE::rpc_fprintf(file, "A simple string\n", nullptr, 0);
written = LIBC_NAMESPACE::fprintf(file, "A simple string\n");
ASSERT_EQ(written, 16);

const char *str = "A simple string\n";
written = LIBC_NAMESPACE::rpc_fprintf(file, "%s", &str, sizeof(void *));
written = LIBC_NAMESPACE::fprintf(file, "%s", str);
ASSERT_EQ(written, 16);

// Check printing a different value with each thread.
uint64_t thread_id = gpu::get_thread_id();
written = LIBC_NAMESPACE::rpc_fprintf(file, "%8ld\n", &thread_id,
sizeof(thread_id));
written = LIBC_NAMESPACE::fprintf(file, "%8ld\n", thread_id);
ASSERT_EQ(written, 9);

struct {
uint32_t x = 1;
char c = 'c';
double f = 1.0;
} args1;
written =
LIBC_NAMESPACE::rpc_fprintf(file, "%d%c%.1f\n", &args1, sizeof(args1));
written = LIBC_NAMESPACE::fprintf(file, "%d%c%.1f\n", 1, 'c', 1.0);
ASSERT_EQ(written, 6);

struct {
uint32_t x = 1;
const char *str = "A simple string\n";
} args2;
written =
LIBC_NAMESPACE::rpc_fprintf(file, "%032b%s\n", &args2, sizeof(args2));
written = LIBC_NAMESPACE::fprintf(file, "%032b%s\n", 1, "A simple string\n");
ASSERT_EQ(written, 49);

// Check that the server correctly handles divergent numbers of arguments.
const char *format = gpu::get_thread_id() % 2 ? "%s" : "%20ld\n";
written = LIBC_NAMESPACE::rpc_fprintf(file, format, &str, sizeof(void *));
written = LIBC_NAMESPACE::fprintf(file, format, str);
ASSERT_EQ(written, gpu::get_thread_id() % 2 ? 16 : 21);

format = gpu::get_thread_id() % 2 ? "%s" : str;
written = LIBC_NAMESPACE::rpc_fprintf(file, format, &str, sizeof(void *));
written = LIBC_NAMESPACE::fprintf(file, format, str);
ASSERT_EQ(written, 16);

// Check that we handle null arguments correctly.
struct {
void *null = nullptr;
} args3;
written = LIBC_NAMESPACE::rpc_fprintf(file, "%p", &args3, sizeof(args3));
written = LIBC_NAMESPACE::fprintf(file, "%p", nullptr);
ASSERT_EQ(written, 9);

#ifndef LIBC_COPT_PRINTF_NO_NULLPTR_CHECKS
written = LIBC_NAMESPACE::rpc_fprintf(file, "%s", &args3, sizeof(args3));
written = LIBC_NAMESPACE::fprintf(file, "%s", nullptr);
ASSERT_EQ(written, 6);
#endif // LIBC_COPT_PRINTF_NO_NULLPTR_CHECKS

// Check for extremely abused variable width arguments
struct {
uint32_t x = 1;
uint32_t y = 2;
double f = 1.0;
} args4;
written = LIBC_NAMESPACE::rpc_fprintf(file, "%**d", &args4, sizeof(args4));
written = LIBC_NAMESPACE::fprintf(file, "%**d", 1, 2, 1.0);
ASSERT_EQ(written, 4);
written = LIBC_NAMESPACE::rpc_fprintf(file, "%**d%6d", &args4, sizeof(args4));
written = LIBC_NAMESPACE::fprintf(file, "%**d%6d", 1, 2, 1.0);
ASSERT_EQ(written, 10);
written = LIBC_NAMESPACE::rpc_fprintf(file, "%**.**f", &args4, sizeof(args4));
written = LIBC_NAMESPACE::fprintf(file, "%**.**f", 1, 2, 1.0);
ASSERT_EQ(written, 7);

return 0;
Expand Down
2 changes: 2 additions & 0 deletions libc/test/src/stdio/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ add_libc_test(
printf_test.cpp
DEPENDS
libc.src.stdio.printf
libc.src.stdio.stdout
)

add_fp_unittest(
Expand Down Expand Up @@ -234,6 +235,7 @@ add_libc_test(
vprintf_test.cpp
DEPENDS
libc.src.stdio.vprintf
libc.src.stdio.stdout
)


Expand Down
47 changes: 39 additions & 8 deletions libc/utils/gpu/server/rpc_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::MAX_PORT_COUNT,
"Incorrect maximum port count");

template <uint32_t lane_size> void handle_printf(rpc::Server::Port &port) {
template <bool packed, uint32_t lane_size>
void handle_printf(rpc::Server::Port &port) {
FILE *files[lane_size] = {nullptr};
// Get the appropriate output stream to use.
if (port.get_opcode() == RPC_PRINTF_TO_STREAM)
if (port.get_opcode() == RPC_PRINTF_TO_STREAM ||
port.get_opcode() == RPC_PRINTF_TO_STREAM_PACKED)
port.recv([&](rpc::Buffer *buffer, uint32_t id) {
files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
});
else if (port.get_opcode() == RPC_PRINTF_TO_STDOUT)
else if (port.get_opcode() == RPC_PRINTF_TO_STDOUT ||
port.get_opcode() == RPC_PRINTF_TO_STDOUT_PACKED)
std::fill(files, files + lane_size, stdout);
else
std::fill(files, files + lane_size, stderr);
Expand All @@ -60,6 +63,28 @@ template <uint32_t lane_size> void handle_printf(rpc::Server::Port &port) {
// Recieve the format string and arguments from the client.
port.recv_n(format, format_sizes,
[&](uint64_t size) { return new char[size]; });

// Parse the format string to get the expected size of the buffer.
for (uint32_t lane = 0; lane < lane_size; ++lane) {
if (!format[lane])
continue;

WriteBuffer wb(nullptr, 0);
Writer writer(&wb);

internal::DummyArgList<packed> printf_args;
Parser<internal::DummyArgList<packed> &> parser(
reinterpret_cast<const char *>(format[lane]), printf_args);

for (FormatSection cur_section = parser.get_next_section();
!cur_section.raw_string.empty();
cur_section = parser.get_next_section())
;
args_sizes[lane] = printf_args.read_count();
}
port.send([&](rpc::Buffer *buffer, uint32_t id) {
buffer->data[0] = args_sizes[id];
});
port.recv_n(args, args_sizes, [&](uint64_t size) { return new char[size]; });

// Identify any arguments that are actually pointers to strings on the client.
Expand All @@ -73,8 +98,8 @@ template <uint32_t lane_size> void handle_printf(rpc::Server::Port &port) {
WriteBuffer wb(nullptr, 0);
Writer writer(&wb);

internal::StructArgList printf_args(args[lane], args_sizes[lane]);
Parser<internal::StructArgList> parser(
internal::StructArgList<packed> printf_args(args[lane], args_sizes[lane]);
Parser<internal::StructArgList<packed>> parser(
reinterpret_cast<const char *>(format[lane]), printf_args);

for (FormatSection cur_section = parser.get_next_section();
Expand Down Expand Up @@ -126,8 +151,8 @@ template <uint32_t lane_size> void handle_printf(rpc::Server::Port &port) {
WriteBuffer wb(buffer.get(), buffer_size[lane]);
Writer writer(&wb);

internal::StructArgList printf_args(args[lane], args_sizes[lane]);
Parser<internal::StructArgList> parser(
internal::StructArgList<packed> printf_args(args[lane], args_sizes[lane]);
Parser<internal::StructArgList<packed>> parser(
reinterpret_cast<const char *>(format[lane]), printf_args);

// Parse and print the format string using the arguments we copied from
Expand Down Expand Up @@ -337,10 +362,16 @@ rpc_status_t handle_server_impl(
});
break;
}
case RPC_PRINTF_TO_STREAM_PACKED:
case RPC_PRINTF_TO_STDOUT_PACKED:
case RPC_PRINTF_TO_STDERR_PACKED: {
handle_printf<true, lane_size>(*port);
break;
}
case RPC_PRINTF_TO_STREAM:
case RPC_PRINTF_TO_STDOUT:
case RPC_PRINTF_TO_STDERR: {
handle_printf<lane_size>(*port);
handle_printf<false, lane_size>(*port);
break;
}
case RPC_REMOVE: {
Expand Down