Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible to build a LoRA that doesn't inject into the transformer? #1523

Closed
AngledLuffa opened this issue Mar 3, 2024 · 37 comments
Closed

Comments

@AngledLuffa
Copy link

Feature request

Is it possible to build a LoRA that doesn't inject into the transformer? This would allow for reusing the same basic transformer with multiple adapters in the same process while saving on GPU memory (probably at the expense of some speed)

Motivation

We've started using PEFT with LoRA for tasks such as sentiment analysis and constituency parsing in Stanza, and one thing we found is that there is currently no memory savings compared to using a fully finetuned transformer.

For example, if the transformer loaded for sentiment analysis takes 3GB, with no finetuning we can reuse the same transformer weights when constituency parsing, making for a total of 3GB plus the prediction heads of the models. If we use fully FT transformers, obviously that increases to 6GB assuming those are our only two tasks.

PEFT with LoRA uses inject_adapter_in_model to update the model with the As and Bs, AFAIK, meaning that loading those two models still takes 6GB. If we could have a version of the transformer which does inference with the As and Bs not injected, but wrapping the base transformer's tensors, this would almost certainly be noticeably slower but would allow for a much smaller memory footprint.

Thanks for the extremely useful library, BTW

Your contribution

I probably don't have much time to investigate this in the next couple months, but in the long term it is something I could attempt with some guidance on where to look

@BenjaminBossan
Copy link
Member

It would be possible to build a LoRA adapter that works purely through forward hooks on the base model, but that's a big difference to how we implement this right now and also not as flexible.

Regarding your issue, do you know that you can disable the LoRA-adapters completely and then the model behaves like the pure base model? This should allow you to avoid loading the model twice:

with model.disable_adapter():
    # do inference on base model

@AngledLuffa
Copy link
Author

Regarding your issue, do you know that you can disable the LoRA-adapters completely and then the model behaves like the pure base model?

Interesting, I did not know that. However, wouldn’t that give different weights to the prediction head, probably resulting in errors since it wasn't trained to recognize those weights?

What I was hoping for was something where each addition use of peft lora only incurred a space cost of |A|+|B|, the same size as the saved weights, but it sounds like such a thing currently doesn't exist, is that correct? If it makes it any easier, I would not be expecting to train from that state, just run inference.

Thanks for the fast reply!

@BenjaminBossan
Copy link
Member

wouldn’t that give different weights to the prediction head, probably resulting in errors since it wasn't trained to recognize those weights?

I think I don't quite get your problem yet. My initial understanding was that you need to have the base model for inference and the LoRA-augmented model for inference and wanted to avoid loading the base model twice. Are the prediction heads trained with or without LoRA? If you train them with LoRA, you have probably added them to modules_to_save, which wraps them in a ModulesToSaveWrapper. This one also works with `disable_adapter'.

What I was hoping for was something where each addition use of peft lora only incurred a space cost of |A|+|B|, the same size as the saved weights, but it sounds like such a thing currently doesn't exist, is that correct?

I'm not sure I understand. When you load the base model, it takes 3GB as you mentioned. When you load LoRA on top, that should increase memory only by a tiny amount. My earlier understanding was that the additional memory comes from loading the base model twice.

@AngledLuffa
Copy link
Author

I'm not sure I understand. When you load the base model, it takes 3GB as you mentioned. When you load LoRA on top, that should increase memory only by a tiny amount.

We might be misunderstanding on our end, but our belief was the underlying model gets changed by loading LoRA. So if we load two separate peft models on top of the same transformer, the two loadings of peft models will clobber each other:

    from transformers import AutoModel
    from peft import LoraConfig, get_peft_model, set_peft_model_state_dict

    original_bert_model = AutoModel.from_pretrained(model_name)
    sentiment_bert_model = get_peft_model(original_bert_model, sentiment_peft_config)
    set_peft_model_state_dict(sentiment_bert_model, sentiment_lora_state_dict)

    constituency_bert_model = get_peft_model(original_bert_model, constituency_peft_config)
    set_peft_model_state_dict(constituency_bert_model, constituency_lora_state_dict)

We checked that the weights in original_bert_model were changed after calling set_peft_model_state_dict the first time, which led us to believe that the models would overwrite each other in such a code snippet

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Mar 5, 2024

