Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling users to provide their own stopping_criteria + logits_processor to generate. #12219

Closed

Conversation

Narsil
Copy link
Contributor

@Narsil Narsil commented Jun 17, 2021

What does this PR do?

Fixes #12118

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Narsil
Copy link
Contributor Author

Narsil commented Jul 17, 2021

@patrickvonplaten (Not urgent, get some rest :))

@patrickvonplaten
Copy link
Contributor

Sorry for the late reply here @Narsil - I'm happy with the PR I think :-) If we could add a test that would be great

@Narsil Narsil force-pushed the add_logits_stopping_to_generate branch from 463a1ba to 5e5db45 Compare August 31, 2021 13:44
@Narsil Narsil changed the title [WIP] Enabling users to provide their own stopping_criteria + logits_processor to generate. Enabling users to provide their own stopping_criteria + logits_processor to generate. Aug 31, 2021
@Narsil
Copy link
Contributor Author

Narsil commented Sep 1, 2021

@patrickvonplaten Should I merge this ?

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this :)

@patrickvonplaten do you want to take a look again ?

) -> StoppingCriteriaList:
stopping_criteria = StoppingCriteriaList()
if stopping_criteria is None:
stopping_criteria = StoppingCriteriaList()
if max_length is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_length always defaults to 20 -> so if someone passes a stopping_criteria list then there are two stopping criteria no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not good no? E.g. if the stopping criteria is 30 in the list, the generation will still stop at 20. So IMO if someone passes a stopping_criteria list we should check for each item if the class already exsits in the list (if it's not the case, only then we'll add it). This means that the priority is as follows:

1st priority: stopping_criteria
2nd priority: directly passing max_length
3rd priority: using max_length of config
4th priority: using default max_length

=> think the same should hold true for logits_processor.

Think we should not merge as it is right now

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think the workflow is not optimal at the moment -> see comment here: https://github.com/huggingface/transformers/pull/12219/files#r705307868

Keen to hear your opinion @Narsil

@Narsil
Copy link
Contributor Author

Narsil commented Sep 9, 2021

I think we shouldn't check anything. If you defined something we pass it as-is IMO. It's a poweuser feature, the doc specifically mentions this:

https://github.com/huggingface/transformers/pull/12219/files#diff-b7601d397d5d60326ce61a9c91beaa2afa026014141052b32b07e1d044fbbe17R801

@Narsil
Copy link
Contributor Author

Narsil commented Sep 10, 2021

But also happy to drop the PR, the issue didn't seem to generate that much traction.
If we're scared to introduce new range of bugs, hard to understand stuff, maybe let's just drop it.

@patrickvonplaten
Copy link
Contributor

I think it would be nice to merge the PR, but it just doesn't make much sense to me that a default, always-defined value like max_length=20 would overwrite something that's passed via the logits_processor. So instead of dropping the PR we can just ensure that passed logits_processor and stopping_criteria that are passed have priority which is intuitive and sensible to me.

@Narsil
Copy link
Contributor Author

Narsil commented Sep 15, 2021

So, you think, we should

if logits_processor is None:
   logist_processort = self._get_logits_process(...)

instead ?

Make sense.

@huggingface huggingface deleted a comment from github-actions bot Oct 11, 2021
@github-actions
Copy link

github-actions bot commented Nov 4, 2021

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Nov 13, 2021
@patrickvonplaten
Copy link
Contributor

Leaving it as closed for now - reopening in case the community expresses interest in this PR again...

@lvwerra lvwerra reopened this Nov 29, 2021
Copy link
Contributor Author

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't approve since this is my PR, but it LGTM.

if isinstance(stopping_criterion, MaxLengthCriteria):
if max_length is not None:
warnings.warn(
"A stopping criteria of type `MaxLengthCriteria` as well as `max_length` was passed to `generate`. The `MaxLengthCriteria` will be used.",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not necessarily true.

On line 889, we can defined max_length to self.config.max_length even if it's not user defined.
At least we override here which makes behavior what a user would expect IMO (respecting it's own defined things)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's right, I missed that one. I think the stopping_criteria should be moved before that. Otherwise you will always get a warning if neither max_length is None and max_new_tokens is None which would be the standard case when using a custom MaxLengthCriteria. What do you think?

@patrickvonplaten
Copy link
Contributor

Thanks a lot for taking this over @lvwerra ! Let me know if you need any help with the remaining tests

if stopping_criteria is None:
stopping_criteria = StoppingCriteriaList()
max_length_in_criteria = any([isinstance(criteria, MaxLengthCriteria) for criteria in stopping_criteria])
if max_length is not None and not max_length_in_criteria:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with this I think -> it means that stopping_criteria is always more important than max_length which sounds good to me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually thinking more about I think it would be cleaner to just have the following logic:

if stopping_criteria is provided than only this list is used and nothing else. In a first step this allows us some nasty nested use cases. IMO we could just check if len(stopping_criteria) > 0 and if this is the case we don't even call the function _get_stopping_criteria. IMO someone that uses that functionality understands generate() quite well and doesn't need much magic under-the-hood.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think @Narsil ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an option that is definitely viable.

The core things I think are important:

  • If a user specified something, we need to respect it
  • If something comes as a default it cannot override anything user specified.
  • If user specification are unclear/unsound, yell loud and clear about what is going on, and what the code is going to do to save the generate.

Option 1 (current):

model.generate(...., no_repeat_n_tokens=3)

I need to add some even more clever functionality

mylogitsProcessor=LogitsProcessorList(MyLogits())
model.generate(...., logits_processor=mylogitsProcessor, no_repeat_n_tokens=3)

(You can keep the easiness of generate).

Option 2 (logits_processor is a full override):

model.generate(..., no_repeat_n_tokens=3)

becomes

my_logits_processor = LogitsProcessorList(
   NoRepeatLogitsProcessor(3),
  MyLogits()
)
model.generate(..., logits_processor=logits_processor)

Option 1 has the advantage that we can keep some options simpler to use and still add some custom logits processor, if we're careful enough that no non-user defined variable can ever override the logits_processor in a hidden way, then we're good to go. (Only user defined arguments are able to modify). This is ofc a guarantee that might be tricky to keep in the future, so we are taking a risk of silently breaking things. It also makes user code trickier to understand since there is no ONE way to define logits_processor, so it might lead to hard to understand behavior.

Option 2 has the advantage that it has a single point of definition of logits processor. The disadvantage is that it requires more changes on the user the first time he uses this variable. We also need to yell quite strongly that we're simply ignoring every other argument, which might not be obvious (we could even crash at this point, since it's a really easy to overlook thing and will definitely yield poor results).

IMHO, both are really fine, we just need to stick to one option. I suggested to @lvwerra Option 1, because I thought that max_length was the only variable that was always defined even if not user supplied. If this assumption is wrong and hard to make on a small list of variable, then I don't think we should stick with Option 1. Option 2 does have drawbacks as we're rejecting complexity back into user code instead of absorbing it like we're doing right now. But it does help separation of concern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great write-up! Actually there are a bunch of things that are always defined (top_k is always defined e.g.) but then also lots of models always define num_beams in their config alongside other parameters so I would very much prefer option 2 here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lvwerra are you ok switching to Option 2, seems my assumption was incorrect :( ?

max_length: Optional[int],
max_time: Optional[float],
max_new_tokens: Optional[int],
start_length: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_length is not needed no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is max_new_tokens needed for? I think this PR was from quite some time ago where we had this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think they come from my attempt to merge the main branch into this one. I'll fix this.

) -> LogitsProcessorList:
"""
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
:obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head.
"""
processors = LogitsProcessorList()

