Skip to content

Commit

Permalink
TF: Merge PT and TF behavior for Bart when no decoder_input_ids are p…
Browse files Browse the repository at this point in the history
…assed (huggingface#17593)

* Merge PT and TF behavior
  • Loading branch information
gante authored and elusenji committed Jun 12, 2022
1 parent 47723e6 commit 3f6f6bb
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,14 +1073,16 @@ def call(
**kwargs
) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]:

# different to other models, Bart automatically creates decoder_input_ids from
# input_ids if no decoder_input_ids are provided
if decoder_input_ids is None and decoder_inputs_embeds is None:
use_cache = False

output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is None:
raise ValueError(
"If no `decoder_input_ids` or `decoder_inputs_embeds` are "
"passed, `input_ids` cannot be `None`. Please pass either "
"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
)

if decoder_input_ids is None and input_ids is not None:
decoder_input_ids = shift_tokens_right(
input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down

0 comments on commit 3f6f6bb

Please sign in to comment.