Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling users to provide their own stopping_criteria + logits_processor to generate. #12219

Closed
51 changes: 45 additions & 6 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,13 +573,16 @@ def _get_logits_processor(
num_beam_groups: int,
diversity_penalty: float,
remove_invalid_values: bool,
logits_processor: Optional[LogitsProcessorList],
) -> LogitsProcessorList:
"""
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
:obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head.
"""
processors = LogitsProcessorList()

if logits_processor is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I also think we shouldn't even call the function if logits_processor is used. It makes our live much easier and the design a bit cleaner.

What I don't like about the current design is that:

  • if say both forced_eos_token_id is provided in generate() and as an input at the moment, then we have two forced_eos_token_id processors in the list which leads to weird behavior...

A first simple solution is to just not call this function IMO. We could always later adapt it for more advanced functionality if needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to keep in mind is that we essentially deactivate a whole bunch of options without much transparency about it so a user would need to know which options correspond to a logits_processor. E.g. I just had to look up if temperature would be affected if I added a logits_processor. Same with stopping_criteria.

The other way around it is more transparent I think if something goes wrong: Ah I passed forced_eos_token_id and a logits_processor doing the same thing, maybe that is not good.

For me the main use-case is to add a custom processer/criteria in addition to what I can already do with the kwargs. Otherwise, why go the extra step of using the arguably more complicated API if you can just pass it as an argument.

What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an argument in favour of @Narsil's option 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good argument @lvwerra and I fully understand what you mean.

I'm however really concerned about the complexity that option 1 adds for IMO very few use cases. Also from a backwards breaking point of view it's pretty much impossible to go from option 1 to option 2 in the future if this feature becomes more important where as it's much easier to go from option 2 to option 1 in the future.

For me the main use-case is to add a custom processer/criteria in addition to what I can already do with the kwargs. Otherwise, why go the extra step of using the arguably more complicated API if you can just pass it as an argument.

Very good point and that's the big drawback of option 2. I believe that people having to use special logits processors are able to create it themselves.

Ok, how about we do something in between option 1 and option 2 that doesn't create a crazy complex logic.

If one passes a logits processor, we do the following:

  1. just create the normal logits_processor that would have been created without passing one.
  2. if any object of the passed logits processor is in the already created logits processor then we raise an error and tell the user which logits processor was created twice (and ideally which paramter has to be changed for this)
  3. If there is no error make a union of the two lists and we're good to go.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this compromise, what do you think @Narsil?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we can go for the solution if you want @lvwerra :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that.

