Skip to content

fix the issue when eager mode jax is triggered in undesired places#837

Merged
copybara-service[bot] merged 1 commit intogoogle:mainfrom
precur-ai:avoid_eager_mode
Dec 5, 2025
Merged

fix the issue when eager mode jax is triggered in undesired places#837
copybara-service[bot] merged 1 commit intogoogle:mainfrom
precur-ai:avoid_eager_mode

Conversation

@Hanjun-Dai
Copy link
Contributor

Existing implementation will trigger the jax eager mode, which compiles many small pieces of code on-demand (in order to get it run on TPU) at every training iteration. This will cause several issues:

Before the fix:

  • We see a lot of complains from wandb: WARNING Tried to log to step 0 that is less than the current step xxx. Steps must be monotonically increasing, so this data will be ignored. See https://wandb.me/define-metric to log data out of order.
    This is because the jax compilation message also gets monitored and logged into wandb. But jax doesn't know the step and default step 0 is used. That's why it prints a lot of these warnings during training.

  • TPU utilization: the eager mode especially the re-compilation at every iteration takes time, which can make things slow and lower the TPU utilization

Screenshot 2025-12-04 at 2 07 00 AM
  • Probably too many cached binaries will be created until all kinds of shapes have been compiled.

After the fix:

  • wandb warning is gone

  • TPU utilization is higher and it should run faster

Screenshot 2025-12-04 at 2 04 34 AM

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

@@ -88,12 +89,12 @@ def next_power_of_2(x: int) -> int:


def pad_to_length(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are couple of other callersites of this utility function. E.g. the vllm_sampler and sglang_jax_sampler wraps prompt id as Jax array. Can you update them accordingly. The other callersites such as dpo_trainer you can leave them there. We need to refactor them to use the generate/utils.py.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually let me just merge this PR, it's very important to us and I'd like to land it ASAP. I will fix the rest from internal.

@copybara-service copybara-service bot merged commit ae70dbb into google:main Dec 5, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants