-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add seq2seq prompt tuning support #519
Add seq2seq prompt tuning support #519
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
The documentation is not available anymore as the PR was closed or merged. |
@pacman100 I forgot to run |
|
||
prompts = self.get_prompt(batch_size=batch_size) | ||
prompts = prompts.to(inputs_embeds.dtype) | ||
inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @thomas-schillaci ,
Thank you so much for this PR. This leads to not using prompt tokens for the decoder which might result in a decrease in the model performance. This is the bottleneck because of which I wasn't able to do something similar to this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @pacman100, thank you for the review.
If I'm correct, prompt tuning only requires prompts on the encoder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went over the papers again. Prefix-Tuning paper is adding prefixes to both the encoder and decoder while prompt-tuning paper seems only to add them to the input.
Prompt Tuning Paper:
Given a series of n tokens, {x1, x2, . . . , xn}, the
first thing T5 does is embed the tokens, forming
a matrix Xe ∈ R
n×e where e is the dimension of
the embedding space. Our soft-prompts are repre�sented as a parameter Pe ∈ R
p×e
, where p is the
length of the prompt. Our prompt is then concate�nated to the embedded input forming a single ma�trix [Pe; Xe] ∈ R
(p+n)×e which then flows though
the encoder-decoder as normal. Our models are
trained to maximize the probability of Y , but only
the prompt parameters Pe are updated.
Therefore, you are correct. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please then remove the need for prompt tokens in the decoder for P-Tuning approach too? This will lead to both of these being supported by generate
for seq2seq
tasks.
Also, remove the point 3 from the caveats section of README.md as this PR solves it.
- For encoder-decoder models, P_TUNING or PROMPT_TUNING doesn't support generate functionality of transformers because generate strictly requires decoder_input_ids but P_TUNING/PROMPT_TUNING appends soft prompt embeddings to input_embeds to create new input_embeds to be given to the model. Therefore, generate doesn't support this yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! I'll be working on it on Monday 😉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@thomas-schillaci @pacman100 Thanks for your valuable contribution. What is about the point 2, "When using P_TUNING or PROMPT_TUNING with SEQ_2_SEQ task, remember to remove the num_virtual_token virtual prompt predictions from the left side of the model outputs during evaluations."? Could you please also update it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@thomas-schillaci @pacman100, thanks for your super useful code!
For the line encoder_outputs = self.base_model.get_encoder()(inputs_embeds=inputs_embeds)
andkwargs["encoder_outputs"] = encoder_outputs
, I am not sure that we can call it before updating the attention_mask
, which should also be included to comput the encoder_outputs
.
Additionally, it may not be necessary to get the encoder_outputs
here. The self.base_model.generate will handle it with the method _prepare_encoder_decoder_kwargs_for_generation
later. We may only need to update the inputs_embeds
and attention_mask
.
Please correct me if I am wrong. Thanks a lot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @thomas-schillaci for adding this and the efforts therein 🤗! Could you please extend go over the latest comments and see if those changes too can be incorporated in this PR?
thomas-schillaci huggingface#519
@pacman100 I have incorporated the changes discussed above, and the suggestions from @ZhengxiangShi, thanks a lot! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much @thomas-schillaci for iterating, LGTM! ✨
It would be great to add @ZhengxiangShi as a co-author in this PR with respect to simplifying generate
. Post that will merge the PR. Thank you!
Co-authored-by: ZhengxiangShi michaelszx117@gmail.com
Co-authored-by: ZhengxiangShi <michaelszx117@gmail.com>
Co-authored-by: ZhengxiangShi <michaelszx117@gmail.com>
Thank you for the review @pacman100 ! |
Thanks for your help! @thomas-schillaci @pacman100 |
This commit adds prompt tuning and support for the
generate
method for encoder-decoders.Using
generate
for encoder-decoder models with prompt tuning was previously not supported as you can't usegenerate
withìnputs_embeds
. I address this issue by generating theencoder_outputs
of theinput_ids
+ prompt, and passing it togenerate
.Also included two examples notebooks to showcase this feature.