diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 3a37c86238d7db..6d3d105cced62e 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -1991,7 +1991,7 @@ def greedy_search( batch_size, cur_len = input_ids.shape # initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences` - input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * pad_token_id + input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0) generated = tf.concat([input_ids, input_ids_padding], axis=-1) finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) @@ -2249,7 +2249,7 @@ def sample( batch_size, cur_len = input_ids.shape # initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences` - input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * pad_token_id + input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0) generated = tf.concat([input_ids, input_ids_padding], axis=-1) finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) @@ -2571,9 +2571,11 @@ def gather_fn(tensor): batch_size, num_beams, cur_len = input_ids.shape # per batch, beam-item holding current token in loop, pre-populated with `pad_token_id` - input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * pad_token_id + input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * ( + pad_token_id or 0 + ) running_sequences = tf.concat([input_ids, input_ids_padding], axis=-1) - sequences = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * pad_token_id + sequences = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * (pad_token_id or 0) # per batch,beam-item state bit indicating if sentence has finished. is_sent_finished = tf.zeros((batch_size, num_beams), dtype=tf.bool)