diff --git a/docs/source/model_doc/encoderdecoder.rst b/docs/source/model_doc/encoderdecoder.rst index a63b6044a2c8a4..f3105d9131c512 100644 --- a/docs/source/model_doc/encoderdecoder.rst +++ b/docs/source/model_doc/encoderdecoder.rst @@ -1,13 +1,12 @@ Encoder Decoder Models ------------------------ -The :class:`~transformers.EncoderDecoderModel` can be used to initialize a sequence-to-sequence model with any pre-trained autoencoding model as the encoder and any pre-trained autoregressive model as the decoder. +This class can wrap an encoder model, such as ``BertModel`` and a decoder modeling with a language modeling head, such as ``BertForMaskedLM`` into a encoder-decoder model. -The effectiveness of initializing sequence-to-sequence models with pre-trained checkpoints for sequence generation tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks `__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. +The ``EncoderDecoderModel`` class allows to instantiate a encoder decoder model using the ``from_encoder_decoder_pretrain`` class method taking a pretrained encoder and pretrained decoder model as an input. +The ``EncoderDecoderModel`` is saved using the standard ``save_pretrained()`` method and can also again be loaded using the standard ``from_pretrained()`` method. -After such an :class:`~transformers.EncoderDecoderModel` has been trained / fine-tuned, it can be saved / loaded just like any other models (see Examples for more information). - -An application of this architecture could be to leverage two pre-trained :obj:`transformers.BertModel` models as the encoder and decoder for a summarization model as was shown in: `Text Summarization with Pretrained Encoders `_ by Yang Liu and Mirella Lapata. +An application of this architecture could be *summarization* using two pretrained Bert models as is shown in the paper: `Text Summarization with Pretrained Encoders `_ by Yang Liu and Mirella Lapata. ``EncoderDecoderConfig`` diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 638bb3b12e6dc4..302fad2fc49ef0 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -20,7 +20,6 @@ from torch import Tensor from torch.nn import functional as F -from .file_utils import ModelOutput from .utils import logging @@ -47,6 +46,14 @@ def adjust_logits_during_generation(self, logits, **kwargs): """ return logits + def _use_cache(self, outputs, use_cache): + """During generation, decide whether to pass the `past` variable to the next forward pass.""" + if len(outputs) <= 1 or use_cache is False: + return False + if hasattr(self.config, "mem_len") and self.config.mem_len == 0: + return False + return True + def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty): """ Enforce the repetition penalty (from the `CTRL paper `__). @@ -130,7 +137,7 @@ def generate( attention_mask: Optional[torch.LongTensor] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, - **model_kwargs + **model_specific_kwargs ) -> torch.LongTensor: r""" Generates sequences for models with a language modeling head. The method currently supports greedy decoding, @@ -201,7 +208,7 @@ def generate( use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. - model_kwargs: + model_specific_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. Return: @@ -393,7 +400,7 @@ def generate( # get encoder and store encoder outputs encoder = self.get_encoder() - encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True) + encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask) # Expand input ids if num_beams > 1 or num_return_sequences > 1 if num_return_sequences > 1 or num_beams > 1: @@ -421,8 +428,8 @@ def generate( cur_len = 1 assert ( - batch_size == encoder_outputs.last_hidden_state.shape[0] - ), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} " + batch_size == encoder_outputs[0].shape[0] + ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} " # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) expanded_batch_idxs = ( @@ -432,16 +439,11 @@ def generate( .view(-1) .to(input_ids.device) ) - # expand encoder_outputs - encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( - 0, expanded_batch_idxs - ) - - # save encoder_outputs in `model_kwargs` - model_kwargs["encoder_outputs"] = encoder_outputs + encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:]) else: + encoder_outputs = None cur_len = input_ids.shape[-1] assert ( @@ -469,9 +471,10 @@ def generate( length_penalty=length_penalty, num_beams=num_beams, vocab_size=vocab_size, + encoder_outputs=encoder_outputs, attention_mask=attention_mask, use_cache=use_cache, - model_kwargs=model_kwargs, + model_specific_kwargs=model_specific_kwargs, ) else: output = self._generate_no_beam_search( @@ -489,9 +492,10 @@ def generate( pad_token_id=pad_token_id, eos_token_id=eos_token_id, batch_size=effective_batch_size, + encoder_outputs=encoder_outputs, attention_mask=attention_mask, use_cache=use_cache, - model_kwargs=model_kwargs, + model_specific_kwargs=model_specific_kwargs, ) return output @@ -512,9 +516,10 @@ def _generate_no_beam_search( pad_token_id, eos_token_id, batch_size, + encoder_outputs, attention_mask, use_cache, - model_kwargs, + model_specific_kwargs, ): """Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated independantly. @@ -523,14 +528,15 @@ def _generate_no_beam_search( unfinished_sents = input_ids.new(batch_size).fill_(1) sent_lengths = input_ids.new(batch_size).fill_(max_length) - past = None + past = (encoder_outputs, None) if encoder_outputs is not None else None + while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation( - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs ) - outputs = self(**model_inputs, return_dict=True) - next_token_logits = outputs.logits[:, -1, :] + outputs = self(**model_inputs) + next_token_logits = outputs[0][:, -1, :] scores = self.postprocess_next_token_scores( scores=next_token_logits, @@ -547,10 +553,8 @@ def _generate_no_beam_search( ) # if model has past, then set the past variable to speed up decoding - if "past_key_values" in outputs: - past = outputs.past_key_values - elif "mems" in outputs: - past = outputs.mems + if self._use_cache(outputs, use_cache): + past = outputs[1] if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) @@ -617,9 +621,10 @@ def _generate_beam_search( length_penalty, num_beams, vocab_size, + encoder_outputs, attention_mask, use_cache, - model_kwargs, + model_specific_kwargs, ): """Generate sequences for each example with beam search.""" @@ -638,24 +643,21 @@ def _generate_beam_search( beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) # cache compute states - past = None + past = (encoder_outputs, None) if encoder_outputs is not None else None # done sentences done = [False for _ in range(batch_size)] while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation( - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs ) - outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size) - next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size) + outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) + next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) # if model has past, then set the past variable to speed up decoding - if "past_key_values" in outputs: - past = outputs.past_key_values - elif "mems" in outputs: - past = outputs.mems - + if self._use_cache(outputs, use_cache): + past = outputs[1] if self.config.is_encoder_decoder and do_sample is False: # TODO (PVP) still a bit hacky here - there might be a better solution next_token_logits = self.adjust_logits_during_generation( diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 4122d3aa9d563d..45b40554cd9f5d 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -111,15 +111,15 @@ Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify. See diagram 1 in the paper for more info on the default strategy - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains pre-computed key and value hidden-states of the attention blocks. Can be used to speed up decoding. - If ``past_key_values`` are used, the user can optionally input only the last + If ``decoder_past_key_value_states`` are used, the user can optionally input only the last ``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): - If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see - ``past_key_values``). + If `use_cache` is True, ``decoder_past_key_values`` are returned and can be used to speed up decoding (see + ``decoder_past_key_values``). output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`): @@ -502,7 +502,7 @@ def forward( encoder_padding_mask, decoder_padding_mask, decoder_causal_mask, - past_key_values=None, + decoder_past_key_values=None, use_cache=False, output_attentions=False, output_hidden_states=False, @@ -519,7 +519,7 @@ def forward( encoder_hidden_states: output from the encoder, used for encoder-side attention encoder_padding_mask: for ignoring pad tokens - past_key_values (dict or None): dictionary used for storing state during generation + decoder_past_key_values (dict or None): dictionary used for storing state during generation Returns: BaseModelOutputWithPast or tuple: @@ -530,16 +530,10 @@ def forward( """ if "decoder_cached_states" in unused: warnings.warn( - "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.", FutureWarning, ) - past_key_values = unused.pop("decoder_cached_states") - if "decoder_past_key_values" in unused: - warnings.warn( - "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", - FutureWarning, - ) - past_key_values = unused.pop("decoder_past_key_values") + decoder_past_key_values = unused.pop("decoder_cached_states") # check attention mask and invert if encoder_padding_mask is not None: @@ -574,7 +568,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): continue - layer_state = past_key_values[idx] if past_key_values is not None else None + layer_state = decoder_past_key_values[idx] if decoder_past_key_values is not None else None x, layer_self_attn, layer_past = decoder_layer( x, @@ -600,7 +594,10 @@ def forward( x = x.transpose(0, 1) encoder_hidden_states = encoder_hidden_states.transpose(0, 1) - next_cache = next_decoder_cache if use_cache else None + if use_cache: + next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache) + else: + next_cache = None if not return_dict: return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -872,19 +869,13 @@ def forward( decoder_input_ids=None, encoder_outputs: Optional[Tuple] = None, decoder_attention_mask=None, - past_key_values=None, + decoder_past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): - if "decoder_past_key_values" in kwargs: - warnings.warn( - "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", - FutureWarning, - ) - past_key_values = kwargs.pop("decoder_past_key_values") if decoder_input_ids is None: use_cache = False @@ -933,7 +924,7 @@ def forward( attention_mask, decoder_padding_mask, decoder_causal_mask=causal_mask, - past_key_values=past_key_values, + decoder_past_key_values=decoder_past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -945,7 +936,7 @@ def forward( return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, + decoder_past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, @@ -1003,7 +994,7 @@ def forward( encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, - past_key_values=None, + decoder_past_key_values=None, labels=None, use_cache=None, output_attentions=None, @@ -1046,16 +1037,10 @@ def forward( labels = unused.pop("lm_labels") if "decoder_cached_states" in unused: warnings.warn( - "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.", FutureWarning, ) - past_key_values = unused.pop("decoder_cached_states") - if "decoder_past_key_values" in unused: - warnings.warn( - "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", - FutureWarning, - ) - past_key_values = unused.pop("decoder_past_key_values") + decoder_past_key_values = unused.pop("decoder_cached_states") return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: @@ -1069,7 +1054,7 @@ def forward( decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, - past_key_values=past_key_values, + decoder_past_key_values=decoder_past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1090,7 +1075,7 @@ def forward( return Seq2SeqLMOutput( loss=masked_lm_loss, logits=lm_logits, - past_key_values=outputs.past_key_values, + decoder_past_key_values=outputs.decoder_past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, @@ -1098,13 +1083,14 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs - ): + def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs): + assert past is not None, "past has to be defined for encoder_outputs" + + encoder_outputs, decoder_past_key_values = past return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, - "past_key_values": past, + "decoder_past_key_values": decoder_past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) @@ -1123,14 +1109,20 @@ def _force_token_ids_generation(self, scores, token_id) -> None: @staticmethod def _reorder_cache(past, beam_idx): + ((enc_out, enc_mask), decoder_past_key_values) = past reordered_past = [] - for layer_past in past: + for layer_past in decoder_past_key_values: # get the correct batch idx from decoder layer's batch dim for cross and self-attn layer_past_new = { attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() } reordered_past.append(layer_past_new) - return reordered_past + + new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx) + new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx) + + past = ((new_enc_out, new_enc_mask), reordered_past) + return past def get_encoder(self): return self.model.encoder @@ -1216,7 +1208,7 @@ def forward( return Seq2SeqSequenceClassifierOutput( loss=loss, logits=logits, - past_key_values=outputs.past_key_values, + decoder_past_key_values=outputs.decoder_past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, @@ -1324,7 +1316,7 @@ def forward( loss=total_loss, start_logits=start_logits, end_logits=end_logits, - past_key_values=outputs.past_key_values, + decoder_past_key_values=outputs.decoder_past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index 343c65321ab96e..b737fa779133ce 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -19,79 +19,13 @@ from .configuration_encoder_decoder import EncoderDecoderConfig from .configuration_utils import PretrainedConfig -from .file_utils import add_start_docstrings, add_start_docstrings_to_callable, replace_return_docstrings -from .modeling_outputs import Seq2SeqLMOutput from .modeling_utils import PreTrainedModel from .utils import logging logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "EncoderDecoderConfig" - -ENCODER_DECODER_START_DOCSTRING = r""" - This class can be used to inialize a sequence-to-sequnece model with any pretrained autoencoding model as the encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via :meth:`~transformers.AutoModel.from_pretrained` function and the decoder is loaded via :meth:`~transformers.AutoModelForCausalLM.from_pretrained` function. - Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream generative task, *i.e.* summarization. - - The effectiveness of initializing sequence-to-sequence models with pre-trained checkpoints for sequence generation tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks `__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. - Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. - - After such an Encoder Decoder model has been trained / fine-tuned, it can be saved / loaded just like any other models (see Examples for more information). - - This model is a PyTorch `torch.nn.Module `__ sub-class. Use it as a - regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. - - Parameters: - config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the configuration. - Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. -""" - -ENCODER_DECODER_INPUTS_DOCSTRING = r""" - Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary for the encoder. - Indices can be obtained using :class:`~transformers.PretrainedTokenizer`. - See :meth:`~transformers.PreTrainedTokenizer.encode` and - :meth:`~transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. - inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): - Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert :obj:`input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Mask to avoid performing attention on padding token indices for the encoder. - Mask values selected in ``[0, 1]``: - ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. - encoder_outputs (:obj:`tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`): - This tuple must consist of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) - `last_hidden_state` (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`) is a tensor of hidden-states at the output of the last layer of the encoder. - Used in the cross-attention of the decoder. - decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): - Provide for sequence to sequence training to the decoder. - Indices can be obtained using :class:`transformers.PretrainedTokenizer`. - See :func:`transformers.PreTrainedTokenizer.encode` and - :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. - decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): - Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. - decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): - Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Labels for computing the masked language modeling loss for the decoder. - Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) - Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels - in ``[0, ..., config.vocab_size]`` - return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`): - If set to ``True``, the model will return a :class:`~transformers.file_utils.Seq2SeqLMOutput` instead of a - plain tuple. - kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: - - Without a prefix which will be input as ``**encoder_kwargs`` for the encoder forward function. - - With a `decoder_` prefix which will be input as ``**decoder_kwargs`` for the decoder forward function. -""" - - -@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) + class EncoderDecoderModel(PreTrainedModel): r""" :class:`~transformers.EncoderDecoder` is a generic model class that will be @@ -272,8 +206,6 @@ def from_encoder_decoder_pretrained( config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) return cls(encoder=encoder, decoder=decoder, config=config) - @add_start_docstrings_to_callable(ENCODER_DECODER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids=None, @@ -284,11 +216,47 @@ def forward( decoder_attention_mask=None, decoder_inputs_embeds=None, labels=None, - return_dict=None, **kwargs, ): - r""" - Returns: + + """ + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary for the encoder. + Indices can be obtained using :class:`transformers.PretrainedTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on padding token indices for the encoder. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`): + Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`) + `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder. + Used in the cross-attention of the decoder. + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): + Provide for sequence to sequence training to the decoder. + Indices can be obtained using :class:`transformers.PretrainedTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): + Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the masked language modeling loss for the decoder. + Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: + - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function. + - With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function. Examples:: @@ -296,25 +264,19 @@ def forward( >>> import torch >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints + >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert >>> # forward >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids) >>> # training - >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids, return_dict=True) - >>> loss, logits = outputs.loss, outputs.logits - - >>> # save and load from pretrained - >>> model.save_pretrained("bert2bert") - >>> model = EncoderDecoderModel.from_pretrained("bert2bert") + >>> loss, outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)[:2] >>> # generation >>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id) """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} @@ -327,7 +289,7 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, - return_dict=return_dict, + return_dict=False, **kwargs_encoder, ) @@ -341,28 +303,23 @@ def forward( encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, labels=labels, - return_dict=return_dict, + return_dict=False, **kwargs_decoder, ) # TODO(PVP): currently it is not possible to use `past` - if not return_dict: - return decoder_outputs + encoder_outputs - - return Seq2SeqLMOutput( - loss=decoder_outputs.loss, - logits=decoder_outputs.logits, - past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - + # with the encoder/decoder framework -> should be implemented return decoder_outputs + encoder_outputs - def prepare_inputs_for_generation(self, input_ids, past, attention_mask, encoder_outputs, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs): + assert past is not None, "past has to be defined for encoder_outputs" + + # first step + if type(past) is tuple: + encoder_outputs, _ = past + else: + encoder_outputs = (past,) + decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids) decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None input_dict = { @@ -378,7 +335,7 @@ def prepare_inputs_for_generation(self, input_ids, past, attention_mask, encoder input_dict["decoder_use_cache"] = decoder_inputs["use_cache"] if "past_key_values" in decoder_inputs: - input_dict["past_key_values"] = decoder_inputs["past_key_values"] + input_dict["decoder_past_key_values"] = decoder_inputs["past_key_values"] return input_dict diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 727a3a87c3a87a..1d4ceb0e2f9a42 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -353,11 +353,11 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): Base class for outputs of models predicting if two sentences are consecutive or not. Args: - loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided): + lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided): Language modeling loss. mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided): Multiple choice classification loss. - logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): + lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). @@ -380,9 +380,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): heads. """ - loss: Optional[torch.FloatTensor] = None + lm_loss: Optional[torch.FloatTensor] = None mc_loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None + lm_logits: torch.FloatTensor = None mc_logits: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None @@ -777,17 +777,6 @@ def __init__(self, config): def get_output_embeddings(self): return self.lm_head - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - # only last token for inputs_ids if past is defined in kwargs - if past: - input_ids = input_ids[:, -1].unsqueeze(-1) - - return { - "input_ids": input_ids, - "past_key_values": past, - "use_cache": kwargs.get("use_cache"), - } - @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -904,9 +893,9 @@ def forward( return ((lm_loss,) + output) if lm_loss is not None else output return GPT2DoubleHeadsModelOutput( - loss=lm_loss, + lm_loss=lm_loss, mc_loss=mc_loss, - logits=lm_logits, + lm_logits=lm_logits, mc_logits=mc_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, diff --git a/src/transformers/modeling_openai.py b/src/transformers/modeling_openai.py index e62d13455d58fa..1920880b288f34 100644 --- a/src/transformers/modeling_openai.py +++ b/src/transformers/modeling_openai.py @@ -300,11 +300,11 @@ class OpenAIGPTDoubleHeadsModelOutput(ModelOutput): Base class for outputs of models predicting if two sentences are consecutive or not. Args: - loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided): + lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided): Language modeling loss. mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided): Multiple choice classification loss. - logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): + lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). @@ -321,9 +321,9 @@ class OpenAIGPTDoubleHeadsModelOutput(ModelOutput): heads. """ - loss: Optional[torch.FloatTensor] = None + lm_loss: Optional[torch.FloatTensor] = None mc_loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None + lm_logits: torch.FloatTensor = None mc_logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -713,9 +713,9 @@ def forward( return ((lm_loss,) + output) if lm_loss is not None else output return OpenAIGPTDoubleHeadsModelOutput( - loss=lm_loss, + lm_loss=lm_loss, mc_loss=mc_loss, - logits=lm_logits, + lm_logits=lm_logits, mc_logits=mc_logits, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index e6a4b0ed8be197..1c36dc2d81ac4a 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -109,13 +109,13 @@ class Seq2SeqModelOutput(ModelOutput): last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the decoder of the model. - If ``past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. - past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + If ``decoder_past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. + decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``past_key_values`` input) to speed up sequential decoding. + used (see ``decoder_past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -143,7 +143,7 @@ class Seq2SeqModelOutput(ModelOutput): """ last_hidden_state: torch.FloatTensor - past_key_values: Optional[List[torch.FloatTensor]] = None + decoder_past_key_values: Optional[List[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None @@ -255,12 +255,12 @@ class Seq2SeqLMOutput(ModelOutput): Languaged modeling loss. logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``past_key_values`` input) to speed up sequential decoding. + used (see ``decoder_past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -289,7 +289,7 @@ class Seq2SeqLMOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None - past_key_values: Optional[List[torch.FloatTensor]] = None + decoder_past_key_values: Optional[List[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None @@ -365,12 +365,12 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput): Classification (or regression if config.num_labels==1) loss. logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). - past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``past_key_values`` input) to speed up sequential decoding. + used (see ``decoder_past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -399,7 +399,7 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None - past_key_values: Optional[List[torch.FloatTensor]] = None + decoder_past_key_values: Optional[List[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None @@ -511,12 +511,12 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): Span-start scores (before SoftMax). end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): Span-end scores (before SoftMax). - past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``past_key_values`` input) to speed up sequential decoding. + used (see ``decoder_past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -546,7 +546,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None start_logits: torch.FloatTensor = None end_logits: torch.FloatTensor = None - past_key_values: Optional[List[torch.FloatTensor]] = None + decoder_past_key_values: Optional[List[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 463d9f471e9b00..a8ae72d0b2219c 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -838,27 +838,27 @@ def forward( Used in the cross-attention of the decoder. decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation. - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + If `decoder_past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_values`). To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at `T5 Training <./t5.html#training>`__. If decoder_input_ids and decoder_inputs_embeds are both None, decoder_input_ids takes the value of input_ids. decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + decoder_past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains pre-computed key and value hidden-states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` + If `decoder_past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): - If `use_cache` is True, `past_key_values` are returned and can be used to speed up decoding (see `past_key_values`). + If `use_cache` is True, `decoder_past_key_values` are returned and can be used to speed up decoding (see `decoder_past_key_values`). inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation. - If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`). + If `decoder_past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_values`). This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. If decoder_input_ids and decoder_inputs_embeds are both None, decoder_inputs_embeds takes the value of inputs_embeds. @@ -928,7 +928,7 @@ def forward( encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, - past_key_values=None, + decoder_past_key_values=None, use_cache=None, inputs_embeds=None, decoder_inputs_embeds=None, @@ -955,16 +955,10 @@ def forward( """ if "decoder_past_key_value_states" in kwargs: warnings.warn( - "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.", FutureWarning, ) - past_key_values = kwargs.pop("decoder_past_key_value_states") - if "decoder_past_key_values" in kwargs: - warnings.warn( - "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", - FutureWarning, - ) - past_key_values = kwargs.pop("decoder_past_key_values") + decoder_past_key_values = kwargs.pop("decoder_past_key_value_states") assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -998,7 +992,7 @@ def forward( # If decoding with past key value states, only the last tokens # should be given as an input - if past_key_values is not None: + if decoder_past_key_values is not None: if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: @@ -1009,7 +1003,7 @@ def forward( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, - past_key_value_states=past_key_values, + past_key_value_states=decoder_past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=head_mask, @@ -1019,12 +1013,15 @@ def forward( return_dict=return_dict, ) + past = (encoder_outputs, decoder_outputs[1]) if use_cache is True else None if not return_dict: + if past is not None: + decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] return decoder_outputs + encoder_outputs return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, + decoder_past_key_values=past, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, @@ -1083,7 +1080,7 @@ def forward( encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, - past_key_values=None, + decoder_past_key_values=None, use_cache=None, labels=None, inputs_embeds=None, @@ -1130,16 +1127,10 @@ def forward( labels = kwargs.pop("lm_labels") if "decoder_past_key_value_states" in kwargs: warnings.warn( - "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", - FutureWarning, - ) - past_key_values = kwargs.pop("decoder_past_key_value_states") - if "decoder_past_key_values" in kwargs: - warnings.warn( - "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.", FutureWarning, ) - past_key_values = kwargs.pop("decoder_past_key_values") + decoder_past_key_values = kwargs.pop("decoder_past_key_value_states") assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -1172,7 +1163,7 @@ def forward( # If decoding with past key value states, only the last tokens # should be given as an input - if past_key_values is not None: + if decoder_past_key_values is not None: assert labels is None, "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] @@ -1184,7 +1175,7 @@ def forward( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, - past_key_value_states=past_key_values, + past_key_value_states=decoder_past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=head_mask, @@ -1206,14 +1197,17 @@ def forward( loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + past = (encoder_outputs, decoder_outputs[1]) if use_cache is True else None if not return_dict: + if past is not None: + decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs return ((loss,) + output) if loss is not None else output return Seq2SeqLMOutput( loss=loss, logits=lm_logits, - past_key_values=decoder_outputs.past_key_values, + decoder_past_key_values=past, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, @@ -1221,10 +1215,14 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs): + assert past is not None, "past has to be defined for encoder_outputs" + + encoder_outputs, decoder_past_key_values = past + return { "decoder_input_ids": input_ids, - "past_key_values": past, + "decoder_past_key_values": decoder_past_key_values, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, "use_cache": use_cache, @@ -1233,12 +1231,14 @@ def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cac def _reorder_cache(self, past, beam_idx): # if decoder past is not included in output # speedy decoding is disabled and no need to reorder - if past is None: + if past[1] is None: logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") return past + decoder_past = past[1] + past = (past[0],) reordered_decoder_past = () - for layer_past_states in past: + for layer_past_states in decoder_past: # get the correct batch idx from layer past batch dim # batch dim of `past` is at 2nd position reordered_layer_past_states = () @@ -1252,4 +1252,4 @@ def _reorder_cache(self, past, beam_idx): assert len(reordered_layer_past_states) == len(layer_past_states) reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past + return past + (reordered_decoder_past,) diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index 439e2906bc48ca..e603643c252312 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -431,7 +431,7 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput): Base class for outputs of models predicting if two sentences are consecutive or not. Args: - logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): + lm_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`): Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). @@ -454,7 +454,7 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput): heads. """ - logits: tf.Tensor = None + lm_logits: tf.Tensor = None mc_logits: tf.Tensor = None past_key_values: Optional[List[tf.Tensor]] = None hidden_states: Optional[Tuple[tf.Tensor]] = None @@ -794,7 +794,7 @@ def call( return (lm_logits, mc_logits) + transformer_outputs[1:] return TFGPT2DoubleHeadsModelOutput( - logits=lm_logits, + lm_logits=lm_logits, mc_logits=mc_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, diff --git a/src/transformers/modeling_tf_openai.py b/src/transformers/modeling_tf_openai.py index 058596845753a9..49ca4de86c5145 100644 --- a/src/transformers/modeling_tf_openai.py +++ b/src/transformers/modeling_tf_openai.py @@ -394,7 +394,7 @@ class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput): Base class for outputs of models predicting if two sentences are consecutive or not. Args: - logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): + lm_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`): Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). @@ -411,7 +411,7 @@ class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput): heads. """ - logits: tf.Tensor = None + lm_logits: tf.Tensor = None mc_logits: tf.Tensor = None hidden_states: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None @@ -719,7 +719,7 @@ def call( return (lm_logits, mc_logits) + transformer_outputs[1:] return TFOpenAIGPTDoubleHeadsModelOutput( - logits=lm_logits, + lm_logits=lm_logits, mc_logits=mc_logits, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, diff --git a/src/transformers/modeling_tf_outputs.py b/src/transformers/modeling_tf_outputs.py index d0914b6ddf375f..8d61a175723ef6 100644 --- a/src/transformers/modeling_tf_outputs.py +++ b/src/transformers/modeling_tf_outputs.py @@ -113,13 +113,13 @@ class TFSeq2SeqModelOutput(ModelOutput): last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the decoder of the model. - If ``past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. - past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + If ``decoder_past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. + decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``past_key_values`` input) to speed up sequential decoding. + used (see ``decoder_past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -147,7 +147,7 @@ class TFSeq2SeqModelOutput(ModelOutput): """ last_hidden_state: tf.Tensor = None - past_key_values: Optional[List[tf.Tensor]] = None + decoder_past_key_values: Optional[List[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_last_hidden_state: Optional[tf.Tensor] = None @@ -259,12 +259,12 @@ class TFSeq2SeqLMOutput(ModelOutput): Languaged modeling loss. logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``past_key_values`` input) to speed up sequential decoding. + used (see ``decoder_past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -293,7 +293,7 @@ class TFSeq2SeqLMOutput(ModelOutput): loss: Optional[tf.Tensor] = None logits: tf.Tensor = None - past_key_values: Optional[List[tf.Tensor]] = None + decoder_past_key_values: Optional[List[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_last_hidden_state: Optional[tf.Tensor] = None @@ -366,12 +366,12 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput): Classification (or regression if config.num_labels==1) loss. logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). - past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``past_key_values`` input) to speed up sequential decoding. + used (see ``decoder_past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -400,7 +400,7 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput): loss: Optional[tf.Tensor] = None logits: tf.Tensor = None - past_key_values: Optional[List[tf.Tensor]] = None + decoder_past_key_values: Optional[List[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_last_hidden_state: Optional[tf.Tensor] = None @@ -512,12 +512,12 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput): Span-start scores (before SoftMax). end_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`): Span-end scores (before SoftMax). - past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``past_key_values`` input) to speed up sequential decoding. + used (see ``decoder_past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -547,7 +547,7 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput): loss: Optional[tf.Tensor] = None start_logits: tf.Tensor = None end_logits: tf.Tensor = None - past_key_values: Optional[List[tf.Tensor]] = None + decoder_past_key_values: Optional[List[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_last_hidden_state: Optional[tf.Tensor] = None diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index 9b451c8ff2730b..6a4379c0f66bdd 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -437,15 +437,15 @@ def call( ): if past_key_value_state is not None: - assert self.is_decoder, "Only decoder can use `past_key_values`" - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + assert self.is_decoder, "Only decoder can use `past_key_value_states`" + expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4 error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format( - expected_num_past_key_values, - "2 (past / key) for cross attention" if expected_num_past_key_values == 4 else "", + expected_num_past_key_value_states, + "2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "", len(past_key_value_state), ) - assert len(past_key_value_state) == expected_num_past_key_values, error_message + assert len(past_key_value_state) == expected_num_past_key_value_states, error_message self_attn_past_key_value_state = past_key_value_state[:2] cross_attn_past_key_value_state = past_key_value_state[2:] @@ -586,12 +586,11 @@ def call( encoder_attention_mask=None, inputs_embeds=None, head_mask=None, - past_key_values=None, + past_key_value_states=None, use_cache=None, output_attentions=None, output_hidden_states=None, training=False, - **kwargs, ): if isinstance(inputs, (tuple, list)): input_ids = inputs[0] @@ -600,7 +599,7 @@ def call( encoder_attention_mask = inputs[3] if len(inputs) > 3 else encoder_attention_mask inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds head_mask = inputs[5] if len(inputs) > 5 else head_mask - past_key_values = inputs[6] if len(inputs) > 6 else past_key_values + past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states use_cache = inputs[7] if len(inputs) > 7 else use_cache output_attentions = inputs[8] if len(inputs) > 8 else output_attentions output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states @@ -612,26 +611,13 @@ def call( encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) head_mask = inputs.get("head_mask", head_mask) - past_key_values = inputs.get("past_key_values", past_key_values) + past_key_value_states = inputs.get("past_key_value_states", past_key_value_states) use_cache = inputs.get("use_cache", use_cache) output_attentions = inputs.get("output_attentions", output_attentions) output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) assert len(inputs) <= 10, "Too many inputs." - - if "past_key_value_states" in inputs: - warnings.warn( - "The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", - FutureWarning, - ) - past_key_values = inputs.pop("past_key_value_states") else: input_ids = inputs - if "past_key_value_states" in kwargs: - warnings.warn( - "The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", - FutureWarning, - ) - past_key_values = kwargs.pop("past_key_value_states") output_attentions = output_attentions if output_attentions is not None else self.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states @@ -653,13 +639,13 @@ def call( batch_size, seq_length = input_shape - if past_key_values is not None: + if past_key_value_states is not None: assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format( input_shape, (batch_size, 1) ) # required mask seq length can be calculated via length of past # key value states and seq_length = 1 for the last token - mask_seq_length = shape_list(past_key_values[0][0])[2] + seq_length + mask_seq_length = shape_list(past_key_value_states[0][0])[2] + seq_length else: mask_seq_length = seq_length @@ -669,9 +655,9 @@ def call( encoder_seq_length = shape_list(encoder_hidden_states)[1] encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + # initialize past_key_value_states with `None` if past does not exist + if past_key_value_states is None: + past_key_value_states = [None] * len(self.block) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. @@ -691,7 +677,7 @@ def call( ) causal_mask = tf.cast(causal_mask, dtype=tf.float32) extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] - if past_key_values[0] is not None: + if past_key_value_states[0] is not None: extended_attention_mask = extended_attention_mask[:, :, -1:, :] else: extended_attention_mask = attention_mask[:, None, None, :] @@ -740,7 +726,7 @@ def call( hidden_states = self.dropout(inputs_embeds, training=training) - for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_values)): + for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -892,7 +878,7 @@ def _shift_right(self, input_ids): :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation. - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`). attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: @@ -903,13 +889,13 @@ def _shift_right(self, input_ids): Used in the cross-attention of the decoder. decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. - past_key_values (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains pre-computed key and value hidden-states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` + If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): - If `use_cache` is True, `past_key_values` are returned and can be used to speed up decoding (see `past_key_values`). + If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`). inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): Optionally, instead of passing :obj:`inputs` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `inputs` indices into associated vectors @@ -983,7 +969,7 @@ def call( encoder_outputs=None, inputs_embeds=None, head_mask=None, - past_key_values=None, + decoder_past_key_value_states=None, decoder_input_ids=None, decoder_attention_mask=None, decoder_inputs_embeds=None, @@ -992,7 +978,6 @@ def call( output_hidden_states=None, return_dict=None, training=False, - **kwargs, ): r""" Returns: @@ -1014,7 +999,7 @@ def call( encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds head_mask = inputs[4] if len(inputs) > 4 else head_mask - past_key_values = inputs[5] if len(inputs) > 5 else past_key_values + decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds @@ -1032,7 +1017,7 @@ def call( encoder_outputs = inputs.get("encoder_outputs", encoder_outputs) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) head_mask = inputs.get("head_mask", head_mask) - past_key_values = inputs.get("past_key_values", past_key_values) + decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states) decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids) decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask) decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds) @@ -1041,23 +1026,9 @@ def call( output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) return_dict = inputs.get("return_dict", return_dict) assert len(inputs) <= 13, "Too many inputs." - - if "past_key_value_states" in inputs: - warnings.warn( - "The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", - FutureWarning, - ) - past_key_values = inputs.pop("past_key_value_states") else: input_ids = inputs - if "past_key_value_states" in kwargs: - warnings.warn( - "The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", - FutureWarning, - ) - past_key_values = kwargs.pop("past_key_value_states") - use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.return_dict @@ -1083,7 +1054,7 @@ def call( # If decoding with past key value states, only the last tokens # should be given as an input - if past_key_values is not None: + if decoder_past_key_value_states is not None: if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: @@ -1098,7 +1069,7 @@ def call( attention_mask, decoder_inputs_embeds, head_mask, - past_key_values, + decoder_past_key_value_states, use_cache, output_attentions, output_hidden_states, @@ -1132,7 +1103,7 @@ def call( return TFSeq2SeqModelOutput( last_hidden_state=decoder_outputs[0], - past_key_values=past, + decoder_past_key_values=past, decoder_hidden_states=decoder_outputs[2], decoder_attentions=decoder_outputs[3], encoder_last_hidden_state=encoder_outputs[0], @@ -1193,7 +1164,7 @@ def call( encoder_outputs=None, inputs_embeds=None, head_mask=None, - past_key_values=None, + decoder_past_key_value_states=None, decoder_input_ids=None, decoder_attention_mask=None, decoder_inputs_embeds=None, @@ -1203,7 +1174,6 @@ def call( return_dict=None, labels=None, training=False, - **kwargs, ): r""" labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): @@ -1234,7 +1204,7 @@ def call( encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds head_mask = inputs[4] if len(inputs) > 4 else head_mask - past_key_values = inputs[5] if len(inputs) > 5 else past_key_values + decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds @@ -1253,7 +1223,7 @@ def call( encoder_outputs = inputs.get("encoder_outputs", encoder_outputs) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) head_mask = inputs.get("head_mask", head_mask) - past_key_values = inputs.get("past_key_values", past_key_values) + decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states) decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids) decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask) decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds) @@ -1263,23 +1233,9 @@ def call( return_dict = inputs.get("return_dict", return_dict) labels = inputs.get("labels", labels) assert len(inputs) <= 14, "Too many inputs." - - if "past_key_value_states" in inputs: - warnings.warn( - "The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", - FutureWarning, - ) - past_key_values = inputs.pop("past_key_value_states") else: input_ids = inputs - if "past_key_value_states" in kwargs: - warnings.warn( - "The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", - FutureWarning, - ) - past_key_values = kwargs.pop("past_key_value_states") - use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.return_dict @@ -1310,7 +1266,7 @@ def call( # If decoding with past key value states, only the last tokens # should be given as an input - if past_key_values is not None: + if decoder_past_key_value_states is not None: if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: @@ -1325,7 +1281,7 @@ def call( attention_mask, decoder_inputs_embeds, head_mask, - past_key_values, + decoder_past_key_value_states, use_cache, output_attentions, output_hidden_states, @@ -1368,7 +1324,7 @@ def call( return TFSeq2SeqLMOutput( loss=loss, logits=logits, - past_key_values=past, + decoder_past_key_values=past, decoder_hidden_states=decoder_outputs[2], decoder_attentions=decoder_outputs[3], encoder_last_hidden_state=encoder_outputs[0], @@ -1381,14 +1337,14 @@ def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, # first step if len(past) < 2: - encoder_outputs, past_key_values = past, None + encoder_outputs, decoder_past_key_value_states = past, None else: - encoder_outputs, past_key_values = past[0], past[1] + encoder_outputs, decoder_past_key_value_states = past[0], past[1] return { "inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy "decoder_input_ids": inputs, # inputs are the decoder_input_ids - "past_key_values": past_key_values, + "decoder_past_key_value_states": decoder_past_key_value_states, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, "use_cache": use_cache, diff --git a/src/transformers/modeling_transfo_xl.py b/src/transformers/modeling_transfo_xl.py index 9b0e276e2b5bd0..c57be4afd37dbf 100644 --- a/src/transformers/modeling_transfo_xl.py +++ b/src/transformers/modeling_transfo_xl.py @@ -661,15 +661,6 @@ class TransfoXLLMHeadModelOutput(ModelOutput): hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None - @property - def logits(self): - # prediciton scores are the output of the adaptive softmax, see - # the file `modeling_transfo_xl_utilities`. Since the adaptive - # softmax returns the log softmax value, `self.prediciton_scores` - # are strictly speaking not exactly `logits`, but behave the same - # way logits do. - return self.prediction_scores - TRANSFO_XL_START_DOCSTRING = r""" diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index 8aefee7f85de3f..3af9fbc9c7edbe 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -33,7 +33,6 @@ from transformers import ( BertLMHeadModel, BertModel, - BertTokenizer, EncoderDecoderConfig, EncoderDecoderModel, GPT2LMHeadModel, @@ -129,11 +128,10 @@ def check_encoder_decoder_model_from_pretrained( decoder_config, decoder_input_ids, decoder_attention_mask, - return_dict, **kwargs ): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) - kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict} + kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model} enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( @@ -363,11 +361,7 @@ def test_encoder_decoder_model_from_pretrained_configs(self): def test_encoder_decoder_model_from_pretrained(self): input_ids_dict = self.prepare_config_and_inputs() - self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=False) - - def test_encoder_decoder_model_from_pretrained_return_dict(self): - input_ids_dict = self.prepare_config_and_inputs() - self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True) + self.check_encoder_decoder_model_from_pretrained(**input_ids_dict) def test_save_and_load_from_pretrained(self): input_ids_dict = self.prepare_config_and_inputs() @@ -472,22 +466,6 @@ def prepare_config_and_inputs(self): "labels": decoder_token_labels, } - @slow - def test_bert2bert_summarization(self): - model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") - model.to(torch_device) - tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") - - ARTICLE = """(CNN)Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members singing a racist chant. SAE's national chapter suspended the students, but University of Oklahoma President David Boren took it a step further, saying the university's affiliation with the fraternity is permanently done. The news is shocking, but it's not the first time SAE has faced controversy. SAE was founded March 9, 1856, at the University of Alabama, five years before the American Civil War, according to the fraternity website. When the war began, the group had fewer than 400 members, of which "369 went to war for the Confederate States and seven for the Union Army," the website says. The fraternity now boasts more than 200,000 living alumni, along with about 15,000 undergraduates populating 219 chapters and 20 "colonies" seeking full membership at universities. SAE has had to work hard to change recently after a string of member deaths, many blamed on the hazing of new recruits, SAE national President Bradley Cohen wrote in a message on the fraternity's website. The fraternity's website lists more than 130 chapters cited or suspended for "health and safety incidents" since 2010. At least 30 of the incidents involved hazing, and dozens more involved alcohol. However, the list is missing numerous incidents from recent months. Among them, according to various media outlets: Yale University banned the SAEs from campus activities last month after members allegedly tried to interfere with a sexual misconduct investigation connected to an initiation rite. Stanford University in December suspended SAE housing privileges after finding sorority members attending a fraternity function were subjected to graphic sexual content. And Johns Hopkins University in November suspended the fraternity for underage drinking. "The media has labeled us as the 'nation's deadliest fraternity,' " Cohen said. In 2011, for example, a student died while being coerced into excessive alcohol consumption, according to a lawsuit. SAE's previous insurer dumped the fraternity. "As a result, we are paying Lloyd's of London the highest insurance rates in the Greek-letter world," Cohen said. Universities have turned down SAE's attempts to open new chapters, and the fraternity had to close 12 in 18 months over hazing incidents.""" - - EXPECTED_SUMMARY = """sae was founded in 1856, five years before the civil war. the fraternity has had to work hard to change recently. the university of oklahoma president says the university's affiliation with the fraternity is permanently done. the sae has had a string of members in recent months.""" - - input_ids = tokenizer(ARTICLE, return_tensors="pt").input_ids.to(torch_device) - output_ids = model.generate(input_ids) - summary = tokenizer.decode(output_ids[0], skip_special_tokens=True) - - self.assertEqual(summary, EXPECTED_SUMMARY) - class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): def get_encoder_decoder_model(self, config, decoder_config): diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index dcb0faefe4e391..17e0a6bc48d3b7 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -289,9 +289,9 @@ def create_and_check_double_lm_head_model( } result = model(**inputs) - self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.lm_loss.shape, ()) self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size) + result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size) ) self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices)) @@ -324,7 +324,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () all_generative_model_classes = ( - (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () + (GPT2LMHeadModel,) if is_torch_available() else () ) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly test_missing_keys = False diff --git a/tests/test_modeling_openai.py b/tests/test_modeling_openai.py index 92a0335cda74bd..1014e1eea4a12b 100644 --- a/tests/test_modeling_openai.py +++ b/tests/test_modeling_openai.py @@ -131,8 +131,8 @@ def create_and_check_double_lm_head_model(self, config, input_ids, head_mask, to model.eval() result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) - self.parent.assertEqual(result.loss.shape, ()) - self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + self.parent.assertEqual(result.lm_loss.shape, ()) + self.parent.assertEqual(result.lm_logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 4c411d84492a53..fef623807ca192 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -159,15 +159,17 @@ def create_and_check_model( ) result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) decoder_output = result.last_hidden_state - decoder_past = result.past_key_values + decoder_past = result.decoder_past_key_values encoder_output = result.encoder_last_hidden_state self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size)) - # There should be `num_layers` key value embeddings stored in decoder_past - self.parent.assertEqual(len(decoder_past), config.num_layers) - # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple - self.parent.assertEqual(len(decoder_past[0]), 4) + self.parent.assertEqual(len(decoder_past), 2) + self.parent.assertTrue(torch.all(decoder_past[0][0] == encoder_output)) + # There should be `num_layers` key value embeddings stored in decoder_past[1] + self.parent.assertEqual(len(decoder_past[1]), config.num_layers) + # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple + self.parent.assertEqual(len(decoder_past[1][0]), 4) def create_and_check_with_lm_head( self, diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index 4cd20be25e609b..41b973719eae67 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -238,7 +238,7 @@ def create_and_check_gpt2_double_head( } result = model(inputs) self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size) + result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size) ) self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices)) diff --git a/tests/test_modeling_tf_openai.py b/tests/test_modeling_tf_openai.py index 6e57db2d39f446..e3bd82dae23a68 100644 --- a/tests/test_modeling_tf_openai.py +++ b/tests/test_modeling_tf_openai.py @@ -151,7 +151,7 @@ def create_and_check_openai_gpt_double_head( } result = model(inputs) self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size) + result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size) ) self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices)) diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index 7c50bd15c5328b..eb575f5131e9bd 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -96,7 +96,7 @@ def create_and_check_t5_model(self, config, input_ids, input_mask, token_labels) result = model(input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids) decoder_output = result.last_hidden_state - decoder_past = result.past_key_values + decoder_past = result.decoder_past_key_values encoder_output = result.encoder_last_hidden_state self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size]) self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])