Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF generate refactor - Greedy Search #15562

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
789f2f4
TF generate start refactor
patrickvonplaten Feb 8, 2022
6702ed3
Add tf tests for sample generate
patrickvonplaten Feb 8, 2022
cf6de09
re-organize
patrickvonplaten Feb 9, 2022
7164d93
boom boom
patrickvonplaten Feb 9, 2022
0844c83
Apply suggestions from code review
patrickvonplaten Feb 9, 2022
b5ae041
re-add
patrickvonplaten Feb 9, 2022
5c3f3f5
Merge branch 'tf_generate_refactor' of https://github.com/patrickvonp…
patrickvonplaten Feb 9, 2022
0671b89
add all code
patrickvonplaten Feb 9, 2022
c23bff2
make random greedy pass
patrickvonplaten Feb 10, 2022
7786d18
make encoder-decoder random work
patrickvonplaten Feb 10, 2022
d04530f
further improvements
patrickvonplaten Feb 10, 2022
73090dd
delete bogus file
patrickvonplaten Feb 10, 2022
3db93ff
make gpt2 and t5 tests work
patrickvonplaten Feb 11, 2022
7a7b7ef
finish logits tests
patrickvonplaten Feb 11, 2022
7b1b2cc
correct logits processors
patrickvonplaten Feb 14, 2022
a8cf81e
Merge branch 'master' of https://github.com/huggingface/transformers …
patrickvonplaten Feb 14, 2022
1a9e870
correct past / encoder_outputs drama
patrickvonplaten Feb 14, 2022
385c24f
refactor some methods
patrickvonplaten Feb 14, 2022
bd750ff
another fix
patrickvonplaten Feb 14, 2022
49e33b0
refactor shape_list
patrickvonplaten Feb 14, 2022
4b2460d
fix more shape list
patrickvonplaten Feb 14, 2022
ed5f2ff
import shape
patrickvonplaten Feb 14, 2022
dd1c214
finish docs
patrickvonplaten Feb 14, 2022
0c7d049
fix imports
patrickvonplaten Feb 14, 2022
726355e
make style
patrickvonplaten Feb 14, 2022
6293862
correct tf utils
patrickvonplaten Feb 14, 2022
b2934ee
Fix TFRag as well
patrickvonplaten Feb 14, 2022
4f6d927
Apply Lysandre's and Sylvais suggestions
patrickvonplaten Feb 15, 2022
39c0b65
Update tests/test_generation_tf_logits_process.py
patrickvonplaten Feb 15, 2022
920a991
Update src/transformers/tf_utils.py
patrickvonplaten Feb 15, 2022
4b7d994
remove cpu according to gante
patrickvonplaten Feb 15, 2022
3fbe55b
correct logit processor
patrickvonplaten Feb 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
291 changes: 291 additions & 0 deletions src/transformers/generation_tf_logits_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import tensorflow as tf
from abc import ABC


from .file_utils import add_start_docstrings
from .utils.logging import get_logger


logger = get_logger(__name__)


TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.

Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.

[What are input IDs?](../glossary#input-ids)
scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search
kwargs:
Additional logits processor specific kwargs.

Return:
`tf.Tensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.

"""


class TFLogitsProcessor(ABC):
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
"""Abstract base class for all logit processors that can be applied during generation."""

@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
"""TF method for processing logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)


class TFLogitsProcessorList(list):
"""
This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process
a `scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each
[`TFLogitsProcessor`] to the inputs.
"""

@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int, **kwargs) -> tf.Tensor:
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 3:
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
raise ValueError(
f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor."
)
scores = processor(input_ids, scores, cur_len, **kwargs)
else:
scores = processor(input_ids, scores, cur_len)
return scores


class TFMinLengthLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.

Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
"""

def __init__(self, min_length: int, eos_token_id: int):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")

if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")

self.min_length = min_length
self.eos_token_id = eos_token_id

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
# create boolean flag to decide if min length penalty should be applied
apply_penalty = 1 - tf.clip(cur_len - self.min_length, 0, 1)

scores = tf.where(
apply_penalty, jax.ops.index_update(scores, jax.ops.index[:, self.eos_token_id], -float("inf")), scores
)

return scores


class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
r"""
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.

Args:
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
"""

def __init__(self, penalty: float):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

self.penalty = penalty

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids)

# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)

scores.scatter_(1, input_ids, score)
return scores


class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
"""
[`LogitsProcessor`] that enforces that specified sequences will never be sampled.

Args:
bad_words_ids (`List[List[int]]`):
List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
that should not appear in the generated text, use `tokenizer(bad_word, add_prefix_space=True).input_ids`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
"""

def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):

if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.")
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
if any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
for bad_word_ids in bad_words_ids
):
raise ValueError(
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)

bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
self.bad_words_id_length_1 = []
self.bad_words_id_length_greater_than_1 = []
for word in bad_words_ids:
if len(word) == 1:
self.bad_words_id_length_1.append(word[0])
else:
self.bad_words_id_length_greater_than_1.append(word)

self.static_bad_words_mask: Optional[torch.LongTensor] = None

for banned_token_seq in self.bad_words_id_length_greater_than_1:
if len(banned_token_seq) == 0:
raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.static_bad_words_mask is None and len(self.bad_words_id_length_1) > 0:
self.static_bad_words_mask = self._calc_static_bad_word_mask(scores)

dynamic_banned_tokens = self._calc_banned_bad_words_ids(input_ids.tolist())
scores = self._set_scores_to_inf_for_banned_tokens(scores, dynamic_banned_tokens)

return scores

def _calc_static_bad_word_mask(self, scores: torch.FloatTensor) -> torch.BoolTensor:
static_bad_words_mask = torch.zeros(scores.shape[1])
static_bad_words_mask[self.bad_words_id_length_1] = 1
return static_bad_words_mask.unsqueeze(0).to(scores.device).bool()

def _tokens_match(self, prev_tokens: List[int], tokens: List[int]) -> bool:
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
elif len(tokens) > len(prev_tokens):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
else:
return prev_tokens[-len(tokens) :] == tokens

def _calc_banned_bad_words_ids(self, prev_input_ids: List[List[int]]) -> Iterable[int]:
banned_tokens = []
for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []
for banned_token_seq in self.bad_words_id_length_greater_than_1:
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]):
banned_tokens_slice.append(banned_token_seq[-1])

banned_tokens.append(banned_tokens_slice)

return banned_tokens

def _set_scores_to_inf_for_banned_tokens(
self, scores: torch.Tensor, banned_tokens: List[List[int]]
) -> torch.Tensor:
"""
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
list of list of banned tokens to ban in the format [[batch index, vocabulary position],...

Args:
scores: logits distribution of shape (batch size, vocabulary size)
banned_tokens: list of list of tokens to ban of length (batch_size)
"""
banned_mask_list = []
for idx, batch_banned_tokens in enumerate(banned_tokens):
for token in batch_banned_tokens:
# Eliminates invalid bad word IDs that are over the vocabulary size.
if token <= scores.shape[1]:
banned_mask_list.append([idx, token])
else:
logger.error(
f"An invalid bad word ID is defined: {token}. This ID is not contained in the "
f"vocabulary, and is therefore ignored."
)
if not banned_mask_list and self.static_bad_words_mask is None:
return scores

else:
if banned_mask_list:
banned_mask = torch.LongTensor(banned_mask_list)
indices = torch.ones(len(banned_mask))
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ]
# [ 0 0 0 ]
# [ 1 0 0 ]

banned_mask = (
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size())
.to(scores.device)
.to_dense()
.bool()
)

if self.static_bad_words_mask is not None:
banned_mask = torch.bitwise_or(banned_mask, self.static_bad_words_mask)
else:
banned_mask = self.static_bad_words_mask

scores = scores.masked_fill(banned_mask, -float("inf"))
return scores


class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
r"""
[`LogitsProcessor`] that enforces no repetition of n-grams. See
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).

Args:
ngram_size (`int`):
All ngrams of size `ngram_size` can only occur once.
"""

def __init__(self, ngram_size: int):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1]
banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)

for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")

return scores
Loading