Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues when switching between multiple adapters LoRAs #1802

Closed
JhonDan1999 opened this issue May 26, 2024 · 8 comments
Closed

Issues when switching between multiple adapters LoRAs #1802

JhonDan1999 opened this issue May 26, 2024 · 8 comments

Comments

@JhonDan1999
Copy link

I am experiencing the same issues when switching between multiple adapters despite following the documentation and checking this #1315 (@BenjaminBossan)I do not observe any change in the model’s behavior when switching between the adapters.

base_model.add_adapter(PeftConfig.from_pretrained(peft_model_output_dir1), adapter_name="adapter_1")
base_model.add_adapter(PeftConfig.from_pretrained(peft_model_output_dir2), adapter_name="adapter_2")

When I switch between the adapters using the set_adapter method, there is no observable change in the model’s behavior. The outputs remain the same, regardless of which adapter is active.

I suspect that the set_adapter method does not actually activate the specified adapter correctly. Instead, I notice a change in behavior only when I merge the adapter with the base model

The documentation does not mention the need to perform a merge when switching adapters. Additionally, the methods add_adapter, set_adapter, and enable_adapters do not appear to work

Please provide clarification on how to correctly switch between adapters

@BenjaminBossan
Copy link
Member

This should not happen. Could you please share some code to reproduce this error?

If you're on the latest PEFT version, you can also run model.get_model_status() and model.get_layer_status() to help troubleshoot the issue. You could share those outputs as well.

@JhonDan1999
Copy link
Author

thank you for the prompt response

here is the code and output

Screenshot 2024-05-27 at 4 48 11 PM

the Loras appear in the model when I print the model layers like this
Screenshot 2024-05-27 at 4 50 51 PM

when I run this model.get_layer_status() it gave me this error

Screenshot 2024-05-27 at 4 53 29 PM

@BenjaminBossan
Copy link
Member

Hey, could you please paste the code as text, otherwise I'd have to copy everything by hand if I want to reproduce :) Also, if you call base_model.add_adapter(...), you're adding a fresh, untrained LoRA adapter, which by default is a no-op (unless init_lora_weights=False in the config). Are you aware of that? Did you intend to call load_adapter instead?

@JhonDan1999
Copy link
Author

no i was not aware of that ("adding a fresh, untrained LoRA adapter") so now i changed the code based this information but i go the same behavior as you can see

Screenshot 2024-05-27 at 5 49 21 PM

here is text code

from peft import PeftConfig

config_lora1 = PeftConfig.from_pretrained(peft_adapter1_output_dir, init_lora_weights = False)
config_lora2 = PeftConfig.from_pretrained(peft_adapter2_output_dir, init_lora_weights = False)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
)

base_model = AutoModelForCausalLM.from_pretrained(
    config_lora1.base_model_name_or_path,
    quantization_config=bnb_config,
    device_map='auto',
    torch_dtype=torch.bfloat16,
    use_cache=False,
)
tokenizer = AutoTokenizer.from_pretrained(config_lora1.base_model_name_or_path, trust_remote_code=True)

input_prompt = "Hello, how are you"
inputs = tokenizer(input_prompt, return_tensors='pt').to("cuda")

# Generate text with base model
output_base = base_model.generate(**inputs)
print("Base Model Output:")
print(tokenizer.decode(output_base[0], max_new_tokens=50, pad_token_id=tokenizer.eos_token_id))



base_model.load_adapter(peft_adapter1_output_dir, adapter_name="adapter_1")
base_model.load_adapter(peft_adapter2_output_dir, adapter_name="adapter_2")

# base_model.disable_adapters()
# Generate text with adapter 1
base_model.set_adapter("adapter_1")
output_adapter1 = base_model.generate(**inputs)
print("\nAdapter 1 Output:")
print(tokenizer.decode(output_base[0], skip_special_tokens=True, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id))

# Generate text with adapter 2
base_model.set_adapter("adapter_2")
output_adapter2 = base_model.generate(**inputs)
print("\nAdapter 2 Output:")
print(tokenizer.decode(output_base[0], skip_special_tokens=True, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id))

@BenjaminBossan
Copy link
Member

Okay, thanks for trying. Since you use some private adapters, I can't really reproduce, unless you can share your adapters.

One thing to try out would be to use PEFT to load the adapter, not transformers:

