Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add seq2seq prompt tuning support #519

Merged

Conversation

thomas-schillaci
Copy link
Contributor

This commit adds prompt tuning and support for the generate method for encoder-decoders.

Using generate for encoder-decoder models with prompt tuning was previously not supported as you can't use generate with ìnputs_embeds. I address this issue by generating the encoder_outputs of the input_ids + prompt, and passing it to generate.

Also included two examples notebooks to showcase this feature.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 1, 2023

The documentation is not available anymore as the PR was closed or merged.

@thomas-schillaci
Copy link
Contributor Author

@pacman100 I forgot to run make, this should be good with the latest commit


prompts = self.get_prompt(batch_size=batch_size)
prompts = prompts.to(inputs_embeds.dtype)
inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @thomas-schillaci ,

Thank you so much for this PR. This leads to not using prompt tokens for the decoder which might result in a decrease in the model performance. This is the bottleneck because of which I wasn't able to do something similar to this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @pacman100, thank you for the review.
If I'm correct, prompt tuning only requires prompts on the encoder.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went over the papers again. Prefix-Tuning paper is adding prefixes to both the encoder and decoder while prompt-tuning paper seems only to add them to the input.

Prompt Tuning Paper:

Given a series of n tokens, {x1, x2, . . . , xn}, the
first thing T5 does is embed the tokens, forming
a matrix Xe ∈ R
n×e where e is the dimension of
the embedding space. Our soft-prompts are repre�sented as a parameter Pe ∈ R
p×e
, where p is the
length of the prompt. Our prompt is then concate�nated to the embedded input forming a single ma�trix [Pe; Xe] ∈ R
(p+n)×e which then flows though
the encoder-decoder as normal. Our models are
trained to maximize the probability of Y , but only
the prompt parameters Pe are updated.

Prefix-Tuning paper
Screenshot 2023-06-16 at 1 45 59 PM

Therefore, you are correct. Thank you!

Copy link
Contributor

@pacman100 pacman100 Jun 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please then remove the need for prompt tokens in the decoder for P-Tuning approach too? This will lead to both of these being supported by generate for seq2seq tasks.

Also, remove the point 3 from the caveats section of README.md as this PR solves it.

  1. For encoder-decoder models, P_TUNING or PROMPT_TUNING doesn't support generate functionality of transformers because generate strictly requires decoder_input_ids but P_TUNING/PROMPT_TUNING appends soft prompt embeddings to input_embeds to create new input_embeds to be given to the model. Therefore, generate doesn't support this yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! I'll be working on it on Monday 😉

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomas-schillaci @pacman100 Thanks for your valuable contribution. What is about the point 2, "When using P_TUNING or PROMPT_TUNING with SEQ_2_SEQ task, remember to remove the num_virtual_token virtual prompt predictions from the left side of the model outputs during evaluations."? Could you please also update it?

Copy link
Contributor

@ZhengxiangShi ZhengxiangShi Jun 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomas-schillaci @pacman100, thanks for your super useful code!

For the line encoder_outputs = self.base_model.get_encoder()(inputs_embeds=inputs_embeds) andkwargs["encoder_outputs"] = encoder_outputs, I am not sure that we can call it before updating the attention_mask, which should also be included to comput the encoder_outputs.

Additionally, it may not be necessary to get the encoder_outputs here. The self.base_model.generate will handle it with the method _prepare_encoder_decoder_kwargs_for_generation later. We may only need to update the inputs_embeds and attention_mask.

Please correct me if I am wrong. Thanks a lot.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @thomas-schillaci for adding this and the efforts therein 🤗! Could you please extend go over the latest comments and see if those changes too can be incorporated in this PR?

ZhengxiangShi added a commit to ZhengxiangShi/peft that referenced this pull request Jun 16, 2023
@thomas-schillaci
Copy link
Contributor Author

@pacman100 I have incorporated the changes discussed above, and the suggestions from @ZhengxiangShi, thanks a lot!
Regarding point 2 of the caveats, I don't think it is longer relevant, I took the liberty to remove it.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much @thomas-schillaci for iterating, LGTM! ✨

It would be great to add @ZhengxiangShi as a co-author in this PR with respect to simplifying generate. Post that will merge the PR. Thank you!

Thomas SCHILLACI and others added 3 commits June 26, 2023 17:50
Co-authored-by: ZhengxiangShi michaelszx117@gmail.com
Co-authored-by: ZhengxiangShi <michaelszx117@gmail.com>
Co-authored-by: ZhengxiangShi <michaelszx117@gmail.com>
@thomas-schillaci
Copy link
Contributor Author

Thank you for the review @pacman100 !
I have added @ZhengxiangShi as a co-author as I have added his suggestions.

@pacman100 pacman100 merged commit 0e8932f into huggingface:main Jun 27, 2023
11 checks passed
@thomas-schillaci thomas-schillaci deleted the add-seq2seq-prompt-tuning-support branch June 27, 2023 06:29
@ZhengxiangShi
Copy link
Contributor

Thanks for your help! @thomas-schillaci @pacman100

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants