Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow adding custom logits processors in the generate method #11413

Closed
wadimiusz opened this issue Apr 24, 2021 · 7 comments
Closed

Allow adding custom logits processors in the generate method #11413

wadimiusz opened this issue Apr 24, 2021 · 7 comments
Assignees

Comments

@wadimiusz
Copy link

wadimiusz commented Apr 24, 2021

馃殌 Feature request

Hello,
I'd like to request a new feature in the generate method of the GenerationMixin class from generation_utils. Specifically, I'd like a feature that allows a user to pass custom LogitsProcessors by adding a new argument logit_processors: Optional[LogitsProcessorList] = None to the generate method.

Motivation

I'd like to run generation on a pre-trained model, and I'd like to modify its output logits according to my custom function before the search or sampling or whatever is used. I think that this could be a common use case for controlled natural generation because one often wants to implement some trivial restrictions over generated logits.

Here is an example of how this could be used:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, LogitsProcessor, LogitsProcessorList

class MyLogitsProcessor(LogitsProcessor):
   def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
      something_useful()


model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
logit_processors = LogitsProcessorList([MyLogitsProcessor()])
input_ids = tokenizer('This dog is cute', return_tensors='pt').input_ids
model.generate(input_ids=input_ids, logit_processors=logit_processors)

Your contribution

I have no experience in open source, but I can try to help if you need a hand. I think that the general approach to implementing this is to do the following:

  1. Add the logit_processors: Optional[LogitsProcessorList] = None argument to the generate method,
  2. Add the same argument to the _get_logits_processor method of GenerationMixin and add the custom logit processors after all the other logit processors are in place.
  3. Pass the custom logits processors to every call of _get_logits_processor in the generate method.

What do you think?

@wadimiusz
Copy link
Author

I think I could submit a pull request to this, if I had

  1. feedback on the idea (do you think it makes sense to do that?)
  2. a little help changing existing tests and/or implementing new tests to reflect the change.

Also, maybe one would need the new argument to be Optional[LogitsProcessor] instead of Optional[LogitsProcessorList]. Because LogitsProcessotList is a subclass of LogitsProcessor, this would allow adding both a list of logits processors and a single logits processor.

What do you folks think? Would you accept this pull request (after maybe giving me some tips related to the tests)?

@patrickvonplaten
Copy link
Contributor

Hey @wadimiusz,

Sorry to only come back to you now! I think in general, I'm fine with such an extension. The only problem I see is that a user could add a custom logits processor that already exists (e.g. a user would create his own LengthPenaltyLogitsProcessor) and also pass length_penalty=... . But even in this case I guess we could just apply both processors and there shouldn't be a big problem.

=> So I'm ok with this extension. Interested in hearing your thoughts about this @patil-suraj @Narsil

@Narsil
Copy link
Contributor

Narsil commented May 14, 2021

I think it's a very nice idea !.

The problem you mention @patrickvonplaten I think will be relevant mostly for power users (that want to add a LogitsProcessor) so they should be careful in terms of how they use this tool. I guess we could emphasis this in the documentation for the generate function, that the simpler arguments are preferred for non advanced usage.

@github-actions
Copy link

github-actions bot commented Jun 8, 2021

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ScientiaEtVeritas
Copy link

@wadimiusz Is there any update on this? I think it would be a great addition.

@wadimiusz
Copy link
Author

Hi @ScientiaEtVeritas, the feature seems not hard to implement and I think I already have the code somewhere, but it would require nice and thorough tests that I don't have the time to write right now. If you could help me with the tests, we could submit a pull request together :)

@Narsil
Copy link
Contributor

Narsil commented Nov 26, 2021

There used to be a PR that might be used as a starting point:

#12219

Thanks if you can work on this !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants