Skip to content

Commit

Permalink
tf generation utils: remove unused kwargs (#6591)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshleifer committed Aug 19, 2020
1 parent 2a7402c commit 9a86321
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def generate(
pad_token_id = eos_token_id

# current position and vocab size
cur_len = shape_list(input_ids)[1]
cur_len = shape_list(input_ids)[1] # unused
vocab_size = self.config.vocab_size

# set effective batch size and effective batch multiplier according to do_sample
Expand Down Expand Up @@ -366,10 +366,8 @@ def generate(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
Expand All @@ -392,10 +390,8 @@ def generate(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
Expand All @@ -418,10 +414,8 @@ def _generate_no_beam_search(
repetition_penalty,
no_repeat_ngram_size,
bad_words_ids,
bos_token_id,
pad_token_id,
eos_token_id,
decoder_start_token_id,
batch_size,
vocab_size,
encoder_outputs,
Expand Down Expand Up @@ -582,9 +576,7 @@ def _generate_beam_search(
repetition_penalty,
no_repeat_ngram_size,
bad_words_ids,
bos_token_id,
pad_token_id,
decoder_start_token_id,
eos_token_id,
batch_size,
num_return_sequences,
Expand Down Expand Up @@ -616,6 +608,7 @@ def _generate_beam_search(

# cache compute states
past = encoder_outputs
# to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None

# done sentences
done = [False for _ in range(batch_size)]
Expand Down

0 comments on commit 9a86321

Please sign in to comment.