Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[utils] add merge_lora utility function #227

Merged
merged 16 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 45 additions & 2 deletions src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class LoraConfig(PeftConfig):
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
},
)
init_lora_weights: bool = field(
default=True,
metadata={"help": "Whether to initialize the weights of the Lora layers."},
)

def __post_init__(self):
self.peft_type = PeftType.LORA
Expand Down Expand Up @@ -135,6 +139,7 @@ def _find_and_replace(self):
"fan_in_fan_out": self.peft_config.fan_in_fan_out,
"merge_weights": (self.peft_config.merge_weights or self.peft_config.inference_mode)
and not is_hf_device_map_available,
"init_lora_weights": self.peft_config.init_lora_weights,
}
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
Expand Down Expand Up @@ -233,6 +238,37 @@ def enable_adapter_layers(self):
def disable_adapter_layers(self):
self._set_adapter_layers(enabled=False)

def merge_and_unload(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not work when the model is loaded in 8bit. An assertion error for the same would be helpful.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer! added a check for that

r"""
This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model
as a standalone model.
"""
if self.config.model_type == "gpt2":
raise ValueError("GPT2 models are not supported for merging LORA layers")

if getattr(self.model, "is_loaded_in_8bit", False):
raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode")

key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
for key in key_list:
parent, target, target_name = self._get_submodules(key)
if isinstance(target, LoraLayer):
bias = target.bias is not None
new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)

# manually merge if not merged
if not target.merged:
# merge weights per: https://arxiv.org/pdf/2106.09685.pdf / page 4
if target.r > 0:
target.weight.data += (
transpose(target.lora_B.weight @ target.lora_A.weight, target.fan_in_fan_out)
* target.scaling
)
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
target.merged = True

self._replace_module(parent, target_name, new_module, target)
return self.model


# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# and modified to work with PyTorch FSDP
Expand Down Expand Up @@ -297,6 +333,8 @@ def __init__(
merge_weights: bool = True,
**kwargs,
):
init_lora_weights = kwargs.pop("init_lora_weights", True)

nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)

Expand All @@ -308,7 +346,8 @@ def __init__(
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if init_lora_weights:
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T

Expand Down Expand Up @@ -375,6 +414,8 @@ def __init__(
merge_weights: bool = True,
**kwargs,
):
init_lora_weights = kwargs.pop("init_lora_weights", True)

nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
if out_features % len(enable_lora) != 0:
Expand All @@ -398,7 +439,9 @@ def __init__(
self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()

if init_lora_weights:
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T

Expand Down
83 changes: 68 additions & 15 deletions tests/test_peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,19 @@
from .testing_common import PeftTestConfigManager


# This has to be in the order: model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, prompt_tuning_kwargs
PEFT_DECODER_MODELS_TO_TEST = [
# ("HuggingFaceM4/tiny-random-LlamaForCausalLM", {}, {}, {}, {}), wait until the next `transformers` release
("hf-internal-testing/tiny-random-OPTForCausalLM", {}, {}, {}, {}),
("hf-internal-testing/tiny-random-GPTNeoXForCausalLM", {}, {}, {}, {}),
("hf-internal-testing/tiny-random-GPT2LMHeadModel", {}, {}, {}, {}),
("hf-internal-testing/tiny-random-BloomForCausalLM", {}, {}, {}, {}),
("hf-internal-testing/tiny-random-gpt_neo", {}, {}, {}, {}),
("hf-internal-testing/tiny-random-GPTJForCausalLM", {}, {}, {}, {}),
"hf-internal-testing/tiny-random-OPTForCausalLM",
"hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"hf-internal-testing/tiny-random-GPT2LMHeadModel",
"hf-internal-testing/tiny-random-BloomForCausalLM",
"hf-internal-testing/tiny-random-gpt_neo",
"hf-internal-testing/tiny-random-GPTJForCausalLM",
]

FULL_GRID = {
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
}


class PeftTestMixin:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -54,10 +56,6 @@ class PeftModelTester(unittest.TestCase, PeftTestMixin):
We use parametrized.expand for debugging purposes to test each model individually.
"""

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST))
def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs):
self._test_model_attr(model_id, config_cls, config_kwargs)

def _test_model_attr(self, model_id, config_cls, config_kwargs):
model = AutoModelForCausalLM.from_pretrained(model_id)
config = config_cls(
Expand All @@ -70,6 +68,10 @@ def _test_model_attr(self, model_id, config_cls, config_kwargs):
self.assertTrue(hasattr(model, "from_pretrained"))
self.assertTrue(hasattr(model, "push_to_hub"))

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs):
self._test_model_attr(model_id, config_cls, config_kwargs)

def _test_prepare_for_training(self, model_id, config_cls, config_kwargs):
model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device)
config = config_cls(
Expand Down Expand Up @@ -111,7 +113,7 @@ def make_inputs_require_grad(module, input, output):

self.assertTrue(dummy_output.requires_grad)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST))
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs):
self._test_prepare_for_training(model_id, config_cls, config_kwargs)

Expand Down Expand Up @@ -157,10 +159,61 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs):
# check if `config.json` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST))
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs)

def _test_merge_layers(self, model_id, config_cls, config_kwargs):
model = AutoModelForCausalLM.from_pretrained(model_id)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model = model.to(self.torch_device)

if config.peft_type != "LORA":
with self.assertRaises(AttributeError):
model = model.merge_and_unload()
elif model.config.model_type == "gpt2":
with self.assertRaises(ValueError):
model = model.merge_and_unload()
else:
dummy_input = torch.LongTensor([[1, 2, 3, 2, 1]]).to(self.torch_device)
model.eval()
logits_lora = model(dummy_input)[0]

model = model.merge_and_unload()

logits_merged = model(dummy_input)[0]

transformers_model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device)

logits_transformers = transformers_model(dummy_input)[0]

self.assertTrue(torch.allclose(logits_lora, logits_merged, atol=1e-3, rtol=1e-3))
self.assertFalse(torch.allclose(logits_merged, logits_transformers, atol=1e-3, rtol=1e-3))

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)

model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmp_dirname).to(self.torch_device)

logits_merged_from_pretrained = model_from_pretrained(dummy_input)[0]

self.assertTrue(torch.allclose(logits_merged, logits_merged_from_pretrained, atol=1e-3, rtol=1e-3))

@parameterized.expand(
PeftTestConfigManager.get_grid_parameters(
{
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False], "merge_weights": [False, True]},
},
)
)
def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
self._test_merge_layers(model_id, config_cls, config_kwargs)

def _test_generate(self, model_id, config_cls, config_kwargs):
model = AutoModelForCausalLM.from_pretrained(model_id)
config = config_cls(
Expand All @@ -180,6 +233,6 @@ def _test_generate(self, model_id, config_cls, config_kwargs):
# check if `generate` raises an error if no positional arguments are passed
_ = model.generate(input_ids, attention_mask=attention_mask)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST))
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_generate(self, test_name, model_id, config_cls, config_kwargs):
self._test_generate(model_id, config_cls, config_kwargs)
57 changes: 35 additions & 22 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,34 +71,47 @@ def __getitem__(self, key, *args, **kwargs):

return super().__getitem__(key, *args, **kwargs)

def get_grid_parameters(self, model_list):
def get_grid_parameters(self, grid_parameters, filter_params_func=None):
r"""
Returns a list of all possible combinations of the parameters in the config classes.

Args:
grid_parameters (`dict`):
A dictionary containing the parameters to be tested. There should be at least the key "model_ids" which
contains a list of model ids to be tested. The other keys should be the name of the config class
post-fixed with "_kwargs" and the value should be a dictionary containing the parameters to be tested
for that config class.
filter_params_func (`callable`, `optional`):
A function that takes a list of tuples and returns a list of tuples. This function is used to filter
out the tests that needs for example to be skipped.

Returns:
generated_tests (`list`):
A list of tuples containing the name of the test, the model id, the config class and the config class
kwargs.
"""
grid_parameters = []
for model_tuple in model_list:
model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, prompt_tuning_kwargs = model_tuple
generated_tests = []
model_list = grid_parameters["model_ids"]

for model_id in model_list:
for key, value in self.items():
peft_method = value[1].copy()
if key == "lora":
# update value[1] if necessary
if lora_kwargs is not None:
peft_method.update(lora_kwargs)
elif key == "prefix_tuning":
# update value[1] if necessary
if prefix_tuning_kwargs is not None:
peft_method.update(prefix_tuning_kwargs)
elif key == "prompt_encoder":
# update value[1] if necessary
if prompt_encoder_kwargs is not None:
peft_method.update(prompt_encoder_kwargs)
if "{}_kwargs".format(key) in grid_parameters:
peft_configs = []
current_peft_config = value[1].copy()
for current_key, current_value in grid_parameters[f"{key}_kwargs"].items():
for kwarg in current_value:
current_peft_config.update({current_key: kwarg})
peft_configs.append(current_peft_config)
else:
# update value[1] if necessary
if prompt_tuning_kwargs is not None:
peft_method.update(prompt_tuning_kwargs)
grid_parameters.append((f"test_{model_id}_{key}", model_id, value[0], peft_method))
peft_configs = [value[1].copy()]

for peft_config in peft_configs:
generated_tests.append((f"test_{model_id}_{key}", model_id, value[0], peft_config))

if filter_params_func is not None:
generated_tests = filter_params_func(generated_tests)

return grid_parameters
return generated_tests


PeftTestConfigManager = ClassInstantier(CLASSES_MAPPING)