-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LORA finetuning gradients are scaled by a unknown constant factor #1893
Comments
There is really a lot going on in your notebook, so it's hard for me to tell where this constant is coming from. Therefore, I created the most simple version of this problem and this revealed that for SGD with lr=1, the parameter upgrade is exactly equal to the gradient: import copy
import torch
from torch import nn
from peft import get_peft_model, LoraConfig
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(10, 5)
def forward(self, x):
return self.lin(x)
torch.manual_seed(0)
x = torch.randn(8, 10)
# without LoRA
torch.manual_seed(0)
model = MyModule()
sd = copy.deepcopy(model).state_dict()
sgd = torch.optim.SGD(model.parameters(), lr=1, momentum=0, maximize=False, nesterov=False, weight_decay=0)
# train
sgd.zero_grad()
out = model(x)
loss = out.sum()
loss.backward()
sgd.step()
# compare
sd2 = model.state_dict()
grad = model.lin.weight.grad
torch.testing.assert_close(sd['lin.weight'] - sd2['lin.weight'], grad)
# with LoRA
torch.manual_seed(0)
model = MyModule()
config = LoraConfig(target_modules=["lin"], init_lora_weights=False)
model = get_peft_model(model, config)
sd = copy.deepcopy(model).state_dict()
sgd = torch.optim.SGD(model.parameters(), lr=1, momentum=0, maximize=False, nesterov=False, weight_decay=0)
#train
sgd.zero_grad()
out = model(x)
loss = out.sum()
loss.backward()
sgd.step()
# compare
sd2 = model.state_dict()
assert model.base_model.lin.base_layer.weight.grad is None
grad = model.base_model.lin.lora_A["default"].weight.grad
torch.testing.assert_close(sd['base_model.model.lin.lora_A.default.weight'] - sd2['base_model.model.lin.lora_A.default.weight'], grad)
grad = model.base_model.lin.lora_B["default"].weight.grad
torch.testing.assert_close(sd['base_model.model.lin.lora_B.default.weight'] - sd2['base_model.model.lin.lora_B.default.weight'], grad) This doesn't solve your issue, but it shows that the discrepancy is unlikely to come from the LoRA implementation. To investigate further, here are some ideas:
|
Thanks for your input! I was able to find the cause of the bug. I was registering a parameter hook as below, but for some reason, the gradient observed by such a hook is different from the one in the state dictionary. All checks pass when using the state dictionary!
|
System Info
torch: 2.3.0+cu121
transformers: 4.41.2
peft: 0.11.1
datasets: 2.20.0
Who can help?
@BenjaminBossan @sayakpaul
Information
Tasks
examples
folderReproduction
You can run the following Colab notebook: https://colab.research.google.com/drive/1lgFyKZaZ3ySXWRcfImsry92X7dhrVgZz?usp=sharing
There are two sections in the linked Collab doc.
Expected behavior
I'm trying to integrate the
peft
library in our framework, but I am running into an unexplained behavior when performing LORA finetuning. I've noticed that an unidentified factor is scaling the gradients before they are used to update the weights in each optimization step.For example, when using the SGD optimizer with parameters
{lr: 1.0, maximize: False, momentum: 0, nesterov: False, weight_decay: 0.0}
and a constant learning rate scheduler, you would expect the weights to be updated as follows at each step:However, weights are instead updated as follows (note the
c
constant factor):Where does$16/16=1.0$ . I have already looked through the code, and printed any scaling constants, such as this one: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py#L122, which is always 1.0 as expected. I have also checked, and the learning rate at each optimizer stage is 1.0 as I've set it.
c
come from, and what is its formula? With rank=lora_alpha=16, I'd expect a scaling ofThe text was updated successfully, but these errors were encountered: