Skip to content

fix: prevent division by zero in sampling when temperature is 0.0#573

Open
ChiragTrivedi06 wants to merge 1 commit intogoogle-deepmind:mainfrom
ChiragTrivedi06:bugfix/zero-division-error
Open

fix: prevent division by zero in sampling when temperature is 0.0#573
ChiragTrivedi06 wants to merge 1 commit intogoogle-deepmind:mainfrom
ChiragTrivedi06:bugfix/zero-division-error

Conversation

@ChiragTrivedi06
Copy link

Fix: Numerical Stability Guard for Sampling Methods

Problem

Currently, the RandomSampling, TopkSampling, and TopPSampling classes perform a direct division by the temperature parameter. When temperature=0.0, this leads to a division-by-zero error, causing JAX/XLA to produce NaN or Inf logits and crashing the sampling pipeline.

Proposed Changes

Instead of a hard switch to a separate Greedy implementation, this PR introduces a threshold-based guard within the sampling classes. This ensures numerical stability while maintaining a unified API for consumers.

Key Implementation Details:

  • Threshold Guard: Added scaled_logits = logits if self.temperature < 1e-6 else logits / self.temperature across all sampling methods.
  • Unified Interface: Avoids forced object swapping; users can now toggle between stochastic and deterministic decoding simply by adjusting the temperature hyperparameter.
  • Top-P/Top-K Compatibility: Ensures that filtering operations (sorting, cumulative distribution) remain stable even at the limit of zero temperature.

Technical Rationale

  1. Mathematical Continuity: As $\tau \to 0$, categorical sampling naturally converges to a greedy selection. This change ensures the implementation reflects this mathematical property without hitting hardware-level numerical exceptions.
  2. Infrastructure Compatibility: Many evaluation and training loops are designed around a single sampling configuration. Forcing a class change (to Greedy) for $\tau=0$ introduces unnecessary branching logic in downstream code.

Verification

  • Numerical Stability: Verified that $\tau=0.0$ no longer produces NaN logits or XLA errors.
  • Behavioral Consistency: Confirmed that low-temperature sampling correctly priorities the most probable tokens (greedy behavior).
  • Cross-Backend Testing: Verified consistent output across CPU, GPU, and TPU.
  • Unit Testing: Updated _sampling_test.py with cases for near-zero and zero temperature.

Closes #562

@google-cla
Copy link

google-cla bot commented Feb 18, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@ChiragTrivedi06 ChiragTrivedi06 force-pushed the bugfix/zero-division-error branch from 76b5cfe to f8dc797 Compare February 18, 2026 12:02
@ChiragTrivedi06 ChiragTrivedi06 force-pushed the bugfix/zero-division-error branch from f8dc797 to 4ea7dcd Compare February 18, 2026 12:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug: Division by zero in sampling methods when temperature is 0.0

1 participant

Comments