Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement shortest-path routing for inference #362

Merged
merged 13 commits into from
Jul 18, 2023
Merged

Implement shortest-path routing for inference #362

merged 13 commits into from
Jul 18, 2023

Conversation

borzunov
Copy link
Collaborator

@borzunov borzunov commented Jul 17, 2023

This PR:

  1. Adds shortest path routing for inference. We build a graph with client-server and server-server latencies and compute costs, as well as empirically measured overheads. For client-server latencies, we ping possible first and last servers in a sequence in SequenceManager.update(). We penalize servers who may not have enough cache for our request. This uses info added to DHT in Share more info about a server in DHT #355, Make a server ping next servers #356, Report inference, forward, and network RPS separately #358.

  2. Makes a server ping neighboring servers in addition to next ones. This is to get an opportunity to change the server even before we use all its blocks (e.g., because a neighboring server is faster). This feature is not enabled though, since it increases graph size for N servers to O(N^2) - but we may enable it if needed.

  3. Fixes a SequenceManager bug with the first update(). Previously, this update was likely to produce incorrect information and cause to MissingBlocksErrors until the next update happens.

@borzunov borzunov marked this pull request as draft July 17, 2023 09:48
@borzunov borzunov changed the title Implement shortest-path routing Implement shortest-path routing for inference Jul 17, 2023
@borzunov borzunov force-pushed the shortest-path branch 5 times, most recently from f05a3d7 to 3173d95 Compare July 17, 2023 11:16

# This is a pessimistic estimate that assumes that we'll use all blocks hosted by this server,
# which is not always true. This is okay since false positives are more costly than false negatives here.
return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the servers should report cache tokens*layers to the DHT so we can avoid estimates? In other words, multiply whatever is reported by the number of blocks hosted.

iirc, it is a relatively new feature that can be modified without ache

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's already reported this way. I've updated the comment to improve clarity.

@@ -48,6 +48,7 @@ install_requires =
sentencepiece>=0.1.99
peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735
safetensors>=0.3.1
Dijkstar>=2.6.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: perhaps we should set <3.0.0 to protect against interface changes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I use only the simplest graph interface - I think it's unlikely to be changed and/or can be quickly updated.

end_index: int,
*,
cache_tokens_needed: Optional[int],
default_inference_rps: float = 300, # If inference RPS unknown
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: consider moving these defaults to sequence manager config so that they are accessible without editing the source code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are really low-level constants that's unlikely to be changed by a user, so I wouldn't pollute the config with them (it already has lots of knobs).

raise MissingBlocksError(missing_blocks)

client_server_rtts = self.ping_aggregator.to_dict()
logger.info(f"Client-server RTTs: {client_server_rtts}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: do we want to print it to every client? If not, consider logger.debug

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I use them for debug, I'll convert all logger.info() to logger.debug() before merging

inference_rps = span.server_info.inference_rps
if inference_rps is None:
inference_rps = default_inference_rps
graph.add_edge((span.peer_id, block_idx), (span.peer_id, block_idx + 1), 1 / inference_rps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do you, by chance, have some mockup performance numbers for graph construction & pathfinding for simulated larger graphs? If no, please remind me to run them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we allow switching a server before going to the end of the first one, we get O(N^2) graph for servers. For N = 50 servers holding blocks 0..80, building a graph and looking for the shortest path takes ~0.3 sec and scales as O(N^2).

I'll forbid such switching, this makes graph much smaller (takes ~0.03 sec now) and scales much better.

@borzunov borzunov marked this pull request as ready for review July 18, 2023 04:14
@borzunov borzunov merged commit 62d9ed5 into main Jul 18, 2023
7 checks passed
@borzunov borzunov deleted the shortest-path branch July 18, 2023 04:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants