Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significant difference in size in lora artifacts between different versions of trl and peft #1287

Closed
debraj135 opened this issue Jan 30, 2024 · 13 comments

Comments

@debraj135
Copy link

For the same peft config and same based model previously my saved lora adapter files seemed to have changed in size by orders of magnitude:

  1. Previously: 79 MB adapter_model.bin
    trl '0.7.3.dev0'
    peft '0.5.0'

  2. Currently: 1GB an adapter safetensors file
    trl '0.7.10'
    peft '0.7.1

Can anyone please help me understand the difference in behavior?

@debraj135
Copy link
Author

To add some more context I see that in point 2 above upon passing save_safetensors=False it no longer saves safetensors file and instead saves
README.md adapter_config.json adapter_model.bin optimizer.pt rng_state_0.pth rng_state_1.pth rng_state_2.pth rng_state_3.pth scheduler.pt special_tokens_map.json tokenizer.json tokenizer_config.json trainer_state.json training_args.bin
however the adapter_model.bin is 1.1GB in size

@younesbelkada
Copy link
Contributor

@debraj135 hmm that's strange: I ran:

from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model 

peft_config = LoraConfig(
    r=8
)

model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
peft_model = get_peft_model(model, peft_config)
peft_model.save_pretrained("./test-save")

and I can confirm the saved folder has ~1.2 MB which is expected. Can you try with trl '0.7.3.dev0' and peft==0.8.2 to confirm the bug is from TRL?

@younesbelkada
Copy link
Contributor

I also ran:

from peft import LoraConfig 
from trl import SFTTrainer
from datasets import load_dataset

peft_config = LoraConfig(
    r=8
)

dataset = load_dataset("imdb", split="train")

trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
    peft_config=peft_config
)
trainer.save_model("test-save")

on TRL main + peft latest (0.8.2) and still not observing strange behaviour (saved model is ~6MB which should be expected for a 350m model)

@debraj135
Copy link
Author

debraj135 commented Jan 31, 2024

Thank you for your response @younesbelkada

Here is my skeleton code for saving the adapters

import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
        load_in_8bit=False, load_in_4bit=True)
device_map = {"": 0}
torch_dtype = torch.bfloat16


model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    quantization_config=quantization_config,
    device_map=device_map,
    trust_remote_code=False,
    torch_dtype=torch_dtype,
    use_auth_token=True,
)

peft_config = LoraConfig(
        r=8,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj","embed_tokens",
                        "lm_head"],
        task_type="CAUSAL_LM",
    )

peft_model = get_peft_model(model, peft_config)
peft_model.save_pretrained("./test-save", safe_serialization=False)

with peft==0.5.0 the adapter_model.bin is 78MB
with peft==0.8.1 the adapter_model.bin is 578MB

Please let me know if you also see the same. Seems to be a peft issue?

Do you have any thoughts about why this seems to be occurring?

@debraj135
Copy link
Author

debraj135 commented Feb 1, 2024

If you want me to try all the releases between peft 0.5.0 and 0.8.1, please let me know.

I'm not familiar with the changes over time and if the above behavior is unexpected.

Update: Feb 1st Thursday

It appears that this bump appears in going from 0.6.2 to 0.7.0

@debraj135
Copy link
Author

@younesbelkada please do let me know what you think about my observation.

@younesbelkada
Copy link
Contributor

@debraj135 it is definitely a peft issue, I suspect we're hitting a corner case with the recent refactor of peft layers since we now store base layers as well. Do you observe the same behaviour if you remove "embed_tokens", "lm_head" from the LoRAConfig?

@BenjaminBossan
Copy link
Member

I wonder if this is related to the recent addition of saving the embedding weights automatically when we detect that they're being trained. Those can be quite big. If possible @debraj135, could you check the state_dict from before vs after and report the keys that differ between the two.

@debraj135
Copy link
Author

@younesbelkada @BenjaminBossan I will report back later today once I try what both of you have requested.

@BenjaminBossan can you help me understand the impact of :

the recent addition of saving the embedding weights automatically when we detect that they're being trained

Does this mean that it is not correct to include "embed_tokens", "lm_head" in the target_modules list in peft<0.7.0 ?

@debraj135
Copy link
Author

Here is what I see

  1. @younesbelkada the same behavior is not observed after removing "embed_tokens", "lm_head"

  2. @BenjaminBossan I used get_peft_model_state_dict to obtain the state_dict associated with lora and printed the state_dict for only those keys containing "embed_tokens", "lm_head".

With peft==0.6.2

base_model.model.model.embed_tokens.lora_embedding_A torch.Size([8, 32000])
base_model.model.model.embed_tokens.lora_embedding_B torch.Size([4096, 8])
base_model.model.lm_head.lora_A.weight torch.Size([8, 4096])
base_model.model.lm_head.lora_B.weight torch.Size([32000, 8])

With peft==0.7.0

base_model.model.model.embed_tokens.lora_embedding_A torch.Size([8, 32000])
base_model.model.model.embed_tokens.lora_embedding_B torch.Size([4096, 8])
base_model.model.lm_head.lora_A.weight torch.Size([8, 4096])
base_model.model.lm_head.lora_B.weight torch.Size([32000, 8])
base_model.model.model.embed_tokens.base_layer.weight torch.Size([32000, 4096])
base_model.model.lm_head.base_layer.weight torch.Size([32000, 4096])

I think my questions still remains unanswered after these observations:

Does this mean that it is not correct to include "embed_tokens", "lm_head" in the target_modules list in peft<0.7.0 ?

@debraj135
Copy link
Author

One of my key concerns here is that with peft<=0.6.2 , after I finish training, and then save the peft model as shown here, is it possible that when I reload it for serving, there are layers/weights missing (specifically ones related to embeddings) and as a result the reloaded model is different from the one trained?

@pacman100
Copy link
Contributor

Hello @debraj135,

Now, the default behaviour in PEFT is to save the embedding layers of the base model when they are part of target modules or if the embedding layers have been resized. This is because the most common scenario when embedding layers are resized or targetted happens when new special tokens are added to the vocab. Now, these embedding layers will have the new token weights initialized randomly and the LoRA weights are tuned wrt the specific initialization. Hence, it is important to save the specific initialization of the embedding layers else during inference the resized embedding layers can be initialized with different random weights and the LoRA weights will no longer be in sync and will lead to undefined behaviour.

If you don't want to save the embedding layers, then pass save_embedding_layers=False to save_pretrained/get_peft_model_state_dict

Copy link

github-actions bot commented Mar 2, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants