Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom stopping_criteria and logits_processor to generate #14779

Merged
merged 11 commits into from Dec 21, 2021

Conversation

lvwerra
Copy link
Member

@lvwerra lvwerra commented Dec 15, 2021

What does this PR do?

This PR are continues the work and discussions from #12219 with a fresh start.

It integrates the custom stopping_criteria and logits_processor with the following logic:

  • the default stopping_criteria/logits_processor are created from the arguments and model's config
  • if additional, custom stopping_criteria and logits_processor are passed to generate, they are compared with the default list
    • if there is an overlap between the two lists (e.g. a MaxLengthCriteria in both lists) an error is thrown
    • if there is no overlap the two lists are merged

Fixes #12118

Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good and clean.

Only suggestion would be making it slightly more obvious in the docstring that those arguments are intended for power users.

@@ -573,6 +573,7 @@ def _get_logits_processor(
num_beam_groups: int,
diversity_penalty: float,
remove_invalid_values: bool,
logits_processor: Optional[StoppingCriteriaList],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you meant LogitsProcessorList

@lvwerra
Copy link
Member Author

lvwerra commented Dec 20, 2021

@Narsil thanks for your feedback and catching that copy-paste error! I added something to the docstring and fixed the error.

@patrickvonplaten any comments?

def _merge_criteria_processor_list(
self, default_list: StoppingCriteriaList, custom_list: StoppingCriteriaList
) -> StoppingCriteriaList:
if custom_list is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can the custom_list be None? I think it can only be empty no? Wouldn't it be better to do:

if len(custom_list) == 0:
    return default_list

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default in generate is None but I can change it to an empty list if you think that is more consistent.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An empty StoppingCriteriaList and LogitsProcessorList to be more precise.

@@ -1469,6 +1473,14 @@ def generate(
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
argument is useful for constrained generation conditioned on the prefix, as described in
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
logits_processor (:obj:`LogitsProcessorList`, `optional`):
Custom logits processors that complement the default logits processors built from arguments and a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the doc-string!

for custom in custom_list:
if type(custom) is type(default):
raise ValueError(
f"A custom stopping criteria or logits processor of type {type(custom)} was passed to `generate` which is already created with an argument or model config."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"A custom stopping criteria or logits processor of type {type(custom)} was passed to `generate` which is already created with an argument or model config."
f"A custom stopping criteria or logits processor of type {type(custom)} was passed "
"to `generate` which is already created with an argument or model config."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, since this error won't be easy to solve for the user can we try to give her/him as much guidance as possible? we could a) Figure out if it's a stopping criteria or a logits processor from the type. And then b) IMO we should display as much information as possible. E.g.

object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to `generate`, but it has already been created with the values {default}. {default} has been created by passing the corresponding arguments to generate or by the config's default values."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like this maybe, there is probably a better way to phrase it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually what we could do in a second step is to enforce that the init names of the processors have to correspond 1-to-1 to the inputs of generate() and then we can actually pass the arguments that need to set to None in generate to allow this. Maybe a bit overkill in a first step, but we could do this if we encounter issues about it

Copy link
Member Author

@lvwerra lvwerra Dec 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refined the error message to include a little bit more detail. Let me know if you think it is clearer now.

I agree with your last point: if a lot of people have issues we can make it more strict.

return criteria

def _merge_criteria_processor_list(
self, default_list: StoppingCriteriaList, custom_list: StoppingCriteriaList
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self, default_list: StoppingCriteriaList, custom_list: StoppingCriteriaList
self, default_list: List[Any], custom_list: List[Any]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not of type List. Union[LogitsProcessorList, StoppingCriteriaList] is the correct thing (you can define a new type too to make it less verbose maybe)

GenerateList = Union[LogitsProcessorList, StoppingCriteriaList]

    def _merge(... default_list: GenerateList) -> GenerateList:
       ....

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added Union[LogitsProcessorList, StoppingCriteriaList]. Should be more careful with type hints in the future. I chose the verbose option because it is quite readable that way.


def _merge_criteria_processor_list(
self, default_list: StoppingCriteriaList, custom_list: StoppingCriteriaList
) -> StoppingCriteriaList:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> StoppingCriteriaList:
) -> List[Any]:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above discussion.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Thanks a lot for the feature.

Feel free to merge whenever :-)

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
@lvwerra lvwerra changed the title Custom criteria logits processor generate Add custom stopping_criteria and logits_processor to generate Dec 21, 2021
@lvwerra lvwerra merged commit 5722d05 into master Dec 21, 2021
@lvwerra lvwerra deleted the custom-criteria-logits-processor-generate branch December 21, 2021 15:47
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 27, 2022
…uggingface#14779)

* add custom `stopping_criteria` and `logits_processor` to `generate`

* add tests for custom `stopping_criteria` and `logits_processor`

* fix typo in RAG

* address reviewer comments

* improve custom logits processor/stopping criteria error message

* fix types in merge function signature

* change default for custom list from `None` to empty list

* fix rag generate

* add string split suggestion

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Passing a custom stopping_criteria list to model.generate() yields a multiple value error for that keyword arg
3 participants