if logits_processor is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I also think we shouldn't even call the function if logits_processor is used. It makes our live much easier and the design a bit cleaner.

What I don't like about the current design is that:

  • if say both forced_eos_token_id is provided in generate() and as an input at the moment, then we have two forced_eos_token_id processors in the list which leads to weird behavior...

A first simple solution is to just not call this function IMO. We could always later adapt it for more advanced functionality if needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to keep in mind is that we essentially deactivate a whole bunch of options without much transparency about it so a user would need to know which options correspond to a logits_processor. E.g. I just had to look up if temperature would be affected if I added a logits_processor. Same with stopping_criteria.

The other way around it is more transparent I think if something goes wrong: Ah I passed forced_eos_token_id and a logits_processor doing the same thing, maybe that is not good.

For me the main use-case is to add a custom processer/criteria in addition to what I can already do with the kwargs. Otherwise, why go the extra step of using the arguably more complicated API if you can just pass it as an argument.

What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an argument in favour of @Narsil's option 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good argument @lvwerra and I fully understand what you mean.

I'm however really concerned about the complexity that option 1 adds for IMO very few use cases. Also from a backwards breaking point of view it's pretty much impossible to go from option 1 to option 2 in the future if this feature becomes more important where as it's much easier to go from option 2 to option 1 in the future.

For me the main use-case is to add a custom processer/criteria in addition to what I can already do with the kwargs. Otherwise, why go the extra step of using the arguably more complicated API if you can just pass it as an argument.

Very good point and that's the big drawback of option 2. I believe that people having to use special logits processors are able to create it themselves.

Ok, how about we do something in between option 1 and option 2 that doesn't create a crazy complex logic.

If one passes a logits processor, we do the following:

  1. just create the normal logits_processor that would have been created without passing one.
  2. if any object of the passed logits processor is in the already created logits processor then we raise an error and tell the user which logits processor was created twice (and ideally which paramter has to be changed for this)
  3. If there is no error make a union of the two lists and we're good to go.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this compromise, what do you think @Narsil?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we can go for the solution if you want @lvwerra :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that.

@@ -792,6 +806,12 @@ def generate(
crash. Note that using ``remove_invalid_values`` can slow down generation.
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
logits_processor (:obj:`LogitsProcessorList`, `optional`):
This object is created automatically from other arguments of this function. `logits_processor` is meant
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge fan of the docstring here - could we maybe rephrase it a bit. I think the user is mostly interested in what happens when this object is passed and how it can be passed - not really what happens when it is not passed.

So maybe something more like:

Suggested change
This object is created automatically from other arguments of this function. `logits_processor` is meant
If provided `logits_processor` will overwrite all passed arguments that can process logits as well as those saved in the model's config. It can be very useful to enable custom logits processing logic.

We should also note somewhere that this is an experimental feature IMO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keen to hear your input here as well @Narsil . IMO this feature is really not for the "unexperienced" HF user but for the advanced ones that know more or less what happens under the hood in generate() (otherwise why need costum logit processors?). To give this functionality while keeping complexity at a minimum I think the best first step is to simply say:
If we pass logits_processor or stopping_criteria it will overwrite everything else...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I named Option 2 above. While I think it's viable I don't think it's the only way.
Option 1 which is ("we're adding your stuff too without looking") is also perfectly viable. Let's continue discussion above since I think that's where the main point is, no ?

This object is created automatically from other arguments of this function. `logits_processor` is meant
to be used to add another layer with custom logic.
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`):
This object is created automatically from other arguments of this function. `stopping_criteria` is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@Narsil
Copy link
Contributor Author

Narsil commented Dec 27, 2021

Superseeded by #14779 (comment)

@Narsil Narsil closed this Dec 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Passing a custom stopping_criteria list to model.generate() yields a multiple value error for that keyword arg
4 participants