diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 4e76725947445..20b49764fcb77 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -284,7 +284,7 @@ def generate( pad_token_id = eos_token_id # current position and vocab size - cur_len = shape_list(input_ids)[1] + cur_len = shape_list(input_ids)[1] # unused vocab_size = self.config.vocab_size # set effective batch size and effective batch multiplier according to do_sample @@ -366,10 +366,8 @@ def generate( repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, bad_words_ids=bad_words_ids, - bos_token_id=bos_token_id, pad_token_id=pad_token_id, eos_token_id=eos_token_id, - decoder_start_token_id=decoder_start_token_id, batch_size=effective_batch_size, num_return_sequences=num_return_sequences, length_penalty=length_penalty, @@ -392,10 +390,8 @@ def generate( repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, bad_words_ids=bad_words_ids, - bos_token_id=bos_token_id, pad_token_id=pad_token_id, eos_token_id=eos_token_id, - decoder_start_token_id=decoder_start_token_id, batch_size=effective_batch_size, vocab_size=vocab_size, encoder_outputs=encoder_outputs, @@ -418,10 +414,8 @@ def _generate_no_beam_search( repetition_penalty, no_repeat_ngram_size, bad_words_ids, - bos_token_id, pad_token_id, eos_token_id, - decoder_start_token_id, batch_size, vocab_size, encoder_outputs, @@ -582,9 +576,7 @@ def _generate_beam_search( repetition_penalty, no_repeat_ngram_size, bad_words_ids, - bos_token_id, pad_token_id, - decoder_start_token_id, eos_token_id, batch_size, num_return_sequences, @@ -616,6 +608,7 @@ def _generate_beam_search( # cache compute states past = encoder_outputs + # to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None # done sentences done = [False for _ in range(batch_size)]