Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[libc] Scan the ports more fairly in the RPC server #66680

Merged
merged 1 commit into from
Sep 26, 2023

Conversation

jhuber6
Copy link
Contributor

@jhuber6 jhuber6 commented Sep 18, 2023

Summary:
Currently, we use the RPC server to respond to different ports which
each contain a request from some client thread wishing to do work on the
server. This scan starts at zero and continues until its checked all
ports at which point it resets. If we find an active port, we service it
and then restart the search.

This is bad for two reasons. First, it means that we will always bias
the lower ports. If a thread grabs a high port it will be stuck for a
very long time until all the other work is done. Second, it means that
the handle_server function can technically run indefinitely as long as
the client is always pushing new work. Because the OpenMP implementation
uses the user thread to service the kernel, this means that it could be
stalled with another asyncrhonous device's kernels.

This patch addresses this by making the server restart at the next port
over. This means we will always do a full scan of the ports before
quitting.

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 18, 2023

@llvm/pr-subscribers-libc

Changes

Summary:
Currently, we use the RPC server to respond to different ports which
each contain a request from some client thread wishing to do work on the
server. This scan starts at zero and continues until its checked all
ports at which point it resets. If we find an active port, we service it
and then restart the search.

This is bad for two reasons. First, it means that we will always bias
the lower ports. If a thread grabs a high port it will be stuck for a
very long time until all the other work is done. Second, it means that
the handle_server function can technically run indefinitely as long as
the client is always pushing new work. Because the OpenMP implementation
uses the user thread to service the kernel, this means that it could be
stalled with another asyncrhonous device's kernels.

This patch addresses this by making the server restart at the next port
over. This means we will always do a full scan of the ports before
quitting.


Full diff: https://github.com/llvm/llvm-project/pull/66680.diff

2 Files Affected:

  • (modified) libc/src/__support/RPC/rpc.h (+5-3)
  • (modified) libc/utils/gpu/server/rpc_server.cpp (+14-8)
diff --git a/libc/src/__support/RPC/rpc.h b/libc/src/__support/RPC/rpc.h
index fc95e5edf1c7209..081f4daee904521 100644
--- a/libc/src/__support/RPC/rpc.h
+++ b/libc/src/__support/RPC/rpc.h
@@ -325,6 +325,8 @@ template <bool T, typename S> struct Port {
     return process.packet[index].header.opcode;
   }
 
