From 3f6f6bb31f25b4d6c40b4eb7d6ac766471097fd2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 8 Jun 2022 17:42:23 +0100 Subject: [PATCH] TF: Merge PT and TF behavior for Bart when no decoder_input_ids are passed (#17593) * Merge PT and TF behavior --- src/transformers/models/bart/modeling_tf_bart.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 21523e2f817a2..8f8586c791314 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -1073,14 +1073,16 @@ def call( **kwargs ) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]: + # different to other models, Bart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided if decoder_input_ids is None and decoder_inputs_embeds is None: - use_cache = False - - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) - if decoder_input_ids is None and input_ids is not None: decoder_input_ids = shift_tokens_right( input_ids, self.config.pad_token_id, self.config.decoder_start_token_id )