Fix batch_size=1 token_buffer truncation in sampler#963
Open
Cascoopman wants to merge 1 commit intogoogle:mainfrom
Open
Fix batch_size=1 token_buffer truncation in sampler#963Cascoopman wants to merge 1 commit intogoogle:mainfrom
Cascoopman wants to merge 1 commit intogoogle:mainfrom
Conversation
Fixes google#809 When using the generic sampler with batch_size=1, the token_buffer was being truncated instead of properly populated during the while_loop. This was caused by using Python integers for decoding_step indexing, which can cause JAX tracing issues. Changes: - Initialize decoding_step as jnp.int32() in init_sample_state - Add explicit jnp.asarray() conversion in _sample() and _sample_step() to ensure consistent JAX tracing behavior - Add test cases for batch_size=1 to prevent regression The Gemma-specific sampler already had this fix at line 440, but the generic sampler was missing it. This fix aligns the generic sampler with the Gemma sampler implementation.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #809
When using the generic sampler with
batch_size=1, thetoken_bufferwas being truncated instead of properly populated during thewhile_loop. This made the sampler unusable whenTRAIN_MICRO_BATCH_SIZE=1, yet it worked fine for larger batches.Root Cause
The issue was caused by using Python integers for
decoding_stepindexing insidejax.lax.while_loop. When JAX traces through a while_loop, using Python integers for dynamic indexing can cause issues with how the traced computation is specialized. Withbatch_size=1, JAX may incorrectly optimize or trace the array update operations (.at[:, decoding_step + 1].set(...)).The Gemma-specific sampler (
tunix/tunix/models/gemma/sampler.py) already had this fix at line 440:But the generic sampler was missing this conversion.
Changes
tunix/generate/sampler.py:decoding_stepasjnp.int32(num_input_tokens - 1)ininit_sample_statejnp.asarray()conversion in_sample()methodjnp.asarray()conversion in_sample_step()method with a comment referencing the issuetests/generate/sampler_test.py:test_batch_size_one()test casetest_batch_size_one_with_echo()test caseTesting
Related
This fix aligns the generic sampler with the Gemma sampler implementation which already had this fix.