+  LIBC_INLINE uint16_t get_index() const { return index; }
+
   LIBC_INLINE void close() {
     // The server is passive, if it own the buffer when it closes we need to
     // give ownership back to the client.
@@ -372,7 +374,7 @@ template <uint32_t lane_size> struct Server {
   LIBC_INLINE ~Server() = default;
 
   using Port = rpc::Port<true, Packet<lane_size>>;
-  LIBC_INLINE cpp::optional<Port> try_open();
+  LIBC_INLINE cpp::optional<Port> try_open(uint32_t start = 0);
   LIBC_INLINE Port open();
 
   LIBC_INLINE void reset(uint32_t port_count, void *buffer) {
@@ -560,9 +562,9 @@ template <uint16_t opcode> LIBC_INLINE Client::Port Client::open() {
 template <uint32_t lane_size>
 [[clang::convergent]] LIBC_INLINE
     cpp::optional<typename Server<lane_size>::Port>
-    Server<lane_size>::try_open() {
+    Server<lane_size>::try_open(uint32_t start) {
   // Perform a naive linear scan for a port that has a pending request.
-  for (uint32_t index = 0; index < process.port_count; ++index) {
+  for (uint32_t index = start; index < process.port_count; ++index) {
     uint64_t lane_mask = gpu::get_lane_mask();
     uint32_t in = process.load_inbox(lane_mask, index);
     uint32_t out = process.load_outbox(lane_mask, index);
diff --git a/libc/utils/gpu/server/rpc_server.cpp b/libc/utils/gpu/server/rpc_server.cpp
index c98b9fa46ce058b..02965e4b4a480cc 100644
--- a/libc/utils/gpu/server/rpc_server.cpp
+++ b/libc/utils/gpu/server/rpc_server.cpp
@@ -53,12 +53,13 @@ struct Server {
   }
 
   rpc_status_t handle_server(
-      std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
-      std::unordered_map<rpc_opcode_t, void *> &callback_data) {
+      const std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
+      const std::unordered_map<rpc_opcode_t, void *> &callback_data,
+      uint32_t &index) {
     rpc_status_t ret = RPC_STATUS_SUCCESS;
     std::visit(
         [&](auto &server) {
-          ret = handle_server(*server, callbacks, callback_data);
+          ret = handle_server(*server, callbacks, callback_data, index);
         },
         server);
     return ret;
@@ -68,9 +69,10 @@ struct Server {
   template <uint32_t lane_size>
   rpc_status_t handle_server(
       rpc::Server<lane_size> &server,
-      std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
-      std::unordered_map<rpc_opcode_t, void *> &callback_data) {
-    auto port = server.try_open();
+      const std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
+      const std::unordered_map<rpc_opcode_t, void *> &callback_data,
+      uint32_t &index) {
+    auto port = server.try_open(index);
     if (!port)
       return RPC_STATUS_SUCCESS;
 
@@ -202,6 +204,9 @@ struct Server {
       (handler->second)(port_ref, data);
     }
     }
+
+    // Increment the index so we start the scan after this port.
+    index = port->get_index() + 1;
     port->close();
     return RPC_STATUS_CONTINUE;
   }
@@ -317,10 +322,11 @@ rpc_status_t rpc_handle_server(uint32_t device_id) {
   if (!state->devices[device_id])
     return RPC_STATUS_ERROR;
 
+  uint32_t index = 0;
   for (;;) {
     auto &device = *state->devices[device_id];
-    rpc_status_t status =
-        device.server.handle_server(device.callbacks, device.callback_data);
+    rpc_status_t status = device.server.handle_server(
+        device.callbacks, device.callback_data, index);
     if (status != RPC_STATUS_CONTINUE)
       return status;
   }

@jplehr
Copy link
Contributor

jplehr commented Sep 18, 2023

I believe this makes sense.

// Perform a naive linear scan for a port that has a pending request.
for (uint32_t index = 0; index < process.port_count; ++index) {
for (uint32_t index = start; index < process.port_count; ++index) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have assumed you scan start - end, 0 - start-1. Scanning only part of the space is likely not what you want. If process.port_count is a compile time constant, and a power of two, just do the modulo. Otherwise do if (index > count) index -= count;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to make sure that handle_server runs a fixed amount of times because we use the user thread to check it. This prevents the situation where some other kernel's RPC calls perpetually block a user's thread from finishing, e.g.

while (!kernel.done)
   server.handle_server() // Could pick up work from other async kernels.

I haven't yet introduced sleeping to the busy-wait in OpenMP, which would then probably not want to increase a sleep duration if we just finished a full scan or something.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with Johannes - start until end is weird on a circular buffer. Suggest we check N ports on each call, where N may as well be the size of the buffer

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to make sure that handle_server runs a fixed amount of times because we use the user thread to check it.

As @JonChesterfield noted, my suggestion does not violate this. Each call should check all ports.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I'm intending with this logic. When the user calls handle_server it will guaruntee that all 1 through N ports will be checked. If it finds work on a port, it handles it, then picks back up where it left off. If we make each check go 1 through N modulo the port size then we'll be in the same situation where we don't know whether or not handle_server will run a bounded number of times.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I am not seeing this to run through all N ports. Let's see the last picked one is N-2, then index is N-1, and this checks only a single port, namely N-1. Before, it always checked N ports, but always starting at 0. What we want is N (=all) ports are checked, but starting at the last picked index so you can't starve a port since the one we picked last is checked last in the next run. My proposal was to run N checks, nothing unbounded.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at the greater scope. This isn't handle_server this is try_open. handle_server passes in start as the last position it found a port. If we instead always check starting at N we will then continuously check forever as long as there is work still being submitted faster to the RPC interface than it can consume it. We can potentially do that, but I was just thinking that it would be good to guarantee that this routine, when called by a blocking thread, will return in a specific number of checks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I think I see the disconnect. Your scheme doesn't actually provide the guarantee I thought I was, which is why we talk past each other. In your patch, one call to handle_server will always start to process the ports from 0 till the end. With this patch it will ensure termination as it checks each port once. That is clearly better than the code before, not necessarily because it could get stuck, but because it could starve ports. I don't think stuck performing work is too bad for the host thread. For now I don't see what the benefit would be to return to the caller of handle_server, just to let them call it again.

That all said, what I thought you were doing is to remember the index of the last port that was handled across calls. Such that you'd scan from there the next time. Now, you still always prefer port 0 over port 5 even if you just handled port 0 last time. That said, this patch at least ensures both are handled during a handle_server call. If index is persistent though, which is what I thought, port 5 would be prioritized over port 0 if port 0 was handled last time. In that scheme, you need to iterate from 0 to index at the end, since index did not start at 0. Hence my request for the extra iterations.

I still think the persistent index has benefits when it comes to fairness, but I agree this patch in itself does what you describe in a sensible way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I don't see what the benefit would be to return to the caller of handle_server, just to let them call it again.

My main concern was if we had the user's thread running the server it could potentially be blocked for a long time if it keeps picking up some other kernel's work. I could instead do the method where we start at the last known index, but always check N ports.

Copy link
Member

@jdoerfert jdoerfert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We agreed that the extra state we'd need to "start from where we left of", is not worth it.

LG

Summary:
Currently, we use the RPC server to respond to different ports which
each contain a request from some client thread wishing to do work on the
server. This scan starts at zero and continues until its checked all
ports at which point it resets. If we find an active port, we service it
and then restart the search.

This is bad for two reasons. First, it means that we will always bias
the lower ports. If a thread grabs a high port it will be stuck for a
very long time until all the other work is done. Second, it means that
the `handle_server` function can technically run indefinitely as long as
the client is always pushing new work. Because the OpenMP implementation
uses the user thread to service the kernel, this means that it could be
stalled with another asyncrhonous device's kernels.

This patch addresses this by making the server restart at the next port
over. This means we will always do a full scan of the ports before
quitting.
@jhuber6 jhuber6 merged commit 1a5d3b6 into llvm:main Sep 26, 2023
2 of 3 checks passed
legrosbuffle pushed a commit to legrosbuffle/llvm-project that referenced this pull request Sep 29, 2023
Summary:
Currently, we use the RPC server to respond to different ports which
each contain a request from some client thread wishing to do work on the
server. This scan starts at zero and continues until its checked all
ports at which point it resets. If we find an active port, we service it
and then restart the search.

This is bad for two reasons. First, it means that we will always bias
the lower ports. If a thread grabs a high port it will be stuck for a
very long time until all the other work is done. Second, it means that
the `handle_server` function can technically run indefinitely as long as
the client is always pushing new work. Because the OpenMP implementation
uses the user thread to service the kernel, this means that it could be
stalled with another asyncrhonous device's kernels.

This patch addresses this by making the server restart at the next port
over. This means we will always do a full scan of the ports before
quitting.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants