From 18e5bdbec5b12ad395bfb2a30223c78d74a9c158 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Tue, 24 Dec 2019 17:18:05 +0100 Subject: [PATCH] fix repetition penalty error in modeling_utils.py --- src/transformers/modeling_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8413aad595d7a..2698816c66417 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -728,7 +728,11 @@ def _generate_no_beam_search( if repetition_penalty != 1.0: for i in range(batch_size): for previous_tokens in set(input_ids[i].tolist()): - next_token_logits[i, previous_tokens] /= repetition_penalty + # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if next_token_logits[i, previous_tokens] < 0: + next_token_logits[i, previous_tokens] *= repetition_penalty + else: + next_token_logits[i, previous_tokens] /= repetition_penalty if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) @@ -807,7 +811,11 @@ def _generate_beam_search( if repetition_penalty != 1.0: for i in range(batch_size * num_beams): for previous_tokens in set(input_ids[i].tolist()): - scores[i, previous_tokens] /= repetition_penalty + # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if scores[i, previous_tokens] < 0: + scores[i, previous_tokens] *= repetition_penalty + else: + scores[i, previous_tokens] /= repetition_penalty if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens)