Skip to content

Commit

Permalink
fix off by one
Browse files Browse the repository at this point in the history
  • Loading branch information
kingoflolz committed Dec 4, 2021
1 parent 628f44d commit d2c2f59
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion mesh_transformer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def nucleaus_filter(logits, top_p=0.9, top_k=None):
indices_range = jnp.arange(len(sorted_indices[0]))
indices_range = jnp.stack([indices_range] * len(sorted_indices), axis=0)

sorted_indices_to_remove = jnp.where(indices_range > top_k, sorted_indices, 0)
sorted_indices_to_remove = jnp.where(indices_range >= top_k, sorted_indices, 0)

_, indices_to_remove = jax.lax.sort_key_val(sorted_indices, sorted_indices_to_remove)

Expand Down

0 comments on commit d2c2f59

Please sign in to comment.