From 707f7eb181e1191a64268be902407dacd5e605c7 Mon Sep 17 00:00:00 2001 From: Silviu Oprea Date: Fri, 1 Oct 2021 18:36:57 +0100 Subject: [PATCH] Bart: check if decoder_inputs_embeds is set (#13800) In BartForConditionalGeneration.forward, if labels are provided, decoder_input_ids are set to the labels shifted to the right. This is problematic: if decoder_inputs_embeds is also set, the call to self.model, which eventually gets to BartDecoder.forward, will raise an error. The fix is quite simple, similar to what is there already in BartModel.forward. Mainly, we should not compute decoder_input_ids if decoder_inputs_embeds is provided. Co-authored-by: Silviu Vlad Oprea --- src/transformers/models/bart/modeling_bart.py | 2 +- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 08315557d20941..318b084b715bfe 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1291,7 +1291,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 684e14af8eaec6..736f6126024cae 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2501,7 +2501,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id )