processors = LogitsProcessorList()
else:
processors = logits_processor
# init warp parameters
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
no_repeat_ngram_size = (
Expand Down Expand Up @@ -638,9 +641,18 @@ def _get_logits_processor(
processors.append(InfNanRemoveLogitsProcessor())
return processors

def _get_stopping_criteria(self, max_length: Optional[int], max_time: Optional[float]) -> StoppingCriteriaList:
stopping_criteria = StoppingCriteriaList()
if max_length is not None:
def _get_stopping_criteria(
self,
max_length: Optional[int],
max_time: Optional[float],
max_new_tokens: Optional[int],
start_length: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_length is not needed no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is max_new_tokens needed for? I think this PR was from quite some time ago where we had this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think they come from my attempt to merge the main branch into this one. I'll fix this.

stopping_criteria: Optional[StoppingCriteriaList],
) -> StoppingCriteriaList:
if stopping_criteria is None:
stopping_criteria = StoppingCriteriaList()
max_length_in_criteria = any([isinstance(criteria, MaxLengthCriteria) for criteria in stopping_criteria])
if max_length is not None and not max_length_in_criteria:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with this I think -> it means that stopping_criteria is always more important than max_length which sounds good to me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually thinking more about I think it would be cleaner to just have the following logic:

if stopping_criteria is provided than only this list is used and nothing else. In a first step this allows us some nasty nested use cases. IMO we could just check if len(stopping_criteria) > 0 and if this is the case we don't even call the function _get_stopping_criteria. IMO someone that uses that functionality understands generate() quite well and doesn't need much magic under-the-hood.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think @Narsil ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an option that is definitely viable.

The core things I think are important:

  • If a user specified something, we need to respect it
  • If something comes as a default it cannot override anything user specified.
  • If user specification are unclear/unsound, yell loud and clear about what is going on, and what the code is going to do to save the generate.

Option 1 (current):

model.generate(...., no_repeat_n_tokens=3)

I need to add some even more clever functionality

mylogitsProcessor=LogitsProcessorList(MyLogits())
model.generate(...., logits_processor=mylogitsProcessor, no_repeat_n_tokens=3)

(You can keep the easiness of generate).

Option 2 (logits_processor is a full override):

model.generate(..., no_repeat_n_tokens=3)

becomes

my_logits_processor = LogitsProcessorList(
   NoRepeatLogitsProcessor(3),
  MyLogits()
)
model.generate(..., logits_processor=logits_processor)

Option 1 has the advantage that we can keep some options simpler to use and still add some custom logits processor, if we're careful enough that no non-user defined variable can ever override the logits_processor in a hidden way, then we're good to go. (Only user defined arguments are able to modify). This is ofc a guarantee that might be tricky to keep in the future, so we are taking a risk of silently breaking things. It also makes user code trickier to understand since there is no ONE way to define logits_processor, so it might lead to hard to understand behavior.

Option 2 has the advantage that it has a single point of definition of logits processor. The disadvantage is that it requires more changes on the user the first time he uses this variable. We also need to yell quite strongly that we're simply ignoring every other argument, which might not be obvious (we could even crash at this point, since it's a really easy to overlook thing and will definitely yield poor results).

IMHO, both are really fine, we just need to stick to one option. I suggested to @lvwerra Option 1, because I thought that max_length was the only variable that was always defined even if not user supplied. If this assumption is wrong and hard to make on a small list of variable, then I don't think we should stick with Option 1. Option 2 does have drawbacks as we're rejecting complexity back into user code instead of absorbing it like we're doing right now. But it does help separation of concern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great write-up! Actually there are a bunch of things that are always defined (top_k is always defined e.g.) but then also lots of models always define num_beams in their config alongside other parameters so I would very much prefer option 2 here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lvwerra are you ok switching to Option 2, seems my assumption was incorrect :( ?

stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
if max_time is not None:
stopping_criteria.append(MaxTimeCriteria(max_time=max_time))
Expand Down Expand Up @@ -674,6 +686,8 @@ def generate(
num_beam_groups: Optional[int] = None,
diversity_penalty: Optional[float] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
Expand Down Expand Up @@ -792,6 +806,12 @@ def generate(
crash. Note that using ``remove_invalid_values`` can slow down generation.
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
logits_processor (:obj:`LogitsProcessorList`, `optional`):
This object is created automatically from other arguments of this function. `logits_processor` is meant
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge fan of the docstring here - could we maybe rephrase it a bit. I think the user is mostly interested in what happens when this object is passed and how it can be passed - not really what happens when it is not passed.

So maybe something more like:

Suggested change
This object is created automatically from other arguments of this function. `logits_processor` is meant
If provided `logits_processor` will overwrite all passed arguments that can process logits as well as those saved in the model's config. It can be very useful to enable custom logits processing logic.

We should also note somewhere that this is an experimental feature IMO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keen to hear your input here as well @Narsil . IMO this feature is really not for the "unexperienced" HF user but for the advanced ones that know more or less what happens under the hood in generate() (otherwise why need costum logit processors?). To give this functionality while keeping complexity at a minimum I think the best first step is to simply say:
If we pass logits_processor or stopping_criteria it will overwrite everything else...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I named Option 2 above. While I think it's viable I don't think it's the only way.
Option 1 which is ("we're adding your stuff too without looking") is also perfectly viable. Let's continue discussion above since I think that's where the main point is, no ?

to be used to add another layer with custom logic.
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`):
This object is created automatically from other arguments of this function. `stopping_criteria` is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

meant to be used to add another layer with custom logic from your own code.

model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
Expand Down Expand Up @@ -871,6 +891,17 @@ def generate(
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
"""

# if `MaxLengthCriteria` exists it overwrites `max_length`
if stopping_criteria is not None:
for stopping_criterion in stopping_criteria:
if isinstance(stopping_criterion, MaxLengthCriteria):
if max_length is not None:
warnings.warn(
"A stopping criteria of type `MaxLengthCriteria` as well as `max_length` was passed to `generate`. The `MaxLengthCriteria` will be used.",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not necessarily true.

On line 889, we can defined max_length to self.config.max_length even if it's not user defined.
At least we override here which makes behavior what a user would expect IMO (respecting it's own defined things)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's right, I missed that one. I think the stopping_criteria should be moved before that. Otherwise you will always get a warning if neither max_length is None and max_new_tokens is None which would be the standard case when using a custom MaxLengthCriteria. What do you think?

UserWarning,
)
max_length = stopping_criterion.max_length

num_beams = num_beams if num_beams is not None else self.config.num_beams
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
do_sample = do_sample if do_sample is not None else self.config.do_sample
Expand Down Expand Up @@ -992,9 +1023,17 @@ def generate(
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values,
logits_processor=logits_processor,
)

stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time)
cur_len = input_ids.shape[-1]
stopping_criteria = self._get_stopping_criteria(
max_length=max_length,
max_time=max_time,
max_new_tokens=max_new_tokens,
start_length=cur_len,
stopping_criteria=stopping_criteria,
)

if is_greedy_gen_mode:
if num_return_sequences > 1:
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/models/rag/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from ...generation_beam_search import BeamSearchScorer
from ...generation_logits_process import LogitsProcessorList
from ...generation_stopping_criteria import StoppingCriteriaList
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
Expand Down Expand Up @@ -1375,6 +1377,8 @@ def generate(
decoder_start_token_id=None,
n_docs=None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
Expand Down Expand Up @@ -1478,6 +1482,12 @@ def generate(
remove_invalid_values (:obj:`bool`, `optional`):
Whether to remove possible `nan` and `inf` outputs of the model to prevent the generation method to
crash. Note that using ``remove_invalid_values`` can slow down generation.
logits_processor (:obj:`LogitsProcessorList`, `optional`):
This object is created automatically from other arguments of this function. `logits_processor` is meant
to be used to add another layer with custom logic.
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`):
This object is created automatically from other arguments of this function. `stopping_criteria` is
meant to be used to add another layer with custom logic from your own code.

Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
Expand Down Expand Up @@ -1585,6 +1595,7 @@ def extend_enc_output(tensor, num_beams=None):
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values,
logits_processor=logits_processor,
)

if num_beams == 1:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,38 @@ def test_beam_search_warning_if_max_length_is_passed(self):
# BeamSearchScorer max_length should not influence "real" max_length
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())

def test_custom_stopping_criteria_priorities(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)

input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
bart_model.config.max_length = 22
max_length = 33
stopping_criteria = StoppingCriteriaList()
stopping_criteria.append(MaxLengthCriteria(max_length=44))
self.assertEqual(list(bart_model.generate(input_ids).shape), [1, 22])

self.assertEqual(list(bart_model.generate(input_ids, max_length=max_length).shape), [1, 33])
self.assertEqual(list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria).shape), [1, 44])
with self.assertWarns(UserWarning):
self.assertEqual(
list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=max_length).shape),
[1, 44],
)

def test_custom_logits_processor(self):
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)

logits_processor = LogitsProcessorList()
logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0))
# XXX: Used to fail with `logits_processor` being defined twice in call arguments
# https://github.com/huggingface/transformers/issues/12118
bart_model.generate(input_ids, logits_processor=logits_processor)

def test_max_new_tokens_encoder_decoder(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
Expand Down