Skip to content

Commit

Permalink
Generate: PT's top_p enforces min_tokens_to_keep when it is 1 (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Jun 9, 2023
1 parent 03585f3 commit be10092
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/transformers/generation/flax_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")

self.top_p = top_p
self.filter_value = filter_value
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens
top_p = float(top_p)
if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")

self.top_p = top_p
self.filter_value = filter_value
Expand All @@ -266,9 +268,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0

# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/generation/tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ class TFTopPLogitsWarper(TFLogitsWarper):
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")

self.top_p = top_p
self.filter_value = filter_value
Expand Down

0 comments on commit be10092

Please sign in to comment.