Skip to content

Commit

Permalink
Adding FlaxNoRepeatNGramLogitsProcessor (#29677)
Browse files Browse the repository at this point in the history
* fix issue with logit processor in beam search in Flax

* adding FlaxNoRepeatNGramLogitsProcessor class + unit test

* style correction and code verification

* add FlaxNoRepeatNGramLogitsProcessor to the test_processor_list and test_processor_list_jitted tests

* fix an issue where ngrams are banned only if they appear ==1 time + update description of get_previous_ngrams

* replace non-jit compatible masking of ngrams that are not yet generated with jittable version

* Revert "fix issue with logit processor in beam search in Flax"

This reverts commit 09b70d7.

* add FlaxNoRepeatNGramLogitsProcessor to _get_logits_processor

* change the method of casting to boolean of banned tokens indices

* fix code style

* remove some useless operations + significantly faster computation of update indices using jax.lax.fori_loop

* remove useless loop iterations

* set some variables that were calculated and used multiple times

* fix format
  • Loading branch information
giganttheo committed Apr 2, 2024
1 parent 33288ff commit fed27ff
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@
"FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper",
"FlaxWhisperTimeStampLogitsProcessor",
"FlaxNoRepeatNGramLogitsProcessor",
]
_import_structure["flax_utils"] = [
"FlaxGenerationMixin",
Expand Down Expand Up @@ -294,6 +295,7 @@
FlaxLogitsProcessorList,
FlaxLogitsWarper,
FlaxMinLengthLogitsProcessor,
FlaxNoRepeatNGramLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper,
Expand Down
87 changes: 87 additions & 0 deletions src/transformers/generation/flax_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import jax
import jax.lax as lax
import jax.numpy as jnp
from jax.experimental import sparse

from ..utils import add_start_docstrings
from ..utils.logging import get_logger
Expand Down Expand Up @@ -455,3 +456,89 @@ def handle_cumulative_probs(logprobs_k, scores_k):
scores = jax.vmap(handle_cumulative_probs)(logprobs, scores)

return scores


class FlaxNoRepeatNGramLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] that enforces no repetition of n-grams. See
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
Args:
ngram_size (`int`):
All ngrams of size `ngram_size` can only occur once.
"""

def __init__(self, ngram_size: int):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size

def get_previous_ngrams(self, input_ids: jnp.ndarray, vocab_size: int, cur_len: int):
"""
get a matrix of size (batch_size,) + (vocab_size,)*n (for n-grams) that
represent the n-grams that occured previously.
The BCOO representation allow to store only the few non-zero entries, instead of the full (huge) matrix
"""
batch_size, seq_len = input_ids.shape
# number of n-grams in the whole sequence
seq_ngrams = seq_len - (self.ngram_size - 1)
# number of n-grams in the currently generated sequence
cur_ngrams = cur_len - (self.ngram_size - 1)

def body_fun(i, val):
b = i % batch_size
pos = i // batch_size
return val.at[i].set(
jnp.array(
[
b,
]
+ [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)]
)
)

shape = (batch_size * seq_ngrams, self.ngram_size + 1)
all_update_indices = jax.lax.fori_loop(
0, batch_size * cur_ngrams, body_fun, jnp.zeros(shape, dtype=input_ids.dtype)
)

# ignore the n-grams not yet generated
data = (jnp.arange(batch_size * seq_ngrams) < batch_size * cur_ngrams).astype("float32")

return sparse.BCOO((data, all_update_indices), shape=(batch_size,) + (vocab_size,) * self.ngram_size)

def get_banned_tokens_mask(self, latest_tokens: jnp.ndarray, previous_ngrams) -> jnp.ndarray:
"""
Determines which tokens must be banned given latest tokens and the previously seen
ngrams.
"""

@sparse.sparsify
@jax.vmap
def inner_fn(latest_tokens, previous_ngrams):
return previous_ngrams[tuple(latest_tokens)]

return sparse.bcoo_todense(inner_fn(latest_tokens, previous_ngrams))

def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
def true_fn():
_, vocab_size = scores.shape
# store the previously seen n-grams
previous_ngrams = self.get_previous_ngrams(input_ids, vocab_size, cur_len)

# get the n-1 last tokens that prefix the n-gram being generated
latest_tokens = jnp.zeros((input_ids.shape[0], self.ngram_size - 1), dtype=input_ids.dtype)
latest_tokens = jax.lax.dynamic_update_slice(
latest_tokens,
jax.lax.dynamic_slice(
input_ids, (0, cur_len - (self.ngram_size - 1)), (input_ids.shape[0], (self.ngram_size - 1))
),
(0, 0),
)

# compute the banned tokens, ie all the tokens that when added to the latest tokens lead to a n-gram that was previously generated
banned_tokens_indices_mask = self.get_banned_tokens_mask(latest_tokens, previous_ngrams).astype("bool")
return jnp.where(banned_tokens_indices_mask, -float("inf"), scores)

output = jax.lax.cond((cur_len >= self.ngram_size - 1), true_fn, lambda: scores)
return output
3 changes: 3 additions & 0 deletions src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
FlaxForceTokensLogitsProcessor,
FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor,
FlaxNoRepeatNGramLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper,
Expand Down Expand Up @@ -534,6 +535,8 @@ def _get_logits_processor(
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
]
processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids))
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
processors.append(FlaxNoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
processors = self._merge_criteria_processor_list(processors, logits_processor)

return processors
Expand Down
45 changes: 43 additions & 2 deletions tests/generation/test_flax_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor,
FlaxNoRepeatNGramLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
Expand Down Expand Up @@ -197,6 +198,26 @@ def test_forced_eos_token_logits_processor(self):
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertFalse(jnp.isinf(scores).any())

def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2

cur_len = 4
input_ids = np.array([[1, 1, 2, 1], [0, 1, 0, 1]], dtype="i4")
scores = self._get_uniform_logits(batch_size, vocab_size)

no_repeat_proc_2_gram = FlaxNoRepeatNGramLogitsProcessor(2)
no_repeat_proc_3_gram = FlaxNoRepeatNGramLogitsProcessor(3)

filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores, cur_len=cur_len)
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores, cur_len=cur_len)

# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(jnp.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]])

# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(jnp.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]])

def test_processor_list(self):
batch_size = 4
sequence_length = 10
Expand All @@ -216,6 +237,7 @@ def test_processor_list(self):
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8)
no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2)

# instantiate all logits processors
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
Expand All @@ -231,10 +253,19 @@ def test_processor_list(self):
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = no_repeat_proc(input_ids, scores, cur_len=cur_len)

# with processor list
processor = FlaxLogitsProcessorList(
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
[
temp_dist_warp,
top_k_warp,
top_p_warp,
min_dist_proc,
bos_dist_proc,
eos_dist_proc,
no_repeat_proc,
]
)
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)

Expand Down Expand Up @@ -263,6 +294,7 @@ def test_processor_list_jitted(self):
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8)
no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2)

# instantiate all logits processors
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
Expand All @@ -279,12 +311,21 @@ def run_no_processor_list(input_ids, scores, cur_len):
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = no_repeat_proc(input_ids, scores, cur_len=cur_len)
return scores

# with processor list
def run_processor_list(input_ids, scores, cur_len):
processor = FlaxLogitsProcessorList(
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
[
temp_dist_warp,
top_k_warp,
top_p_warp,
min_dist_proc,
bos_dist_proc,
eos_dist_proc,
no_repeat_proc,
]
)
scores = processor(input_ids, scores, cur_len=cur_len)
return scores
Expand Down

0 comments on commit fed27ff

Please sign in to comment.