Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt tuning significantly slows down code & uses more memory than without? #13

Open
patricks-lab opened this issue Oct 6, 2023 · 0 comments

Comments

@patricks-lab
Copy link

First of all, thanks for the great resource to help people get started with prompt tuning!

I was wondering as to why prompt tuning uses more memory & time than without. To be concrete, I am trying to prompt tune on a 12-layer BertSequenceBertForSequenceClassification model for my custom classification task:

model = BertForSequenceClassification.from_pretrained(my_model_path, labels=10, output_attentions = False, output_hidden_states = False).to("cuda") 

#---freeze all params except prompt---#
for param in model.parameters():
   param.requires_grad = False

#---the only trainable parameter is the soft prompt---#
soft_prompt = True
if (soft_prompt == True):
  softemb = SoftEmbedding(model.get_input_embeddings(), 
                    n_tokens=20, 
                    initialize_from_vocab=True)
  model.set_input_embeddings(softemb)
  model.cuda()

However, I have noticed that without the soft prompt (setting soft_prompt = False), my memory usage is ~1.5GB, while with prompt tuning (soft_prompt = True), it jumps up to 9GB and training takes twice as long. My hypothesis is that since the soft prompt is at the head of the model, we need to calculate all the model's intermediary gradients in order to calculate the soft prompt's gradients, despite setting requires_grad = False for all the intermediate layers.

To test whether calculating the soft prompt's gradients causes the jump in memory, I added .clone().detach() to the soft prompt to make it unlearnable/fixed/not require gradients (that is, replacing this line

learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
with learned_embedding = self.learned_embedding.clone().detach().repeat(input_embedding.size(0), 1, 1). And indeed, even with soft_prompt = True I got only ~1.5GBs of memory usage.

This suggests that the whole backprop to the soft prompt somehow takes the extra 7.5 GBs of memory despite the prompt only being several megabytes of size at most (my soft prompt is only a 20 x 512 dimensional matrix).

In light of this issue, does anyone have potential workarounds that can fully utilize the memory and time benefits of soft prompts? (Something like calculating the gradients for the prompt without needing to store all the intermediate gradients?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant