Skip to content

Commit

Permalink
[Generation] Fix bug for manual decoder_input_ids + warning message (#…
Browse files Browse the repository at this point in the history
…9472)

* up

* improve style
  • Loading branch information
patrickvonplaten committed Jan 8, 2021
1 parent 9e1ea84 commit 79bbcc5
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions src/transformers/generation_utils.py
Expand Up @@ -379,12 +379,8 @@ def _prepare_encoder_decoder_kwargs_for_generation(
return model_kwargs

def _prepare_decoder_input_ids_for_generation(
self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None, **model_kwargs
self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None
) -> torch.LongTensor:

if "decoder_input_ids" in model_kwargs:
return model_kwargs["decoder_input_ids"]

decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
decoder_input_ids = (
torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device)
Expand Down Expand Up @@ -837,13 +833,23 @@ def generate(
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)

# set input_ids as decoder_input_ids
input_ids = self._prepare_decoder_input_ids_for_generation(
input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, **model_kwargs
)
if "decoder_input_ids" in model_kwargs:
input_ids = model_kwargs.pop("decoder_input_ids")
else:
input_ids = self._prepare_decoder_input_ids_for_generation(
input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id
)

if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput):
raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.")

if input_ids.shape[-1] >= max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids.shape[-1]}, but ``max_length`` is set to {max_length}."
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
)

# determine generation mode
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True
Expand Down

0 comments on commit 79bbcc5

Please sign in to comment.