from peft import PeftModel

# instead of base_model.load_adapter
model = PeftModel.from_pretrained(base_model, peft_adapter1_output_dir, adapter_name="adapter_1")
model.load_adapter(peft_adapter2_output_dir, adapter_name="adapter_2")

Regarding the error with get_layer_status: You could do:

from peft import get_model_status, get_layer_status

# after base_model.load_adapter
get_layer_status(base_model)
get_model_status(base_model)

and paste the results here.

@JhonDan1999
Copy link
Author

JhonDan1999 commented May 27, 2024

thank you for your prompt responses

I tried what you mentioned
here what I got the same behaviour

Base Model Output:
<s> Hello, how are you? I hope you’re having a great day. Today I want

Adapter 1 Output:
Hello, how are you? I hope you’re having a great day. Today I want

Adapter 2 Output:
Hello, how are you? I hope you’re having a great day. Today I want

for this part :

from peft import PeftModel

# instead of base_model.load_adapter
model = PeftModel.from_pretrained(base_model, peft_adapter1_output_dir, adapter_name="adapter_1")
model.load_adapter(peft_adapter2_output_dir, adapter_name="adapter_2")

from peft import get_model_status, get_layer_status

# after base_model.load_adapter
get_layer_status(model)
get_model_status(model)

it gave me this output

TunerModelStatus(base_model_type='MistralForCausalLM', adapter_model_type='LoraModel', peft_types={'adapter_1': 'LORA', 'adapter_2': 'LORA'}, trainable_params=167772160, total_params=7745048576, num_adapter_layers=224, enabled=True, active_adapters=['adapter_1'], merged_adapters=[], requires_grad={'adapter_1': True, 'adapter_2': False}, available_adapters=['adapter_1', 'adapter_2'])

Regardless of my private adapters is this approach working with you on any adapter you can access?

@BenjaminBossan
Copy link
Member

Regardless of my private adapters is this approach working with you any adapter you can access?

Yes, it's working, here is a simple test:

import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.manual_seed(0)
model_id = "facebook/opt-125m"

model = AutoModelForCausalLM.from_pretrained(model_id)
input = torch.tensor([[1, 2, 3, 4, 5]])
output_base = model(input).logits
print("Base model output:")
print(output_base[0, :3, :5])

# create a PEFT model with 2 adapters and save it
config = LoraConfig(r=8, init_lora_weights=False)
model = get_peft_model(model, config, adapter_name="adapter1")
model.add_adapter("adapter2", config)
model.save_pretrained("/tmp/issue-1802")

# load the model again
del model
model = AutoModelForCausalLM.from_pretrained("/tmp/issue-1802/adapter1", adapter_name="adapter1")
model.load_adapter("/tmp/issue-1802/adapter2", "adapter2")

model.set_adapter("adapter1")
output_adapter1 = model(input).logits
print("Model output after loading adapter1:")
print(output_adapter1[0, :3, :5])

model.set_adapter("adapter2")
output_adapter2 = model(input).logits
print("Model output after setting adapter2:")
print(output_adapter2[0, :3, :5])

This prints:

Base model output:
tensor([[-3.9463, -3.9443,  3.2428, -3.9522,  5.4978],
        [-3.7805, -3.7759,  5.7177, -3.7743,  4.9581],
        [ 2.1029,  2.1002,  1.9693,  2.0843,  3.4022]],
       grad_fn=<SliceBackward0>)
Model output after loading adapter1:
tensor([[-1.4193, -1.4301,  2.9313, -1.4266,  6.9664],
        [-4.2108, -4.2206,  3.1630, -4.2111,  5.8416],
        [-3.3278, -3.3351,  0.0213, -3.3350,  5.0806]],
       grad_fn=<SliceBackward0>)
Model output after setting adapter2:
tensor([[-3.8967, -3.8936,  5.2991, -3.9072,  4.7403],
        [-5.8532, -5.8452,  7.2219, -5.8641,  4.4519],
        [-4.6259, -4.6206,  4.6002, -4.6405,  5.2777]],
       grad_fn=<SliceBackward0>)

Note that when you want to compare model outputs, looking at the generated tokens is not reliable. When the difference in logits is small, the generated tokens can be the same, even if the outputs are different. Therefore, it's better to check the logits directly, as in my example.

Copy link

github-actions bot commented Jul 6, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants