New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add custom stopping_criteria
and logits_processor
to generate
#14779
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good and clean.
Only suggestion would be making it slightly more obvious in the docstring that those arguments are intended for power users.
src/transformers/generation_utils.py
Outdated
@@ -573,6 +573,7 @@ def _get_logits_processor( | |||
num_beam_groups: int, | |||
diversity_penalty: float, | |||
remove_invalid_values: bool, | |||
logits_processor: Optional[StoppingCriteriaList], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you meant LogitsProcessorList
@Narsil thanks for your feedback and catching that copy-paste error! I added something to the docstring and fixed the error. @patrickvonplaten any comments? |
src/transformers/generation_utils.py
Outdated
def _merge_criteria_processor_list( | ||
self, default_list: StoppingCriteriaList, custom_list: StoppingCriteriaList | ||
) -> StoppingCriteriaList: | ||
if custom_list is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can the custom_list be None? I think it can only be empty no? Wouldn't it be better to do:
if len(custom_list) == 0:
return default_list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default in generate
is None
but I can change it to an empty list if you think that is more consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An empty StoppingCriteriaList
and LogitsProcessorList
to be more precise.
@@ -1469,6 +1473,14 @@ def generate( | |||
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This | |||
argument is useful for constrained generation conditioned on the prefix, as described in | |||
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__. | |||
logits_processor (:obj:`LogitsProcessorList`, `optional`): | |||
Custom logits processors that complement the default logits processors built from arguments and a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the doc-string!
src/transformers/generation_utils.py
Outdated
for custom in custom_list: | ||
if type(custom) is type(default): | ||
raise ValueError( | ||
f"A custom stopping criteria or logits processor of type {type(custom)} was passed to `generate` which is already created with an argument or model config." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f"A custom stopping criteria or logits processor of type {type(custom)} was passed to `generate` which is already created with an argument or model config." | |
f"A custom stopping criteria or logits processor of type {type(custom)} was passed " | |
"to `generate` which is already created with an argument or model config." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, since this error won't be easy to solve for the user can we try to give her/him as much guidance as possible? we could a) Figure out if it's a stopping criteria
or a logits processor
from the type. And then b) IMO we should display as much information as possible. E.g.
object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to `generate`, but it has already been created with the values {default}. {default} has been created by passing the corresponding arguments to generate or by the config's default values."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like this maybe, there is probably a better way to phrase it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually what we could do in a second step is to enforce that the init names of the processors have to correspond 1-to-1 to the inputs of generate()
and then we can actually pass the arguments that need to set to None
in generate to allow this. Maybe a bit overkill in a first step, but we could do this if we encounter issues about it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refined the error message to include a little bit more detail. Let me know if you think it is clearer now.
I agree with your last point: if a lot of people have issues we can make it more strict.
src/transformers/generation_utils.py
Outdated
return criteria | ||
|
||
def _merge_criteria_processor_list( | ||
self, default_list: StoppingCriteriaList, custom_list: StoppingCriteriaList |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self, default_list: StoppingCriteriaList, custom_list: StoppingCriteriaList | |
self, default_list: List[Any], custom_list: List[Any] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are not of type List
. Union[LogitsProcessorList, StoppingCriteriaList]
is the correct thing (you can define a new type too to make it less verbose maybe)
GenerateList = Union[LogitsProcessorList, StoppingCriteriaList]
def _merge(... default_list: GenerateList) -> GenerateList:
....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added Union[LogitsProcessorList, StoppingCriteriaList]
. Should be more careful with type hints in the future. I chose the verbose option because it is quite readable that way.
src/transformers/generation_utils.py
Outdated
|
||
def _merge_criteria_processor_list( | ||
self, default_list: StoppingCriteriaList, custom_list: StoppingCriteriaList | ||
) -> StoppingCriteriaList: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) -> StoppingCriteriaList: | |
) -> List[Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above discussion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Thanks a lot for the feature.
Feel free to merge whenever :-)
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
stopping_criteria
and logits_processor
to generate
…uggingface#14779) * add custom `stopping_criteria` and `logits_processor` to `generate` * add tests for custom `stopping_criteria` and `logits_processor` * fix typo in RAG * address reviewer comments * improve custom logits processor/stopping criteria error message * fix types in merge function signature * change default for custom list from `None` to empty list * fix rag generate * add string split suggestion Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
What does this PR do?
This PR are continues the work and discussions from #12219 with a fresh start.
It integrates the custom
stopping_criteria
andlogits_processor
with the following logic:stopping_criteria
/logits_processor
are created from the arguments and model's configstopping_criteria
andlogits_processor
are passed togenerate
, they are compared with the default listMaxLengthCriteria
in both lists) an error is thrownFixes #12118