-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling users to provide their own stopping_criteria
+ logits_processor
to generate
.
#12219
Changes from all commits
bfc203b
5e5db45
896542c
9e011cd
a6ccdf8
5301c58
4e2d741
3b117b2
cf4ce59
58da8c5
fd282ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -573,13 +573,16 @@ def _get_logits_processor( | |||||
num_beam_groups: int, | ||||||
diversity_penalty: float, | ||||||
remove_invalid_values: bool, | ||||||
logits_processor: Optional[LogitsProcessorList], | ||||||
) -> LogitsProcessorList: | ||||||
""" | ||||||
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant | ||||||
:obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head. | ||||||
""" | ||||||
processors = LogitsProcessorList() | ||||||
|
||||||
if logits_processor is None: | ||||||
processors = LogitsProcessorList() | ||||||
else: | ||||||
processors = logits_processor | ||||||
# init warp parameters | ||||||
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty | ||||||
no_repeat_ngram_size = ( | ||||||
|
@@ -638,9 +641,18 @@ def _get_logits_processor( | |||||
processors.append(InfNanRemoveLogitsProcessor()) | ||||||
return processors | ||||||
|
||||||
def _get_stopping_criteria(self, max_length: Optional[int], max_time: Optional[float]) -> StoppingCriteriaList: | ||||||
stopping_criteria = StoppingCriteriaList() | ||||||
if max_length is not None: | ||||||
def _get_stopping_criteria( | ||||||
self, | ||||||
max_length: Optional[int], | ||||||
max_time: Optional[float], | ||||||
max_new_tokens: Optional[int], | ||||||
start_length: int, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. start_length is not needed no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think they come from my attempt to merge the main branch into this one. I'll fix this. |
||||||
stopping_criteria: Optional[StoppingCriteriaList], | ||||||
) -> StoppingCriteriaList: | ||||||
if stopping_criteria is None: | ||||||
stopping_criteria = StoppingCriteriaList() | ||||||
max_length_in_criteria = any([isinstance(criteria, MaxLengthCriteria) for criteria in stopping_criteria]) | ||||||
if max_length is not None and not max_length_in_criteria: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm fine with this I think -> it means that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually thinking more about I think it would be cleaner to just have the following logic: if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think @Narsil ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's an option that is definitely viable. The core things I think are important:
Option 1 (current): model.generate(...., no_repeat_n_tokens=3) I need to add some even more clever functionality mylogitsProcessor=LogitsProcessorList(MyLogits())
model.generate(...., logits_processor=mylogitsProcessor, no_repeat_n_tokens=3) (You can keep the easiness of Option 2 (logits_processor is a full override): model.generate(..., no_repeat_n_tokens=3) becomes my_logits_processor = LogitsProcessorList(
NoRepeatLogitsProcessor(3),
MyLogits()
)
model.generate(..., logits_processor=logits_processor) Option 1 has the advantage that we can keep some options simpler to use and still add some custom logits processor, if we're careful enough that no non-user defined variable can ever override the Option 2 has the advantage that it has a single point of definition of logits processor. The disadvantage is that it requires more changes on the user the first time he uses this variable. We also need to yell quite strongly that we're simply ignoring every other argument, which might not be obvious (we could even crash at this point, since it's a really easy to overlook thing and will definitely yield poor results). IMHO, both are really fine, we just need to stick to one option. I suggested to @lvwerra Option 1, because I thought that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the great write-up! Actually there are a bunch of things that are always defined (top_k is always defined e.g.) but then also lots of models always define There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lvwerra are you ok switching to Option 2, seems my assumption was incorrect :( ? |
||||||
stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) | ||||||
if max_time is not None: | ||||||
stopping_criteria.append(MaxTimeCriteria(max_time=max_time)) | ||||||
|
@@ -674,6 +686,8 @@ def generate( | |||||
num_beam_groups: Optional[int] = None, | ||||||
diversity_penalty: Optional[float] = None, | ||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, | ||||||
logits_processor: Optional[LogitsProcessorList] = None, | ||||||
stopping_criteria: Optional[StoppingCriteriaList] = None, | ||||||
output_attentions: Optional[bool] = None, | ||||||
output_hidden_states: Optional[bool] = None, | ||||||
output_scores: Optional[bool] = None, | ||||||
|
@@ -792,6 +806,12 @@ def generate( | |||||
crash. Note that using ``remove_invalid_values`` can slow down generation. | ||||||
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) | ||||||
logits_processor (:obj:`LogitsProcessorList`, `optional`): | ||||||
This object is created automatically from other arguments of this function. `logits_processor` is meant | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a huge fan of the docstring here - could we maybe rephrase it a bit. I think the user is mostly interested in what happens when this object is passed and how it can be passed - not really what happens when it is not passed. So maybe something more like:
Suggested change
We should also note somewhere that this is an experimental feature IMO There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. keen to hear your input here as well @Narsil . IMO this feature is really not for the "unexperienced" HF user but for the advanced ones that know more or less what happens under the hood in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's what I named |
||||||
to be used to add another layer with custom logic. | ||||||
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`): | ||||||
This object is created automatically from other arguments of this function. `stopping_criteria` is | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||||||
meant to be used to add another layer with custom logic from your own code. | ||||||
|
||||||
model_kwargs: | ||||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the | ||||||
|
@@ -871,6 +891,17 @@ def generate( | |||||
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) | ||||||
""" | ||||||
|
||||||
# if `MaxLengthCriteria` exists it overwrites `max_length` | ||||||
if stopping_criteria is not None: | ||||||
for stopping_criterion in stopping_criteria: | ||||||
if isinstance(stopping_criterion, MaxLengthCriteria): | ||||||
if max_length is not None: | ||||||
warnings.warn( | ||||||
"A stopping criteria of type `MaxLengthCriteria` as well as `max_length` was passed to `generate`. The `MaxLengthCriteria` will be used.", | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's not necessarily true. On line 889, we can defined There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh that's right, I missed that one. I think the |
||||||
UserWarning, | ||||||
) | ||||||
max_length = stopping_criterion.max_length | ||||||
|
||||||
num_beams = num_beams if num_beams is not None else self.config.num_beams | ||||||
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups | ||||||
do_sample = do_sample if do_sample is not None else self.config.do_sample | ||||||
|
@@ -992,9 +1023,17 @@ def generate( | |||||
num_beam_groups=num_beam_groups, | ||||||
diversity_penalty=diversity_penalty, | ||||||
remove_invalid_values=remove_invalid_values, | ||||||
logits_processor=logits_processor, | ||||||
) | ||||||
|
||||||
stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time) | ||||||
cur_len = input_ids.shape[-1] | ||||||
stopping_criteria = self._get_stopping_criteria( | ||||||
max_length=max_length, | ||||||
max_time=max_time, | ||||||
max_new_tokens=max_new_tokens, | ||||||
start_length=cur_len, | ||||||
stopping_criteria=stopping_criteria, | ||||||
) | ||||||
|
||||||
if is_greedy_gen_mode: | ||||||
if num_return_sequences > 1: | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I also think we shouldn't even call the function if
logits_processor
is used. It makes our live much easier and the design a bit cleaner.What I don't like about the current design is that:
forced_eos_token_id
is provided ingenerate()
and as an input at the moment, then we have twoforced_eos_token_id
processors in the list which leads to weird behavior...A first simple solution is to just not call this function IMO. We could always later adapt it for more advanced functionality if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to keep in mind is that we essentially deactivate a whole bunch of options without much transparency about it so a user would need to know which options correspond to a
logits_processor
. E.g. I just had to look up iftemperature
would be affected if I added alogits_processor
. Same withstopping_criteria
.The other way around it is more transparent I think if something goes wrong: Ah I passed
forced_eos_token_id
and alogits_processor
doing the same thing, maybe that is not good.For me the main use-case is to add a custom processer/criteria in addition to what I can already do with the kwargs. Otherwise, why go the extra step of using the arguably more complicated API if you can just pass it as an argument.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's an argument in favour of @Narsil's option 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good argument @lvwerra and I fully understand what you mean.
I'm however really concerned about the complexity that option 1 adds for IMO very few use cases. Also from a backwards breaking point of view it's pretty much impossible to go from option 1 to option 2 in the future if this feature becomes more important where as it's much easier to go from option 2 to option 1 in the future.
Very good point and that's the big drawback of option 2. I believe that people having to use special logits processors are able to create it themselves.
Ok, how about we do something in between option 1 and option 2 that doesn't create a crazy complex logic.
If one passes a logits processor, we do the following:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this compromise, what do you think @Narsil?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think we can go for the solution if you want @lvwerra :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do that.