Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOCS] Add NoRepeatNGramLogitsProcessor Example for LogitsProcessor class #25186

Merged
merged 16 commits into from Aug 7, 2023
72 changes: 68 additions & 4 deletions src/transformers/generation/logits_process.py
Expand Up @@ -438,17 +438,53 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to


def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
"""
Assume ngram_size=2 and prev_input_ids=tensor([[list of generated tokens]]). The output of generated ngrams look
like this {(generated word #1,): [tokenized list of next words observed], (generated word #2,): [tokenized list of
next words observed] }.
Rishab26 marked this conversation as resolved.
Show resolved Hide resolved

Args:
ngram_size (`int`):
`ngram_size` that can only occur once.
Rishab26 marked this conversation as resolved.
Show resolved Hide resolved
prev_input_ids (`torch.Tensor`):
A tensor containing tokenized input for each hypothesis.
num_hypos (`int`):
The number of hypotheses for which n-grams need to be generated.

Returns:
generated_ngrams (`dict`):
Dictionary of generated ngrams.
"""
# Initialize an empty list of dictionaries, one for each hypothesis (index) in the range of num_hypos
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
# Loop through each n-gram of size ngram_size in the list of tokens (gen_tokens)
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
return generated_ngrams


def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
"""
Determines the banned tokens for each hypothesis based on previously generated n-grams.
Rishab26 marked this conversation as resolved.
Show resolved Hide resolved

Args:
banned_ngrams (`dict`):
A dictionary containing previously generated n-grams for each hypothesis.
prev_input_ids (`torch.Tensor`):
A `tensor` containing tokenized input for each hypothesis in the current batch.
Rishab26 marked this conversation as resolved.
Show resolved Hide resolved
ngram_size (`int`):
`ngram_size` that can only occur once.
Rishab26 marked this conversation as resolved.
Show resolved Hide resolved
cur_len (`int`):
The current length of the token sequences for which the n-grams are being checked.

Returns:
List of tokens that are banned.

Rishab26 marked this conversation as resolved.
Show resolved Hide resolved
"""
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - ngram_size
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
Expand All @@ -462,9 +498,7 @@ def _calc_banned_ngram_tokens(
if cur_len + 1 < ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]

generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)

banned_tokens = [
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
for hypo_idx in range(num_hypos)
Expand All @@ -474,12 +508,43 @@ def _calc_banned_ngram_tokens(

class NoRepeatNGramLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces no repetition of n-grams. See
N-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
sentence: "She runs fast", the bi-grams (n = 2) would be ("she","runs") and ("runs","fast"). In text generation,
Rishab26 marked this conversation as resolved.
Show resolved Hide resolved
avoiding repetitions of word sequences provides a more diverse output. This [`LogitsProcessor`] enforces no
repetition of n-grams by setting the scores of banned tokens to negative infinity which eliminates those tokens
from consideration when further processing the scores.
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).

<Tip>

Use n-gram penalties with care. For instance, penalizing 2-grams (bigrams) in an article about the city of New York
might lead to undesirable outcomes where the city's name appears only once in the entire text.
[Reference](https://huggingface.co/blog/how-to-generate)

</Tip>

Args:
ngram_size (`int`):
All ngrams of size `ngram_size` can only occur once.

Examples:

```py
>>> from transformers import GPT2Tokenizer, AutoModelForCausalLM

>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
>>> inputs = tokenizer(["I enjoy watching football"], return_tensors="pt")

>>> output = model.generate(**inputs, max_length=50)
>>> print(tokenizer.decode(output[0], skip_special_tokens=True))
"I enjoy playing football on the weekends, but I'm not a big fan of the idea of playing in the middle of the night. I'm not a big fan of the idea of playing in the middle of the night. I'm not a big"

>>> # Now let's add ngram size using <no_repeat_ngram_size> in model.generate. This should stop the repetitions in the output.
>>> output = model.generate(**inputs, max_length=50, no_repeat_ngram_size=2)
>>> print(tokenizer.decode(output[0], skip_special_tokens=True))
I enjoy playing football on the weekends, but I'm not a big fan of the idea of playing in the middle of a game. I think it's a bit of an overreaction to the fact that we're playing a team that's playing"
```
"""

def __init__(self, ngram_size: int):
Expand All @@ -491,7 +556,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1]
banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)

for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")

Expand Down