LoRA adapters do indeed mutate the base model, but it doesn't "clobber" it. We took care of making it so that the base model itself can still make predictions as if LoRA was not there when LoRA is being disabled, and different LoRA adapters do play nicely with each other.

Note, however, that I would load the adapter a bit differently than you do, not sure if your code would work correctly:

from peft import PeftModel

base_model = AutoModel...
# load adapter 0, it is automatically the active adapter
peft_model = PeftModel.from_pretrained(base_model, <path-to-adapter0>, adapter_name=<name0>)
# load adapter 1, adapter 0 is still active
peft_model.load_adapter(<path-to-adapter1>, adapter_name=<name1>)
# activate adapter 1, deactivate adapter 0
peft_model.set_adapter(<name1>)

@AngledLuffa
Copy link
Author

We took care of making it so that the base model itself can still make predictions as if LoRA was not there when LoRA is being disabled

The bolded part is a little worrisome - does that mean that base_model in my example won't get back the same results unless the peft models are explicitly disabled? We have been treating the annotators such as POS, sentiment, parser, etc as separate entities which had no idea there were other annotators out there that might affect the transformer they are using.

I can verify this myself later tonight, although I don't have time to do so this early afternoon.

@AngledLuffa
Copy link
Author

Can confirm there is an issue with our software loading the transformer and then having that transformer overwritten when connecting with LoRA using peft.

We have a POS model and a sentiment model which both use electra-large, except the sentiment model uses a peft wrapper to get better results. We weren't able to figure out how to get better POS tags with peft, but that's a story for another day.

If I load the POS model by itself, it gets an overall score of 93.92 on the EWT UD dataset (those exact details aren't super relevant).

If I load both the POS model and the Sentiment model, using two different invocations of electra-large, it still gets that score.

If instead I load the POS model, then reuse the electra-large transformer to load the Sentiment model, the score drops to 92.12. Our current release uses the first scheme to keep the results consistent, but that results in 6GB of transformer instead of 3GB.

In terms of separation of concerns, I'd love to have a mechanism where the POS model can load its transformer, the Sentiment model can reuse the same transformer and wrap it in peft, and the two models wouldn't have to do anything specific to their transformer before using it.

@BenjaminBossan
Copy link
Member

Just so I understand correctly, the same base model (electra large) is used by the POS model and the sentiment model. You load LoRA weights on top of the sentiment model. In that case, indeed, it is expected that the POS model also behaves differently because PEFT will modify the base model by attaching the adapters.

If you want to avoid having a copy of the base model in memory, I would recommend checking if disabling the adapters on the sentiment model restores the performance of the POS model.

@AngledLuffa
Copy link
Author

It's actually giving me an error, saying that there is no adapter loaded:

  File "/home/john/stanza/stanza/pipeline/sentiment_processor.py", line 55, in _set_up_model
    trainer = Trainer.load(filename=filename,
  File "/home/john/stanza/stanza/models/classifiers/trainer.py", line 176, in load
    model.bert_model.disable_adapters()
  File "/usr/local/lib/python3.9/site-packages/transformers/integrations/peft.py", line 322, in disable_adapters
    raise ValueError("No adapter loaded. Please load an adapter first.")

I saw your earlier suggestion of using PeftModel.from_pretrained and peft_model.load_adapter as a mechanism for creating peft models from the transformers, but currently that would be a pretty invasive change on our end, since we have been shipping single files with both the model head and the finetuned transformer / peft weights. Hence the use of set_peft_model_state_dict

@BenjaminBossan
Copy link
Member

Did you call disable_adapters after you loaded the PEFT adapter? If yes, it would be strange for that error to occur and I would assume that it is somehow related to how you load the adapter.

Also, as I'm reading your code again, how does the constituency_bert_model fit the picture? Or is that unrelated?

@AngledLuffa
Copy link
Author

Yes, definitely. This crash happens if I do this:

             set_peft_model_state_dict(model.bert_model, model_params['bert_lora'])
            model.bert_model.disable_adapters()

I'm not giving it a name as part of the set... call. Does that affect whether or not the model thinks it has an adapter attached to it? As I mentioned earlier, I was using get_peft_model to build the peft-wrapped transformer. I tried this instead:

model.bert_model.add_adapter(lora_config, adapter_name="adapter_1")
model.bert_model.disable_adapters()

This works, actually, and we get back the expected scores! This is a bit challenging to work with, though, for a couple reasons.

The simplest problem is, what if we have two Sentiment models or other model that uses a peft adapter? We've tried it with POS, although we haven't yet found settings that consistently improve the POS scores, and there are definitely situations where we need multiple POS models built from the same transformer. They can't all be named adapter_name="pos"...

There's also the problem of thread safety and needing to modify the transformer's state (currently active adapter) in order to use it. That might actually be a problem if the proposal to remove GIL is implemented in python

I know there's a trick to call copy.deepcopy on a transformer where we could then presumably add separate peft adapters for each model we load, but that does cause an increase in GPU usage. One very satisfactory outcome would be a way to copy the state of the transformer such that there are multiple transformer objects which each use the same underlying weights but know about different sets of adapters. Does such a thing exist?

What I'm getting at is that add_adapter or load_adapter appear to be necessary to organize the adapters on the transformer object, but it doesn't fully address the issue we're running into of wanting to keep each model's state separate from the other model states. If nothing else, we'd run into the same problem we have now as soon as we start loading multiple models of the same type which each want their own adapter.

Also, as I'm reading your code again, how does the constituency_bert_model fit the picture? Or is that unrelated?

constituency_bert_model would be the same transformer with its own adapter. This is actually a great example of the problem of needing multiple adapters for the same model. We found that we get roughly the same results with a peft-adapted transformer as we do the non-peft adapter, but now the model size we distribute is 200MB instead of 1.5GB. So that's a huge improvement, thank you.

A possible issue here is that there are multiple applications for having several constituency models loaded at once, for example using those models in an ensemble to get better results. Again we would run into the limitation of how to name the adapters, such as if the first constituency parser loads its peft adapter with the name "constituency", then the second constituency parser would need to load it with the name "constituency-2" etc. Would that be something where the constituency parser itself knows that there have been N previously loaded models? Would the caller need to keep track of that? Either way, that seems a lot more complicated and less clean than simply having a fresh transformer object (hopefully sharing the underlying weights to save GPU space) which doesn't need to know anything about previously loaded adapters.

@BenjaminBossan
Copy link
Member

You can load in multiple adapters and give them any name you want (usually using the add_adapter method in transformers). Then you can choose the active adapter with the set_adapter method.

There's also the problem of thread safety and needing to modify the transformer's state (currently active adapter) in order to use it. That might actually be a problem if the proposal to remove GIL is implemented in python

That's still far in the future and will be opt-in.

@AngledLuffa
Copy link
Author

You can load in multiple adapters and give them any name you want (usually using the add_adapter method in transformers).

The thread safety concerns may be further off than we'd like, but the unique names solution doesn't really address the issue of situations where we need more than one POS model or more than one constituency model. We could enforce uniqueness of names somehow, but I was hoping to not add more to the global state than necessary. I suppose a random 10 letter name would almost never repeat and wouldn't require any global state at all...

It would also be a little weird for each model to have the workflow of first turn on their own adapter, then run inference, then turn off their adapter (the latter step being necessary so that any model still using the raw version of the transformer doesn't have to figure out if there even are any adapters to turn off).

Currently what does appear to work for loading time, but not GPU memory, is to load the transformer into memory and then clone it N times for each of the adapters we need. Then we get a situation where one annotator doesn't have to care at all how many other annotators of different types or the same type exist, but at the cost of increased GPU memory usage.

If it's not currently a feature, is it at all feasible to make a wrapper which is just a transformer and a single adapter, where that adapter does not affect inference for other users of the underlying transformer? If that's not anywhere on the current project roadmap, is it something where you'd consider merging a PR that implemented a feature like that?

@BenjaminBossan
Copy link
Member

We could enforce uniqueness of names somehow, but I was hoping to not add more to the global state than necessary. I suppose a random 10 letter name would almost never repeat and wouldn't require any global state at all...

I don't get this point. You should know which adapters exist beforehand, so you can just choose some static names like "sentiment", "pos0", "pos1", or not?

It would also be a little weird for each model to have the workflow of first turn on their own adapter, then run inference, then turn off their adapter (the latter step being necessary so that any model still using the raw version of the transformer doesn't have to figure out if there even are any adapters to turn off).

Yes, I agree that it's not super convenient and, if you have to switch each time a new sample comes in, this could add some overhead. As you mention, the alternative would be to have a copy of the model for each adapter in memory. I think the feature we planned to add in #903 would have helped in your situation, but we didn't pursue it further.

If it's not currently a feature, is it at all feasible to make a wrapper which is just a transformer and a single adapter, where that adapter does not affect inference for other users of the underlying transformer? If that's not anywhere on the current project roadmap, is it something where you'd consider merging a PR that implemented a feature like that?

Unfortunately, that's not easy to achieve. The underlying PyTorch model does not have sufficient flexibility (say, via hooks) to add all the features that we need without mutating the model itself. It would probably be possible to have what you asked when focusing on only the subset of features that you need, but you'd have to build that yourself.

@AngledLuffa
Copy link
Author

I don't get this point. You should know which adapters exist beforehand, so you can just choose some static names like "sentiment", "pos0", "pos1", or not?

Not necessarily true. We provide a way for people to make an annotation pipeline, and we have no control over how many times the user makes a new pipeline with the same base models without keeping some form of global state.

Still, we did implement a cache system for not excessively loading the same word vectors or transformer too many times. Perhaps that would be a reasonable place to keep track of which annotator names have been used in the past as well.

[switching]... this could add some overhead

How much overhead? Is it rewriting tensors or just flipping a pointer?

@BenjaminBossan
Copy link
Member

How much overhead? Is it rewriting tensors or just flipping a pointer?

It's flipping a flag, but on each module, so we have to iterate through all the modules to do this. It should still be cheap compared to inference, but just a heads up to not do it excessively.

@BenjaminBossan
Copy link
Member

We merged #1558, a follow up to #903. Maybe you can take a look if it could be interesting for your use case.

@AngledLuffa
Copy link
Author

AngledLuffa commented Mar 19, 2024

Thanks for the shout out! It looks like a useful feature if running things in small batches, or if we ran the adapters on a per sentence basis rather than a per annotator basis

Actually, in a single Pipeline we generally run inference with runtime life cycle of

  • load all the models at once, including transformers and their adapters (currently adapters are put on a copy of a transformer)
  • run a Pipeline one annotator at a time on one or more documents at a time, so first the POS, then the depparse, then the NER, then the constituencies, etc etc. On a document, the sentences are batched, so a POS annotation could conceivably be far more than a single batch worth of Bert or Electra if the document is long enough

The other main use case is multiple of the same type of model at the same time, where admittedly the switching becomes more frequent between batches. Still, I would think that given your description above, 50 sentences and then switching to another wrapper would only be slightly more expensive than 50 sentences by themselves.

So I'm coming around to the idea of, per Pipeline, there would be one copy of the transformer used by that Pipeline, and each annotator would know to set the adapter for themselves and not worry about what any other annotator did. As previously whinged, that's not thread-safe, but maybe there will be another solution available by the time thread safety is an issue.

There is one issue I have though - how to figure out if a transformer even has adapters attached? As I mentioned above, if I do this to a transformer with no adapters, it throws an exception:

            bert_model.disable_adapters()

Is there a way to check first if there are adapters? Or maybe just not have it throw that exception? I don't think there is a downside to having a model that has no adapters just ignore a call to disable_adapters

@BenjaminBossan
Copy link
Member

Is there a way to check first if there are adapters?

You can check if bert_model._hf_peft_config_loaded is True. Otherwise, there's also no harm in using try ... except in this case.

@AngledLuffa
Copy link
Author

AngledLuffa commented Apr 5, 2024

Does calling the forward pass on a transformer respect the active adapter? If not, how do I go about getting back the same values (the transformer used as a featurizer) once there is an adapter loaded? I would have expected the following little program to output a few different weights, but outputs the same ones each time.

If it's not clear from the script where the mismatched expectations are occurring, I can point to the model files in question (they're both available on HF under Stanford's Stanza repos, FWIW)

Is there something I need to do differently to ensure that I get weights with the transformer's adapter activated in this case?

import torch
from transformers import AutoModel, AutoTokenizer
from peft import LoraConfig, set_peft_model_state_dict

model_name = "google/electra-large-discriminator"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenized = tokenizer([["This", "is", "a", "test"]], padding="longest", is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=True)

#print(tokenized)

attention_tensor = torch.tensor(tokenized['attention_mask'])
id_tensor = torch.tensor(tokenized['input_ids'])

def print_output():
    features = model(id_tensor, attention_mask=attention_tensor, output_hidden_states=True)
    features = torch.stack(features.hidden_states)
    print(torch.linalg.norm(features).item())

#print(attention_tensor)
#print(id_tensor)

print(type(model))

# consider this the baseline
print_output()

peft_config = LoraConfig(inference_mode=False,
                         r=64,
                         target_modules="query,value,output.dense,intermediate.dense".split(","),
                         lora_alpha=128,
                         lora_dropout=0.1,
                         modules_to_save="",
                         bias="none")
model.add_adapter(peft_config, adapter_name="sentiment")
model.enable_adapters()

# shouldn't this change?  a random adapter is now loaded on the bert model
# but it doesn't change anything
print_output()

model.set_adapter("sentiment")

# activated the adapter just in case that was the reason
# still doesn't changed
print_output()

filename = "/home/john/stanza_resources/en/sentiment/sstplus_electra-large.pt"
checkpoint = torch.load(filename, lambda storage, loc: storage)
model_params = checkpoint['params']
set_peft_model_state_dict(model, model_params['bert_lora'], adapter_name="sentiment")

# now we've loaded an actual adapter with weights on top of this.  it absolutely should be changing
assert model.active_adapters() == ['sentiment']
print_output()

filename = "/home/john/stanza_resources/de/constituency/spmrl_german-nlp-electra.pt"
checkpoint = torch.load(filename, lambda storage, loc: storage)
# the file format is different because i'm stupid
# but still, this really ought to be changing the output
model.add_adapter(peft_config, adapter_name="constituency")
set_peft_model_state_dict(model, checkpoint['bert_lora'], adapter_name="constituency")
model.set_adapter("constituency")
assert model.active_adapters() == ['constituency']
print_output()

OUTPUT

<class 'transformers.models.electra.modeling_electra.ElectraModel'>
177.09336853027344
177.09336853027344
177.09336853027344
177.09336853027344
177.09336853027344

@AngledLuffa
Copy link
Author

My thinking here was that I could load several adapters onto the same transformer and switch between them to get the needed encoding for each task, but the switching isn't actually doing what I'd expect. However, if I don't use an adapter name and just leave it to be the default adapter, then it works for one adapter. Doesn't let me switch between multiple adapters, though

@AngledLuffa
Copy link
Author

I do note that I never called get_peft_model in this version of my life. Is that necessary?

@AngledLuffa
Copy link
Author

In terms of calling get_peft_model first with a named model, this doesn't work

pefted = get_peft_model(model, peft_config, "sentiment")
pefted.enable_adapters()

ValueError: No adapter loaded. Please load an adapter first.

This doesn't work

pefted = get_peft_model(model, peft_config, "sentiment")
pefted.add_adapter(adapter_name="sentiment", peft_config=peft_config)
pefted.enable_adapters()

ValueError: No adapter loaded. Please load an adapter first.

If I do this:

# model.add_adapter(peft_config, adapter_name="sentiment")
pefted = get_peft_model(model, peft_config, "sentiment")
pefted.add_adapter(adapter_name="sentiment", peft_config=peft_config)
pefted.set_adapter("sentiment")

filename = "/home/john/stanza_resources/en/sentiment/sstplus_electra-large.pt"
checkpoint = torch.load(filename, lambda storage, loc: storage)
model_params = checkpoint['params']
set_peft_model_state_dict(pefted, model_params['bert_lora'], adapter_name="sentiment")

print_output(pefted)
print_output(model)

Now I get seemingly random output from pefted and model, whereas I would hopefully get the features the model wanted. Furthermore, calling disable_adapters() on either model or pefted throws a ValueError

@AngledLuffa
Copy link
Author

If I try this instead

filename = "/home/john/stanza_resources/en/sentiment/sstplus_electra-large.pt"
checkpoint = torch.load(filename, lambda storage, loc: storage)
model_params = checkpoint['params']
set_peft_model_state_dict(pefted, model_params['bert_lora'])

print_output(pefted)
print_output(model)

This keeps giving random output as well, and again there's no way to disable the adapters...

but weirdly, my models give the same final results when run on test sets...

Still, is there an example I can use or a way to turn that short script above into something where I can easily switch between either the two adapters or a no-adapter form of the transformer?

@BenjaminBossan
Copy link
Member

The issue with your code is a combination of a few, some of it using the PEFT API incorrectly, and some of them the fact that a fresh LoRA adapter is a no-op by default, so it does not affect the result. Below is some code that shows how to use this correctly, I hope it helps to solve your issue:

import torch
from peft import get_peft_model, TaskType, PeftModel, LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.manual_seed(0)
model_id = "facebook/opt-125m"

tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer.encode("This is a test", add_special_tokens=False, return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained(model_id).eval()
outputs = model(inputs, output_hidden_states=True)
print("- base model output")
print(torch.linalg.norm(outputs.hidden_states[-1]).item())


# by setting init_lora_weights to False, we ensure that it's not a no-op
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
outputs = model(inputs, output_hidden_states=True)
print("- peft model output default adapter")
print(torch.linalg.norm(outputs.hidden_states[-1]).item())

# add another adapter
config = LoraConfig(r=32, lora_alpha=32, task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model.add_adapter("adapter2", config)
model.set_adapter("adapter2")
outputs = model(inputs, output_hidden_states=True)
print("- peft model output adapter 2")
print(torch.linalg.norm(outputs.hidden_states[-1]).item())

print("saving and loading")

model.save_pretrained("/tmp/issue-1523")
del model

model = AutoModelForCausalLM.from_pretrained(model_id).eval()
outputs = model(inputs, output_hidden_states=True)
print("- loaded model output")
print(torch.linalg.norm(outputs.hidden_states[-1]).item())

model = PeftModel.from_pretrained(model, "/tmp/issue-1523")
outputs = model(inputs, output_hidden_states=True)
print("- loaded peft model output")
print(torch.linalg.norm(outputs.hidden_states[-1]).item())

model.load_adapter("/tmp/issue-1523/adapter2", adapter_name="adapter2")
model.set_adapter("adapter2")
outputs = model(inputs, output_hidden_states=True)
print("- loaded peft model output adapter 2")
print(torch.linalg.norm(outputs.hidden_states[-1]).item())

This gives me:

- base model output
52.065834045410156
- peft model output default adapter
57.894798278808594
- peft model output adapter 2
57.4752082824707
saving and loading
- loaded model output
52.065834045410156
- loaded peft model output
57.894798278808594
- loaded peft model output adapter 2
57.4752082824707

As to the difference between get_peft_model and PeftModel.from_pretrained: The former is for creating new PEFT adapters that you have to train afterwards, the latter is for loading already trained adapters.

Regarding enable_adapters(), you should not need to use that. Use set_adapter to activate a specific, loaded adapter.

@AngledLuffa
Copy link
Author

Ah, great, I didn't see this with previous versions (or maybe I missed it), but the latest version of the peft integration has the ability to call load_adapter specifically with the state dictionary:

model.load_adapter(adapter_name="sentiment", peft_config=peft_config, adapter_state_dict=lora_params)
model.set_adapter("sentiment")
model.eval()

assert model.active_adapters() == ['sentiment']
print_output()

model.disable_adapters()
print_output()

I will have to check that this allows for training to continue if eval() is not called. If so, that might be exactly what we were looking for earlier. Thanks!

I do have two minor complaints about the interface - not sure how easy it would be to fix at this point, seeing as how these modules are both publicly released. In transformers/integrations/peft.py there is

    def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> None:
    def active_adapters(self) -> List[str]:

whereas in peft/peft_model.py it

    def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None:

    @property
    def active_adapters(self) -> list[str]:

so for add_adapter the order of arguments is different, and for active_adapters, the @property is a list and the other is a method that returns a list.

@AngledLuffa
Copy link
Author

Another minor interface complaint: after doing this, I get back the original encoding for the text, which is great. That's exactly what we need. However, the following snippet:

model.load_adapter(adapter_name="sentiment", peft_config=peft_config, adapter_state_dict=lora_params)

# do stuff with this adapter

# later...
model.disable_adapters()
print_output()     # gets back the values of the original transformer, excellent
print(model.active_adapters())    # still prints out the adapter we just turned off!

It would be great to have utility methods on the model which indicated if any peft integrations are currently attached (you mentioned checking _hf_peft_config_loaded, but a method sounds a bit cleaner) and a utility method which indicates if adapters are currently enabled or not.

@AngledLuffa
Copy link
Author

Here's something weird, but possibly known / not important.

When created with get_peft_model, as you recommended for training, the modules are saved via get_peft_model_state_dict have names such as base_model.model.encoder.layer.0.attention.output.dense.lora_A.weight

When loaded with bert_model.load_adapter, though, the modules are saved with names such as encoder.layer.0.attention.output.dense.lora_A.weight

At any rate, I can confirm that they both get loaded as expected when put into a new instance of the transformer w/ peft, and furthermore I can train them after reloading a checkpoint, even if I loaded that checkpoint with load_adapter rather than get_peft_model and set_peft_model_state_dict.

Does that sound like a reasonable approach to take? There's no real reason to have different code paths for loading for training or loading for eval, is there?

@AngledLuffa
Copy link
Author

AngledLuffa commented Apr 7, 2024

When building an optimizer for a transformer model with peft on it, is there a way to only get the optimizer state for the active adapter? Currently, when I call parameters(), I get back all of the parameters. I suppose filtering those by the expected adapter name might work

Alternatively, a way to remove an adapter might be sufficient for my needs. To explain the problem I'm running into:

  • constituency parsing gets more authentic results if you retag with silver POS tags before training
  • so I load a POS model before running the parser
  • loading the transformer can be expensive (not too expensive, I suppose, but still), so I try to reuse the transformer the POS loaded, if possible
  • if that POS model had a peft adapter, when I train the constituency parser, now the optimizer is being created with both the POS and constituency peft parameters
  • later, if I only want to load in the constituency parser for whatever reason (a unit test, perhaps), it fails if there's an optimizer created with only the constituency peft parameters

@AngledLuffa
Copy link
Author

Regarding enable_adapters(), you should not need to use that. Use set_adapter to activate a specific, loaded adapter.

This statement confuses me. If I want to use the original model w/o adapters after having loaded an adapter, I think I need to call disable_adapters to make that happen. At that point, if I want to go back to using an adapter, I have to call enable_adapters to restore them. If there's some other process I should use, please let me know.

However, I will say this doesn't work in the case of creating an adapter via get_peft_model, as far as I can tell. Such a model doesn't have self._hf_peft_config_loaded set on either the PeftModel or the original transformer object. This snippet, for example, throws a ValueError at the end. Also, my understanding is get_peft_model is what I'm supposed to call when I want to train a new peft adapter.

In such a case, is it possible to turn off the peft adapter in any way? I really do think this is a bug, that there should be a way to disable_adapters but _hf_peft_config_loaded is not set

from transformers import AutoModel
from peft import LoraConfig, get_peft_model

model_name = "google/electra-large-discriminator"
model = AutoModel.from_pretrained(model_name)

peft_config = LoraConfig(inference_mode=False,
                         r=64,
                         target_modules="query,value,output.dense,intermediate.dense".split(","),
                         lora_alpha=128,
                         lora_dropout=0.1,
                         modules_to_save="",
                         bias="none")

# with or without an adapter_name makes no difference
#adapter_name = "sentiment"
#pefted = get_peft_model(model, peft_config, adapter_name=adapter_name)
adapter_name = "default"
pefted = get_peft_model(model, peft_config)

# both of these are False
print(model._hf_peft_config_loaded)
print(pefted._hf_peft_config_loaded)

# can call this - PeftModel.set_adapter doesn't check _hf_peft_config_loaded
pefted.set_adapter(adapter_name)

# can't call this - the transformers set_adapter function *does* check,
# so this throws a ValueError
model.set_adapter(adapter_name)

@BenjaminBossan
Copy link
Member

I do have two minor complaints about the interface - not sure how easy it would be to fix at this point, seeing as how these modules are both publicly released. In transformers/integrations/peft.py there is

    def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> None:
    def active_adapters(self) -> List[str]:

whereas in peft/peft_model.py it

    def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None:

    @property
    def active_adapters(self) -> list[str]:

so for add_adapter the order of arguments is different, and for active_adapters, the @property is a list and the other is a method that returns a list.

Indeed it would be better to have more consistency here, but as you mentioned, we can't change that without breaking existing code, so it is how it is.

It would be great to have utility methods on the model which indicated if any peft integrations are currently attached (you mentioned checking _hf_peft_config_loaded, but a method sounds a bit cleaner) and a utility method which indicates if adapters are currently enabled or not.

Yes, indeed. It's not quite as easy as it sounds, because theoretically, some modules belonging to the adapter could be enabled and others disabled (although the current API does not expose this possibility, users could still do this manually). If I have some time on my hands, I'll think of something.

When created with get_peft_model, as you recommended for training, the modules are saved via get_peft_model_state_dict have names such as base_model.model.encoder.layer.0.attention.output.dense.lora_A.weight

When loaded with bert_model.load_adapter, though, the modules are saved with names such as encoder.layer.0.attention.output.dense.lora_A.weight

Using get_peft_model creates a PeftModel instance, which has a couple of more PEFT-specific features. Therefore, the state_dict is more nested. Using bert_model.load_adapter, the PEFT adapter is directly injected into the model as is, thus there is no further nesting.

When building an optimizer for a transformer model with peft on it, is there a way to only get the optimizer state for the active adapter? Currently, when I call parameters(), I get back all of the parameters. I suppose filtering those by the expected adapter name might work

Indeed, filtering by name is the way to go here. Directly after loading, you could also filter by param.requires_grad, but if you change that (e.g. calling .eval()), this method is not reliable anymore.

This statement confuses me. If I want to use the original model w/o adapters after having loaded an adapter, I think I need to call disable_adapters to make that happen. At that point, if I want to go back to using an adapter, I have to call enable_adapters to restore them. If there's some other process I should use, please let me know.

When you create the model with get_peft_model, you should use the disable_adapter context manager (I know, confusing names), the adapter is re-enabled after exiting the context manager:

peft_model = get_peft_model(...)
# default adapter is active
with peft_model.disable_adapter():
    # inference without adapters
# now, default adapter is active again

@AngledLuffa
Copy link
Author

It would be great to have utility methods on the model which indicated if any peft integrations are currently attached (you mentioned checking _hf_peft_config_loaded, but a method sounds a bit cleaner) and a utility method which indicates if adapters are currently enabled or not.

Yes, indeed. It's not quite as easy as it sounds, because theoretically, some modules belonging to the adapter could be enabled and others disabled (although the current API does not expose this possibility, users could still do this manually). If I have some time on my hands, I'll think of something.

Ah, I can see how worrying about users can be a problem. You've treated this user quite nicely so far, at least. Thanks in advance for any progress you can make on adding an interface. My current working solution is to set the active adapter with each new batch, just in case a previous batch used a different annotator and therefore a different adapter. You mentioned this might add some overhead, but it doesn't make a noticeable different in our annotation speed (14s w/ or w/o this change for the EWT dataset, for example).

When building an optimizer for a transformer model with peft on it

Indeed, filtering by name is the way to go here. Directly after loading, you could also filter by param.requires_grad, but if you change that (e.g. calling .eval()), this method is not reliable anymore.

Gotcha, thanks. The approach I've been working on is to make sure the transformer used for the training is not used for anything else, eg a separate copy from the one used for the POS retagging in the constituency or dependency case. I think that should avoid all such problems.

When you create the model with get_peft_model, you should use the disable_adapter context manager (I know, confusing names), the adapter is re-enabled after exiting the context manager:

Thanks. Is there a downside to setting the model._hf_peft_config_loaded in the get_peft_model version? It would save us some effort to only have one code path for this version. (Although again, I'm trying to avoid mixing multiple adapters in the case where we're training.)

@BenjaminBossan
Copy link
Member

Is there a downside to setting the model._hf_peft_config_loaded in the get_peft_model version? It would save us some effort to only have one code path for this version.

Not from the top of my head, but I haven't really "mixed" the transformers use of PEFT with the PEFT use.

@AngledLuffa
Copy link
Author

Thanks again for all your help! I think we're good with using multiple adapters on the same transformer now. If there's any updates to the interface that include shortcuts for checking whether or not any adapters are actually active (you mentioned it being a recursive call for now), that would make things a bit faster for us. Also, if there's ever a way to use a single adapter on a shallow copy of the transformer (eg, no deep copy of the weights), that would also simplify our usage quite a bit.

@BenjaminBossan
Copy link
Member

If there's any updates to the interface that include shortcuts for checking whether or not any adapters are actually active (you mentioned it being a recursive call for now), that would make things a bit faster for us.

I'm currently working on #1663, which goes somewhat in this direction, but I'm not sure if it 100% fits your use case. Maybe you can take a look.

@AngledLuffa
Copy link
Author

That does look relevant in terms of making the switching between adapters faster / simpler. Thanks!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants