Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi Adapter support #263

Merged
merged 39 commits into from
Apr 6, 2023
Merged

Multi Adapter support #263

merged 39 commits into from
Apr 6, 2023

Conversation

pacman100
Copy link
Contributor

What does this PR do?

  1. Adds multi-adapter training and inference support.
  2. Fixes How to use multiple lora at the same time for text generation? #211, Is it possible to "unload" the PEFT LoRA weights after mutating the base model with PeftModel.from_pretrained ? #208, adds multiple adapters to a peft model #133 (comment), adds multiple adapters to a peft model #133 (comment) and https://gist.github.com/philschmid/821c5317d144250feef517aecd390b98

Usage

  1. While loading the first adapter via PeftModel.from_pretrained, you can give it a name using **adapter_name** parameter. Else the default adapter name default is used.
  2. To load another adapter, use **load_adapter()** method of PeftModel, e.g., model.load_adapter(peft_model_path, adapter_name)
  3. To switch between adapters, use **set_adapter()** method of PeftModel, e.g., model.set_adapter(adapter_name)
  4. To disable adapters, use context manager disable_adapter(), e.g., with model.disable_adapter()
  5. Specific to LoRA method: To merge and unload the current active adapter so that the lora weights are added to the base model weights and he injected models are removed to get back the transformers base model with lora weights added, use merge_and_unload() method, e.g., model = model.merge_and_unload()
from peft import PeftModel
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig

model_name = "decapoda-research/llama-7b-hf"
tokenizer = LlamaTokenizer.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    device_map="auto",
    use_auth_token=True
)
model = PeftModel.from_pretrained(model, "tloen/alpaca-lora-7b", adapter_name="eng_alpaca")
model.load_adapter("22h/cabrita-lora-v0-1", adapter_name="portuguese_alpaca")

model.set_adapter("eng_alpaca")
instruction = "Tell me about alpacas."
print(evaluate(instruction))

model.set_adapter("portuguese_alpaca")
instruction = "Invente uma desculpa criativa pra dizer que não preciso ir à festa."
print(evaluate(instruction))

with model.disable_adapter():
    instruction = "Invente uma desculpa criativa pra dizer que não preciso ir à festa."
    print(evaluate(instruction))

Link to colab notebook: https://colab.research.google.com/drive/1vrVg8G7AIdCM9qpcfZya0B7OEsPSAxbO?usp=sharing

@pacman100 pacman100 marked this pull request as ready for review April 4, 2023 22:03
@pacman100 pacman100 changed the title Smangrul/multi lora support Multi Adapter support Apr 4, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 5, 2023

The documentation is not available anymore as the PR was closed or merged.

@Dentoty
Copy link

Dentoty commented Apr 5, 2023

Would this work for Inference using LoRAs that were trained by using train_dreambooth.py ?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! 🔥 This feature is really great
I left a couple of open questions, just to confirm if I have understood some changes correctly, my main concern was whether this introduced some breaking changes with respect to adapters that are on the Hub, but I think that you are dealing correctly with that using ModulesToSaveWrapper
IMO we just need to figure out why some tests are failing, they're related to merging layers for some models, I made a suggestion below (I think in the test we don't deal correctly with the case merge_weights=False.
Also I would like to hear your thoughts on dropping MergedLinear in favor of a single Linear class!
Let's also introduce slow tests, (based on the snippet you shared), this can be done in #256

Comment on lines 439 to 441
if not self.merge_weights:
warnings.warn("Nothing to merge. Set merge_weights to True to enable merging.")
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe that could explain why the tests are failing, we should always call merge_weights=False (i.e. not test the case wheremerge_weights=False) --> or maybe call .eval() in case merge_weigths=True-

src/peft/tuners/lora.py Show resolved Hide resolved
Comment on lines 113 to 154
class ModulesToSaveWrapper(torch.nn.Module):
def __init__(self, module_to_save, adapter_name):
super().__init__()
self.original_module = module_to_save
self.modules_to_save = torch.nn.ModuleDict({})
self.update(adapter_name)
self.active_adapter = adapter_name

def update(self, adapter_name):
self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))

def forward(self, *args, **kwargs):
if self.active_adapter not in self.modules_to_save:
return self.original_module(*args, **kwargs)
return self.modules_to_save[self.active_adapter](*args, **kwargs)


def _get_submodules(model, key):
parent = model.get_submodule(".".join(key.split(".")[:-1]))
target_name = key.split(".")[-1]
target = model.get_submodule(key)
return parent, target, target_name


def _set_trainable(model, adapter_name):
key_list = [key for key, _ in model.named_modules()]
for key in key_list:
target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save)
if target_module_found:
parent, target, target_name = _get_submodules(model, key)
if isinstance(target, ModulesToSaveWrapper):
target.update(adapter_name)
else:
for param in target.parameters():
param.requires_grad = True
setattr(parent, target_name, ModulesToSaveWrapper(target, adapter_name))


def _set_adapter(model, adapter_name):
for module in model.modules():
if isinstance(module, ModulesToSaveWrapper):
module.active_adapter = adapter_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this avoids the breaking change I believe, can you confirm? 🙏

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This avoids breaking changes when there are additional trainable layers such as classfier head/regression head on top of the model for tasks like AutoModelForSequenceClassification or TRL reward model head ... these are also saved along with the adapter weights and if each checkpoint has its own additional trainable layers, this makes sure that they are properly being called.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, this is great then!

@@ -571,7 +645,8 @@ def forward(
return_dict=None,
**kwargs,
):
if not isinstance(self.peft_config, PromptLearningConfig):
peft_config = self.active_peft_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it should be documented somewhere that the way retrieving the peft config has changed, now it's active_peft_config rather than peft_config

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great! Let's address the conversion script + example script / notebook in a follow up PR 🔥 Thanks for explaining in detail the approaches you made

@lxe
Copy link

lxe commented Apr 7, 2023

Great timing on this! Tested it and it's working swimmingly ^^ 🤗

@llllllim
Copy link

llllllim commented Jan 2, 2024

How can I just upload the lora and remain the base model?I want to reduce memory usage

@HyperdustLab
Copy link

this is nice

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

How to use multiple lora at the same time for text generation?
8 participants