-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi Adapter support #263
Conversation
Might have breaking changes
The documentation is not available anymore as the PR was closed or merged. |
Would this work for Inference using LoRAs that were trained by using train_dreambooth.py ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! 🔥 This feature is really great
I left a couple of open questions, just to confirm if I have understood some changes correctly, my main concern was whether this introduced some breaking changes with respect to adapters that are on the Hub, but I think that you are dealing correctly with that using ModulesToSaveWrapper
IMO we just need to figure out why some tests are failing, they're related to merging layers for some models, I made a suggestion below (I think in the test we don't deal correctly with the case merge_weights=False
.
Also I would like to hear your thoughts on dropping MergedLinear
in favor of a single Linear
class!
Let's also introduce slow tests, (based on the snippet you shared), this can be done in #256
src/peft/tuners/lora.py
Outdated
if not self.merge_weights: | ||
warnings.warn("Nothing to merge. Set merge_weights to True to enable merging.") | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe that could explain why the tests are failing, we should always call merge_weights=False
(i.e. not test the case wheremerge_weights=False
) --> or maybe call .eval()
in case merge_weigths=True
-
src/peft/utils/other.py
Outdated
class ModulesToSaveWrapper(torch.nn.Module): | ||
def __init__(self, module_to_save, adapter_name): | ||
super().__init__() | ||
self.original_module = module_to_save | ||
self.modules_to_save = torch.nn.ModuleDict({}) | ||
self.update(adapter_name) | ||
self.active_adapter = adapter_name | ||
|
||
def update(self, adapter_name): | ||
self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)})) | ||
|
||
def forward(self, *args, **kwargs): | ||
if self.active_adapter not in self.modules_to_save: | ||
return self.original_module(*args, **kwargs) | ||
return self.modules_to_save[self.active_adapter](*args, **kwargs) | ||
|
||
|
||
def _get_submodules(model, key): | ||
parent = model.get_submodule(".".join(key.split(".")[:-1])) | ||
target_name = key.split(".")[-1] | ||
target = model.get_submodule(key) | ||
return parent, target, target_name | ||
|
||
|
||
def _set_trainable(model, adapter_name): | ||
key_list = [key for key, _ in model.named_modules()] | ||
for key in key_list: | ||
target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save) | ||
if target_module_found: | ||
parent, target, target_name = _get_submodules(model, key) | ||
if isinstance(target, ModulesToSaveWrapper): | ||
target.update(adapter_name) | ||
else: | ||
for param in target.parameters(): | ||
param.requires_grad = True | ||
setattr(parent, target_name, ModulesToSaveWrapper(target, adapter_name)) | ||
|
||
|
||
def _set_adapter(model, adapter_name): | ||
for module in model.modules(): | ||
if isinstance(module, ModulesToSaveWrapper): | ||
module.active_adapter = adapter_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, this avoids the breaking change I believe, can you confirm? 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This avoids breaking changes when there are additional trainable layers such as classfier head/regression head
on top of the model for tasks like AutoModelForSequenceClassification or TRL reward model head ... these are also saved along with the adapter weights and if each checkpoint has its own additional trainable layers, this makes sure that they are properly being called.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, this is great then!
@@ -571,7 +645,8 @@ def forward( | |||
return_dict=None, | |||
**kwargs, | |||
): | |||
if not isinstance(self.peft_config, PromptLearningConfig): | |||
peft_config = self.active_peft_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it should be documented somewhere that the way retrieving the peft config has changed, now it's active_peft_config
rather than peft_config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great! Let's address the conversion script + example script / notebook in a follow up PR 🔥 Thanks for explaining in detail the approaches you made
Great timing on this! Tested it and it's working swimmingly ^^ 🤗 |
How can I just upload the lora and remain the base model?I want to reduce memory usage |
this is nice |
What does this PR do?
Usage
PeftModel.from_pretrained
, you can give it a name using**adapter_name**
parameter. Else the default adapter namedefault
is used.**load_adapter()**
method of PeftModel, e.g.,model.load_adapter(peft_model_path, adapter_name)
**set_adapter()**
method of PeftModel, e.g.,model.set_adapter(adapter_name)
disable_adapter()
, e.g.,with model.disable_adapter()
merge_and_unload()
method, e.g.,model = model.merge_and_unload()
Link to colab notebook: https://colab.research.google.com/drive/1vrVg8G7AIdCM9qpcfZya0B7OEsPSAxbO?usp=sharing