Skip to content

Commit

Permalink
handle missing pad token
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Jun 21, 2022
1 parent abb2c13 commit 6dd8625
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1991,7 +1991,7 @@ def greedy_search(
batch_size, cur_len = input_ids.shape

# initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences`
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * pad_token_id
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
generated = tf.concat([input_ids, input_ids_padding], axis=-1)
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)

Expand Down Expand Up @@ -2249,7 +2249,7 @@ def sample(
batch_size, cur_len = input_ids.shape

# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * pad_token_id
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
generated = tf.concat([input_ids, input_ids_padding], axis=-1)
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)

Expand Down Expand Up @@ -2571,9 +2571,11 @@ def gather_fn(tensor):
batch_size, num_beams, cur_len = input_ids.shape

# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * pad_token_id
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * (
pad_token_id or 0
)
running_sequences = tf.concat([input_ids, input_ids_padding], axis=-1)
sequences = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * pad_token_id
sequences = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * (pad_token_id or 0)

# per batch,beam-item state bit indicating if sentence has finished.
is_sent_finished = tf.zeros((batch_size, num_beams), dtype=tf.bool)
Expand Down

0 comments on commit 6dd8625

Please sign in to comment.