Skip to content

Commit

Permalink
Revert "[Generate] Facilitate PyTorch generate using ModelOutputs (h…
Browse files Browse the repository at this point in the history
…uggingface#6735)"

This reverts commit bf4184e.
  • Loading branch information
fabiocapsouza committed Nov 15, 2020
1 parent 2506f12 commit 1fc4fd3
Show file tree
Hide file tree
Showing 20 changed files with 260 additions and 394 deletions.
9 changes: 4 additions & 5 deletions docs/source/model_doc/encoderdecoder.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
Encoder Decoder Models
------------------------

The :class:`~transformers.EncoderDecoderModel` can be used to initialize a sequence-to-sequence model with any pre-trained autoencoding model as the encoder and any pre-trained autoregressive model as the decoder.
This class can wrap an encoder model, such as ``BertModel`` and a decoder modeling with a language modeling head, such as ``BertForMaskedLM`` into a encoder-decoder model.

The effectiveness of initializing sequence-to-sequence models with pre-trained checkpoints for sequence generation tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks <https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
The ``EncoderDecoderModel`` class allows to instantiate a encoder decoder model using the ``from_encoder_decoder_pretrain`` class method taking a pretrained encoder and pretrained decoder model as an input.
The ``EncoderDecoderModel`` is saved using the standard ``save_pretrained()`` method and can also again be loaded using the standard ``from_pretrained()`` method.

After such an :class:`~transformers.EncoderDecoderModel` has been trained / fine-tuned, it can be saved / loaded just like any other models (see Examples for more information).

An application of this architecture could be to leverage two pre-trained :obj:`transformers.BertModel` models as the encoder and decoder for a summarization model as was shown in: `Text Summarization with Pretrained Encoders <https://arxiv.org/abs/1910.13461>`_ by Yang Liu and Mirella Lapata.
An application of this architecture could be *summarization* using two pretrained Bert models as is shown in the paper: `Text Summarization with Pretrained Encoders <https://arxiv.org/abs/1910.13461>`_ by Yang Liu and Mirella Lapata.


``EncoderDecoderConfig``
Expand Down
70 changes: 36 additions & 34 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from torch import Tensor
from torch.nn import functional as F

from .file_utils import ModelOutput
from .utils import logging


Expand All @@ -47,6 +46,14 @@ def adjust_logits_during_generation(self, logits, **kwargs):
"""
return logits

def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
if len(outputs) <= 1 or use_cache is False:
return False
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
return False
return True

def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
"""
Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
Expand Down Expand Up @@ -130,7 +137,7 @@ def generate(
attention_mask: Optional[torch.LongTensor] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
**model_kwargs
**model_specific_kwargs
) -> torch.LongTensor:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
Expand Down Expand Up @@ -201,7 +208,7 @@ def generate(
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
model_kwargs:
model_specific_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
Return:
Expand Down Expand Up @@ -393,7 +400,7 @@ def generate(

# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)

# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
Expand Down Expand Up @@ -421,8 +428,8 @@ def generate(
cur_len = 1

assert (
batch_size == encoder_outputs.last_hidden_state.shape[0]
), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
batch_size == encoder_outputs[0].shape[0]
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "

# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
expanded_batch_idxs = (
Expand All @@ -432,16 +439,11 @@ def generate(
.view(-1)
.to(input_ids.device)
)

# expand encoder_outputs
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
0, expanded_batch_idxs
)

# save encoder_outputs in `model_kwargs`
model_kwargs["encoder_outputs"] = encoder_outputs
encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])

else:
encoder_outputs = None
cur_len = input_ids.shape[-1]

assert (
Expand Down Expand Up @@ -469,9 +471,10 @@ def generate(
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
model_kwargs=model_kwargs,
model_specific_kwargs=model_specific_kwargs,
)
else:
output = self._generate_no_beam_search(
Expand All @@ -489,9 +492,10 @@ def generate(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
model_kwargs=model_kwargs,
model_specific_kwargs=model_specific_kwargs,
)

return output
Expand All @@ -512,9 +516,10 @@ def _generate_no_beam_search(
pad_token_id,
eos_token_id,
batch_size,
encoder_outputs,
attention_mask,
use_cache,
model_kwargs,
model_specific_kwargs,
):
"""Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
Expand All @@ -523,14 +528,15 @@ def _generate_no_beam_search(
unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length)

past = None
past = (encoder_outputs, None) if encoder_outputs is not None else None

while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
)

outputs = self(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :]
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]

scores = self.postprocess_next_token_scores(
scores=next_token_logits,
Expand All @@ -547,10 +553,8 @@ def _generate_no_beam_search(
)

# if model has past, then set the past variable to speed up decoding
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
if self._use_cache(outputs, use_cache):
past = outputs[1]

if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
Expand Down Expand Up @@ -617,9 +621,10 @@ def _generate_beam_search(
length_penalty,
num_beams,
vocab_size,
encoder_outputs,
attention_mask,
use_cache,
model_kwargs,
model_specific_kwargs,
):
"""Generate sequences for each example with beam search."""

Expand All @@ -638,24 +643,21 @@ def _generate_beam_search(
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)

# cache compute states
past = None
past = (encoder_outputs, None) if encoder_outputs is not None else None

# done sentences
done = [False for _ in range(batch_size)]

while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
)
outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)

# if model has past, then set the past variable to speed up decoding
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems

if self._use_cache(outputs, use_cache):
past = outputs[1]
if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solution
next_token_logits = self.adjust_logits_during_generation(
Expand Down
Loading

0 comments on commit 1fc4fd3

Please sign in to comment.