-
Notifications
You must be signed in to change notification settings - Fork 477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement shortest-path routing for inference #362
Conversation
f05a3d7
to
3173d95
Compare
|
||
# This is a pessimistic estimate that assumes that we'll use all blocks hosted by this server, | ||
# which is not always true. This is okay since false positives are more costly than false negatives here. | ||
return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps the servers should report cache tokens*layers to the DHT so we can avoid estimates? In other words, multiply whatever is reported by the number of blocks hosted.
iirc, it is a relatively new feature that can be modified without ache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's already reported this way. I've updated the comment to improve clarity.
@@ -48,6 +48,7 @@ install_requires = | |||
sentencepiece>=0.1.99 | |||
peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735 | |||
safetensors>=0.3.1 | |||
Dijkstar>=2.6.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: perhaps we should set <3.0.0 to protect against interface changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, I use only the simplest graph interface - I think it's unlikely to be changed and/or can be quickly updated.
end_index: int, | ||
*, | ||
cache_tokens_needed: Optional[int], | ||
default_inference_rps: float = 300, # If inference RPS unknown |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: consider moving these defaults to sequence manager config so that they are accessible without editing the source code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are really low-level constants that's unlikely to be changed by a user, so I wouldn't pollute the config with them (it already has lots of knobs).
raise MissingBlocksError(missing_blocks) | ||
|
||
client_server_rtts = self.ping_aggregator.to_dict() | ||
logger.info(f"Client-server RTTs: {client_server_rtts}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: do we want to print it to every client? If not, consider logger.debug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! I use them for debug, I'll convert all logger.info()
to logger.debug()
before merging
inference_rps = span.server_info.inference_rps | ||
if inference_rps is None: | ||
inference_rps = default_inference_rps | ||
graph.add_edge((span.peer_id, block_idx), (span.peer_id, block_idx + 1), 1 / inference_rps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do you, by chance, have some mockup performance numbers for graph construction & pathfinding for simulated larger graphs? If no, please remind me to run them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we allow switching a server before going to the end of the first one, we get O(N^2) graph for servers. For N = 50 servers holding blocks 0..80, building a graph and looking for the shortest path takes ~0.3 sec and scales as O(N^2).
I'll forbid such switching, this makes graph much smaller (takes ~0.03 sec now) and scales much better.
This PR:
Adds shortest path routing for inference. We build a graph with client-server and server-server latencies and compute costs, as well as empirically measured overheads. For client-server latencies, we ping possible first and last servers in a sequence in
SequenceManager.update()
. We penalize servers who may not have enough cache for our request. This uses info added to DHT in Share more info about a server in DHT #355, Make a server ping next servers #356, Report inference, forward, and network RPS separately #358.Makes a server ping neighboring servers in addition to next ones. This is to get an opportunity to change the server even before we use all its blocks (e.g., because a neighboring server is faster). This feature is not enabled though, since it increases graph size for N servers to O(N^2) - but we may enable it if needed.
Fixes a
SequenceManager
bug with the firstupdate()
. Previously, this update was likely to produce incorrect information and cause toMissingBlocksErrors
until the next update happens.