-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model warmup support with AOT and endpoint for JetStream #92
Conversation
jetstream/core/orchestrator.py
Outdated
true_length=true_length, | ||
) | ||
if self.warmup_enabled: | ||
padded_token_length = token_utils.take_nearest_length( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I feel the warmup code should be outside of orchstrator. We would like keep orchstrator only contain necessary functions (benchmark or warmup is not necessary function), make sure the code is clean and clear. The logic is already very complex to read right now.
Can you do refactor and move warmup out of this class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, do you mean we move the warmup logic to a separate function and then invoke that function here or we call AOT at a completely different place outside of orchestrator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved the warmup code out of orchestrator but kept the check (if self.warmup_enabled) + its corresponding logic because of the following functionality: Once model warmup is called, its compiled form (prefill, insert, and generate), will be stored in their respective dictionaries for their corresponding bucket length. This compiled form should be called from now on, or else the JetStream server will experience compilation times.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong, we know what type of data prefill or decode should be process for warmup. In this case, all the code can be outside orchestrator.
@JoeZijunZhou Please also take a look. The orchestrator is already complex, we's better to keep this class only have main function code, other wise, it's hard to maintain and refactor in future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. I feel it's feasible to implement a wrapper for the engines, and pass the compiled engines etc in the driver init here if warmup is on: https://github.com/google/JetStream/blob/main/jetstream/core/server_lib.py#L141. Then, we don't need to change the orchestrator and the engine API, making the AOT warmup logics decoupled from the existing jetstream core and engine API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added two things:
- Logic to bake the warmup into model server startup
- Wrapper engine definition
WarmedUpEngine
Added some extra logic to help facilitate the engine / define the warm up state, since that is used later on in prefill threads and generate threads to determine which bucket is needed to called with.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@FanhaiLu1 Added the wrapper logic, ptal and let me know if it looks good! Thanks. We can have a follow up PR to address the performance degradation that occurs at larger batch sizes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it's feasible to move the warmup state and its related handling into WarmedUpEngine
, WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for Zijun's comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, decoupled the warmup handling per our discussion offline!
Thanks for adding the warmup support! Which vm did you run the test? |
@FanhaiLu1 I tested using v5e-8, I'll add this detail to the PR comment too. |
jetstream/core/orchestrator.py
Outdated
true_length=true_length, | ||
) | ||
if self.warmup_enabled: | ||
padded_token_length = token_utils.take_nearest_length( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, do you mean we move the warmup logic to a separate function and then invoke that function here or we call AOT at a completely different place outside of orchestrator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! I feel we could refactor the warmup logic to make it clean and decoupled.
jetstream/core/orchestrator.py
Outdated
true_length=true_length, | ||
) | ||
if self.warmup_enabled: | ||
padded_token_length = token_utils.take_nearest_length( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. I feel it's feasible to implement a wrapper for the engines, and pass the compiled engines etc in the driver init here if warmup is on: https://github.com/google/JetStream/blob/main/jetstream/core/server_lib.py#L141. Then, we don't need to change the orchestrator and the engine API, making the AOT warmup logics decoupled from the existing jetstream core and engine API.
jetstream/core/orchestrator.py
Outdated
true_length=true_length, | ||
) | ||
if self.warmup_enabled: | ||
padded_token_length = token_utils.take_nearest_length( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can probably put the logic of warm up into another Engine
implementation that takes an instance of Engine instance:
class WarmedUpEngine(engine_api.Engine):
def __init__(self, downstream_engine: engine_api.Engine):
# do compile, setup the dicts that maps int to jax Compiled.
def prefill(self, ...):
return self.compiled_prefill[seqlen](*args, **kwargs) etc
# same for insert / generate
Then,
in orchestrator, you only need
if warmed up:
self.engine = WarmedUpEngine(self.engine)
in init and rest dont need to change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synced with @vivianrwu , need to do complete refactor in orchestrator.py and server_lib.py in the following PRs. Approve for the init PR.
This PR aims to support model warmup on the JetStream server using Ahead-of-time compilation (AOT). This is to eliminate compile times after the model server has been loaded, allowing the server to serve without extra latencies.
This covers using AOT on prefill, generate, and insert, and using their respective compiles in the prefill and generate threads when serving. A driver flag
self.warmup_enabled
is added to monitor this.To enable modelwarmup, one can set
--enable_model_warmup=True/true
when running the Maxengine or jetstream-pytorch server. This will enable a wrapper engine over the original engine that will call the respective AOT-compiled prefills, inserts, and generates when appropriate inprefill_threads
andgenerate_threads
Test coverage can be found under
test_server.py
to check if the model warmup has been successful.This functionality will be useful for model server readiness check and pod startup on GKE, to denote that a pod is ready to serve without extra latencies from compilation time.
This has been validated with
Maxtext
on HEAD (AI-Hypercomputer/maxtext@f8ae413) on GKEThe below covers initial testing on jetstream-pytorch and the latency differences
Latency difference on first request after model server has loaded:
Before AOT model warmup (includes compilation): 34.43s
After AOT model warmup (no further compilation needed): 1.81s
This is an initial implementation of AOT for model warmup. With higher batch sizes in jetstream-pytorch, we observe the detokenizing generate step and time to first response after AOT to be slower. Latency below is measured from the time the request is sent to the time that a response is outputted.