You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
First of all, thanks for the great resource to help people get started with prompt tuning!
I was wondering as to why prompt tuning uses more memory & time than without. To be concrete, I am trying to prompt tune on a 12-layer BertSequenceBertForSequenceClassification model for my custom classification task:
model = BertForSequenceClassification.from_pretrained(my_model_path, labels=10, output_attentions = False, output_hidden_states = False).to("cuda")
#---freeze all params except prompt---#
for param in model.parameters():
param.requires_grad = False
#---the only trainable parameter is the soft prompt---#
soft_prompt = True
if (soft_prompt == True):
softemb = SoftEmbedding(model.get_input_embeddings(),
n_tokens=20,
initialize_from_vocab=True)
model.set_input_embeddings(softemb)
model.cuda()
However, I have noticed that without the soft prompt (setting soft_prompt = False), my memory usage is ~1.5GB, while with prompt tuning (soft_prompt = True), it jumps up to 9GB and training takes twice as long. My hypothesis is that since the soft prompt is at the head of the model, we need to calculate all the model's intermediary gradients in order to calculate the soft prompt's gradients, despite setting requires_grad = False for all the intermediate layers.
To test whether calculating the soft prompt's gradients causes the jump in memory, I added .clone().detach() to the soft prompt to make it unlearnable/fixed/not require gradients (that is, replacing this line
with learned_embedding = self.learned_embedding.clone().detach().repeat(input_embedding.size(0), 1, 1). And indeed, even with soft_prompt = True I got only ~1.5GBs of memory usage.
This suggests that the whole backprop to the soft prompt somehow takes the extra 7.5 GBs of memory despite the prompt only being several megabytes of size at most (my soft prompt is only a 20 x 512 dimensional matrix).
In light of this issue, does anyone have potential workarounds that can fully utilize the memory and time benefits of soft prompts? (Something like calculating the gradients for the prompt without needing to store all the intermediate gradients?)
The text was updated successfully, but these errors were encountered:
First of all, thanks for the great resource to help people get started with prompt tuning!
I was wondering as to why prompt tuning uses more memory & time than without. To be concrete, I am trying to prompt tune on a 12-layer BertSequenceBertForSequenceClassification model for my custom classification task:
However, I have noticed that without the soft prompt (setting
soft_prompt = False
), my memory usage is ~1.5GB, while with prompt tuning (soft_prompt = True
), it jumps up to 9GB and training takes twice as long. My hypothesis is that since the soft prompt is at the head of the model, we need to calculate all the model's intermediary gradients in order to calculate the soft prompt's gradients, despite settingrequires_grad = False
for all the intermediate layers.To test whether calculating the soft prompt's gradients causes the jump in memory, I added
.clone().detach()
to the soft prompt to make it unlearnable/fixed/not require gradients (that is, replacing this linesoft-prompt-tuning/soft_embedding.py
Line 53 in 6c6d31a
learned_embedding = self.learned_embedding.clone().detach().repeat(input_embedding.size(0), 1, 1)
. And indeed, even withsoft_prompt = True
I got only ~1.5GBs of memory usage.This suggests that the whole backprop to the soft prompt somehow takes the extra 7.5 GBs of memory despite the prompt only being several megabytes of size at most (my soft prompt is only a 20 x 512 dimensional matrix).
In light of this issue, does anyone have potential workarounds that can fully utilize the memory and time benefits of soft prompts? (Something like calculating the gradients for the prompt without needing to store all the intermediate gradients?)
The text was updated successfully, but these errors were encountered: