fix the issue when eager mode jax is triggered in undesired places#837
Merged
copybara-service[bot] merged 1 commit intogoogle:mainfrom Dec 5, 2025
Merged
Conversation
wang2yn84
reviewed
Dec 4, 2025
| @@ -88,12 +89,12 @@ def next_power_of_2(x: int) -> int: | |||
|
|
|||
|
|
|||
| def pad_to_length( | |||
Collaborator
There was a problem hiding this comment.
There are couple of other callersites of this utility function. E.g. the vllm_sampler and sglang_jax_sampler wraps prompt id as Jax array. Can you update them accordingly. The other callersites such as dpo_trainer you can leave them there. We need to refactor them to use the generate/utils.py.
Collaborator
There was a problem hiding this comment.
Actually let me just merge this PR, it's very important to us and I'd like to land it ASAP. I will fix the rest from internal.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Existing implementation will trigger the jax eager mode, which compiles many small pieces of code on-demand (in order to get it run on TPU) at every training iteration. This will cause several issues:
Before the fix:
We see a lot of complains from wandb: WARNING Tried to log to step 0 that is less than the current step xxx. Steps must be monotonically increasing, so this data will be ignored. See https://wandb.me/define-metric to log data out of order.
This is because the jax compilation message also gets monitored and logged into wandb. But jax doesn't know the step and default step 0 is used. That's why it prints a lot of these warnings during training.
TPU utilization: the eager mode especially the re-compilation at every iteration takes time, which can make things slow and lower the TPU utilization
After the fix:
wandb warning is gone
TPU utilization is higher and it should run faster
Checklist