Skip to content

Fix batch_size=1 token_buffer truncation in sampler#963

Open
Cascoopman wants to merge 1 commit intogoogle:mainfrom
Cascoopman:fix/batch-size-1-sampler-issue-809
Open

Fix batch_size=1 token_buffer truncation in sampler#963
Cascoopman wants to merge 1 commit intogoogle:mainfrom
Cascoopman:fix/batch-size-1-sampler-issue-809

Conversation

@Cascoopman
Copy link

Summary

Fixes #809

When using the generic sampler with batch_size=1, the token_buffer was being truncated instead of properly populated during the while_loop. This made the sampler unusable when TRAIN_MICRO_BATCH_SIZE=1, yet it worked fine for larger batches.

Root Cause

The issue was caused by using Python integers for decoding_step indexing inside jax.lax.while_loop. When JAX traces through a while_loop, using Python integers for dynamic indexing can cause issues with how the traced computation is specialized. With batch_size=1, JAX may incorrectly optimize or trace the array update operations (.at[:, decoding_step + 1].set(...)).

The Gemma-specific sampler (tunix/tunix/models/gemma/sampler.py) already had this fix at line 440:

decoding_step = jnp.asarray(sampler_state.decoding_step, dtype=jnp.int32)

But the generic sampler was missing this conversion.

Changes

  1. tunix/generate/sampler.py:

    • Initialize decoding_step as jnp.int32(num_input_tokens - 1) in init_sample_state
    • Add explicit jnp.asarray() conversion in _sample() method
    • Add explicit jnp.asarray() conversion in _sample_step() method with a comment referencing the issue
  2. tests/generate/sampler_test.py:

    • Added test_batch_size_one() test case
    • Added test_batch_size_one_with_echo() test case

Testing

  • Ran all existing sampler tests - all pass
  • Ran new batch_size=1 tests - all pass
  • Verified fix with standalone reproduction script testing while_loop behavior

Related

This fix aligns the generic sampler with the Gemma sampler implementation which already had this fix.

Fixes google#809

When using the generic sampler with batch_size=1, the token_buffer
was being truncated instead of properly populated during the while_loop.
This was caused by using Python integers for decoding_step indexing,
which can cause JAX tracing issues.

Changes:
- Initialize decoding_step as jnp.int32() in init_sample_state
- Add explicit jnp.asarray() conversion in _sample() and _sample_step()
  to ensure consistent JAX tracing behavior
- Add test cases for batch_size=1 to prevent regression

The Gemma-specific sampler already had this fix at line 440, but the
generic sampler was missing it. This fix aligns the generic sampler
with the Gemma sampler implementation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Single-item batches truncate sampler token_buffer instead of populating it

1 participant