From 4289364941ffab7e964f26b2d0cb45f29a92ecfb Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 29 Mar 2023 08:11:09 +0000 Subject: [PATCH 01/12] add merge_lora utility function --- src/peft/__init__.py | 1 + src/peft/utils/__init__.py | 1 + src/peft/utils/other.py | 31 +++++++++++++++++++++++ tests/test_peft_model.py | 51 ++++++++++++++++++++++++++++++++++---- 4 files changed, 79 insertions(+), 5 deletions(-) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index e141347f1e..0e61a61c82 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -50,4 +50,5 @@ prepare_model_for_int8_training, set_peft_model_state_dict, shift_tokens_right, + merge_lora, ) diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index dd949c0387..decfdc4281 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -26,5 +26,6 @@ prepare_model_for_int8_training, shift_tokens_right, transpose, + merge_lora, ) from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 132b033484..69711dd8e4 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -15,6 +15,8 @@ import torch +import peft + # needed for prefix-tuning of bloom model def bloom_model_postprocess_past_key_value(past_key_values): @@ -157,3 +159,32 @@ def lambda_policy_fn(module): def transpose(weight, fan_in_fan_out): return weight.T if fan_in_fan_out else weight + + +def merge_lora(model): + r""" + This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model as a + standalone model. + + Args: + model, (`PeftModel`): + The input PeftModel that needs to be merged + """ + if not isinstance(model, peft.PeftModel): + raise ValueError("The input model should be a PeftModel") + + if model.peft_config.peft_type != "LORA": + raise ValueError("The input model should be a LORA model") + + model.eval() + + key_list = [key for key, _ in model.base_model.model.named_modules() if "lora" not in key] + for key in key_list: + parent, target, target_name = model.base_model._get_submodules(key) + if isinstance(target, peft.tuners.lora.Linear): + bias = target.bias is not None + new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias) + model.base_model._replace_module(parent, target_name, new_module, target) + + model = model.base_model.model + return model diff --git a/tests/test_peft_model.py b/tests/test_peft_model.py index 2ca4895dd6..3310e288bb 100644 --- a/tests/test_peft_model.py +++ b/tests/test_peft_model.py @@ -24,6 +24,7 @@ PeftModel, get_peft_model, get_peft_model_state_dict, + merge_lora, prepare_model_for_int8_training, ) @@ -32,7 +33,7 @@ # This has to be in the order: model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, prompt_tuning_kwargs PEFT_MODELS_TO_TEST = [ - ("hf-internal-testing/tiny-random-OPTForCausalLM", {"target_modules": ["q_proj", "v_proj"]}, {}, {}, {}), + ("hf-internal-testing/tiny-random-OPTForCausalLM", {"bias": "all"}, {}, {}, {}), ] @@ -48,10 +49,6 @@ class PeftModelTester(unittest.TestCase, PeftTestMixin): We use parametrized.expand for debugging purposes to test each model individually. """ - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_MODELS_TO_TEST)) - def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs): - self._test_model_attr(model_id, config_cls, config_kwargs) - def _test_model_attr(self, model_id, config_cls, config_kwargs): model = AutoModelForCausalLM.from_pretrained(model_id) config = config_cls( @@ -64,6 +61,10 @@ def _test_model_attr(self, model_id, config_cls, config_kwargs): self.assertTrue(hasattr(model, "from_pretrained")) self.assertTrue(hasattr(model, "push_to_hub")) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_MODELS_TO_TEST)) + def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs): + self._test_model_attr(model_id, config_cls, config_kwargs) + def _test_prepare_for_training(self, model_id, config_cls, config_kwargs): model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device) config = config_cls( @@ -154,3 +155,43 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs): @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_MODELS_TO_TEST)) def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs) + + def _test_merge_layers(self, model_id, config_cls, config_kwargs): + model = AutoModelForCausalLM.from_pretrained(model_id) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + if config.peft_type != "LORA": + with self.assertRaises(ValueError): + merged_model = merge_lora(model) + else: + dummy_input = torch.LongTensor([[1, 2, 3, 2, 1]]).to(self.torch_device) + logits_lora = model(dummy_input)[0] + + model = merge_lora(model) + + logits_merged = model(dummy_input)[0] + + transformers_model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device) + + logits_transformers = transformers_model(dummy_input)[0] + + self.assertTrue(torch.allclose(logits_lora, logits_merged, atol=1e-3, rtol=1e-3)) + self.assertFalse(torch.allclose(logits_merged, logits_transformers, atol=1e-3, rtol=1e-3)) + + with tempfile.TemporaryDirectory() as tmp_dirname: + merged_model.save_pretrained(tmp_dirname) + + model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmp_dirname) + + logits_merged_from_pretrained = model_from_pretrained(dummy_input)[0] + + self.assertTrue(torch.allclose(logits_merged, logits_merged_from_pretrained, atol=1e-3, rtol=1e-3)) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_MODELS_TO_TEST)) + def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): + self._test_merge_layers(model_id, config_cls, config_kwargs) From 2cf4c99d8c9dd6ba74b710486a93f9b4ea8463a2 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Wed, 29 Mar 2023 08:15:31 +0000 Subject: [PATCH 02/12] forward contrib credits from original script From 01bd43e9400d489dae791ac21b5a7fbd7db7ce3b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 29 Mar 2023 09:36:35 +0000 Subject: [PATCH 03/12] some changes --- src/peft/tuners/lora.py | 10 +++++++++- src/peft/utils/other.py | 6 ++++++ tests/test_peft_model.py | 8 +++++--- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index 0f65cbf55c..f505d78f52 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -82,6 +82,10 @@ class LoraConfig(PeftConfig): "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." }, ) + init_lora_weights: bool = field( + default=True, + metadata={"help": "Whether to initialize the weights of the Lora layers."}, + ) def __post_init__(self): self.peft_type = PeftType.LORA @@ -135,6 +139,7 @@ def _find_and_replace(self): "fan_in_fan_out": self.peft_config.fan_in_fan_out, "merge_weights": (self.peft_config.merge_weights or self.peft_config.inference_mode) and not is_hf_device_map_available, + "init_lora_weights": self.peft_config.init_lora_weights, } key_list = [key for key, _ in self.model.named_modules()] for key in key_list: @@ -297,6 +302,8 @@ def __init__( merge_weights: bool = True, **kwargs, ): + init_lora_weights = kwargs.pop("init_lora_weights", True) + nn.Linear.__init__(self, in_features, out_features, **kwargs) LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) @@ -308,7 +315,8 @@ def __init__( self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False - self.reset_parameters() + if init_lora_weights: + self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.T diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 69711dd8e4..a4098eb5a3 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -184,6 +184,12 @@ def merge_lora(model): if isinstance(target, peft.tuners.lora.Linear): bias = target.bias is not None new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias) + + # manually merge if not merged + if not target.merged: + target.merge_weights = True + target.train(False) + model.base_model._replace_module(parent, target_name, new_module, target) model = model.base_model.model diff --git a/tests/test_peft_model.py b/tests/test_peft_model.py index 3310e288bb..8b970ad32e 100644 --- a/tests/test_peft_model.py +++ b/tests/test_peft_model.py @@ -33,7 +33,8 @@ # This has to be in the order: model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, prompt_tuning_kwargs PEFT_MODELS_TO_TEST = [ - ("hf-internal-testing/tiny-random-OPTForCausalLM", {"bias": "all"}, {}, {}, {}), + ("hf-internal-testing/tiny-random-OPTForCausalLM", {"init_lora_weights":False}, {}, {}, {}), + ("hf-internal-testing/tiny-random-OPTForCausalLM", {"merge_weights":True}, {}, {}, {}), ] @@ -170,6 +171,7 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): merged_model = merge_lora(model) else: dummy_input = torch.LongTensor([[1, 2, 3, 2, 1]]).to(self.torch_device) + model.eval() logits_lora = model(dummy_input)[0] model = merge_lora(model) @@ -184,9 +186,9 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): self.assertFalse(torch.allclose(logits_merged, logits_transformers, atol=1e-3, rtol=1e-3)) with tempfile.TemporaryDirectory() as tmp_dirname: - merged_model.save_pretrained(tmp_dirname) + model.save_pretrained(tmp_dirname) - model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmp_dirname) + model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmp_dirname).to(self.torch_device) logits_merged_from_pretrained = model_from_pretrained(dummy_input)[0] From 2dc3610516a225c8adeb45bfadbbf2ff14cf6513 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 29 Mar 2023 09:36:49 +0000 Subject: [PATCH 04/12] make style --- tests/test_peft_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_peft_model.py b/tests/test_peft_model.py index 8b970ad32e..c1ab700a0a 100644 --- a/tests/test_peft_model.py +++ b/tests/test_peft_model.py @@ -33,8 +33,8 @@ # This has to be in the order: model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, prompt_tuning_kwargs PEFT_MODELS_TO_TEST = [ - ("hf-internal-testing/tiny-random-OPTForCausalLM", {"init_lora_weights":False}, {}, {}, {}), - ("hf-internal-testing/tiny-random-OPTForCausalLM", {"merge_weights":True}, {}, {}, {}), + ("hf-internal-testing/tiny-random-OPTForCausalLM", {"init_lora_weights": False}, {}, {}, {}), + ("hf-internal-testing/tiny-random-OPTForCausalLM", {"merge_weights": True}, {}, {}, {}), ] @@ -168,7 +168,7 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): if config.peft_type != "LORA": with self.assertRaises(ValueError): - merged_model = merge_lora(model) + merge_lora(model) else: dummy_input = torch.LongTensor([[1, 2, 3, 2, 1]]).to(self.torch_device) model.eval() From 2726d51f7e6db5353dbccdf37f87034bed23047e Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 29 Mar 2023 17:34:41 +0000 Subject: [PATCH 05/12] fix tets --- src/peft/tuners/lora.py | 6 +++++- tests/test_peft_model.py | 41 +++++++++++++++++++++++----------------- tests/testing_common.py | 38 +++++++++++++++---------------------- 3 files changed, 44 insertions(+), 41 deletions(-) diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index f505d78f52..e26f655f1c 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -383,6 +383,8 @@ def __init__( merge_weights: bool = True, **kwargs, ): + init_lora_weights = kwargs.pop("init_lora_weights", True) + nn.Linear.__init__(self, in_features, out_features, **kwargs) LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) if out_features % len(enable_lora) != 0: @@ -406,7 +408,9 @@ def __init__( self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1) self.lora_ind[enable_lora, :] = True self.lora_ind = self.lora_ind.view(-1) - self.reset_parameters() + + if init_lora_weights: + self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.T diff --git a/tests/test_peft_model.py b/tests/test_peft_model.py index f5d6f04623..30a6311d86 100644 --- a/tests/test_peft_model.py +++ b/tests/test_peft_model.py @@ -31,19 +31,20 @@ from .testing_common import PeftTestConfigManager -# This has to be in the order: model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, prompt_tuning_kwargs +# This has to be in the order: model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, PEFT_DECODER_MODELS_TO_TEST = [ - # ("HuggingFaceM4/tiny-random-LlamaForCausalLM", {}, {}, {}, {}), wait until the next `transformers` release - ("hf-internal-testing/tiny-random-OPTForCausalLM", {"init_lora_weights": False}, {}, {}, {}), - ("hf-internal-testing/tiny-random-OPTForCausalLM", {"merge_weights": True}, {}, {}, {}), - ("hf-internal-testing/tiny-random-OPTForCausalLM", {}, {}, {}, {}), - ("hf-internal-testing/tiny-random-GPTNeoXForCausalLM", {}, {}, {}, {}), - ("hf-internal-testing/tiny-random-GPT2LMHeadModel", {}, {}, {}, {}), - ("hf-internal-testing/tiny-random-BloomForCausalLM", {}, {}, {}, {}), - ("hf-internal-testing/tiny-random-gpt_neo", {}, {}, {}, {}), - ("hf-internal-testing/tiny-random-GPTJForCausalLM", {}, {}, {}, {}), + "hf-internal-testing/tiny-random-OPTForCausalLM", + "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", + "hf-internal-testing/tiny-random-GPT2LMHeadModel", + "hf-internal-testing/tiny-random-BloomForCausalLM", + "hf-internal-testing/tiny-random-gpt_neo", + "hf-internal-testing/tiny-random-GPTJForCausalLM", ] +FULL_GRID = { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, +} + class PeftTestMixin: torch_device = "cuda" if torch.cuda.is_available() else "cpu" @@ -68,8 +69,8 @@ def _test_model_attr(self, model_id, config_cls, config_kwargs): self.assertTrue(hasattr(model, "save_pretrained")) self.assertTrue(hasattr(model, "from_pretrained")) self.assertTrue(hasattr(model, "push_to_hub")) - - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST)) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs): self._test_model_attr(model_id, config_cls, config_kwargs) @@ -114,7 +115,7 @@ def make_inputs_require_grad(module, input, output): self.assertTrue(dummy_output.requires_grad) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs): self._test_prepare_for_training(model_id, config_cls, config_kwargs) @@ -160,7 +161,7 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs): # check if `config.json` is not present self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs) @@ -201,11 +202,10 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): self.assertTrue(torch.allclose(logits_merged, logits_merged_from_pretrained, atol=1e-3, rtol=1e-3)) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_MODELS_TO_TEST)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): self._test_merge_layers(model_id, config_cls, config_kwargs) - def _test_generate(self, model_id, config_cls, config_kwargs): model = AutoModelForCausalLM.from_pretrained(model_id) config = config_cls( @@ -225,6 +225,13 @@ def _test_generate(self, model_id, config_cls, config_kwargs): # check if `generate` raises an error if no positional arguments are passed _ = model.generate(input_ids, attention_mask=attention_mask) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(PEFT_DECODER_MODELS_TO_TEST)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False], "merge_lora": [False, True]}, + } + ) + ) def test_generate(self, test_name, model_id, config_cls, config_kwargs): self._test_generate(model_id, config_cls, config_kwargs) diff --git a/tests/testing_common.py b/tests/testing_common.py index 96c0fdb476..bea93a1c45 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -71,34 +71,26 @@ def __getitem__(self, key, *args, **kwargs): return super().__getitem__(key, *args, **kwargs) - def get_grid_parameters(self, model_list): + def get_grid_parameters(self, grid_parameters): r""" Returns a list of all possible combinations of the parameters in the config classes. """ - grid_parameters = [] - for model_tuple in model_list: - model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, prompt_tuning_kwargs = model_tuple + generated_tests = [] + model_list = grid_parameters["model_ids"] + + for model_id in model_list: for key, value in self.items(): - peft_method = value[1].copy() - if key == "lora": - # update value[1] if necessary - if lora_kwargs is not None: - peft_method.update(lora_kwargs) - elif key == "prefix_tuning": - # update value[1] if necessary - if prefix_tuning_kwargs is not None: - peft_method.update(prefix_tuning_kwargs) - elif key == "prompt_encoder": - # update value[1] if necessary - if prompt_encoder_kwargs is not None: - peft_method.update(prompt_encoder_kwargs) - else: - # update value[1] if necessary - if prompt_tuning_kwargs is not None: - peft_method.update(prompt_tuning_kwargs) - grid_parameters.append((f"test_{model_id}_{key}", model_id, value[0], peft_method)) + peft_methods = [value[1].copy()] + + if "{}_kwargs".format(key) in grid_parameters: + for current_key, current_value in grid_parameters[f"{key}_kwargs"].items(): + for kwarg in current_value: + peft_methods.append(value[1].copy().update({current_key: kwarg})) + + for peft_method in peft_methods: + generated_tests.append((f"test_{model_id}_{key}", model_id, value[0], peft_method)) - return grid_parameters + return generated_tests PeftTestConfigManager = ClassInstantier(CLASSES_MAPPING) From 70fa615dfdfdcb363d5c4b2ca456168ae86e552d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 29 Mar 2023 18:09:33 +0000 Subject: [PATCH 06/12] finally fix tests --- src/peft/utils/other.py | 6 +++++- tests/test_peft_model.py | 20 ++++++++++---------- tests/testing_common.py | 33 +++++++++++++++++++++++++++------ 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index a4098eb5a3..c04930def4 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -176,12 +176,15 @@ def merge_lora(model): if model.peft_config.peft_type != "LORA": raise ValueError("The input model should be a LORA model") + if model.config.model_type == "gpt2": + raise ValueError("GPT2 models are not supported for merging LORA layers") + model.eval() key_list = [key for key, _ in model.base_model.model.named_modules() if "lora" not in key] for key in key_list: parent, target, target_name = model.base_model._get_submodules(key) - if isinstance(target, peft.tuners.lora.Linear): + if isinstance(target, (peft.tuners.lora.Linear, peft.tuners.lora.MergedLinear)): bias = target.bias is not None new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias) @@ -191,6 +194,7 @@ def merge_lora(model): target.train(False) model.base_model._replace_module(parent, target_name, new_module, target) + # elif isinstance(target, peft.tuners.lora.MergedLinear): model = model.base_model.model return model diff --git a/tests/test_peft_model.py b/tests/test_peft_model.py index 30a6311d86..672a69e3e2 100644 --- a/tests/test_peft_model.py +++ b/tests/test_peft_model.py @@ -174,7 +174,7 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): model = get_peft_model(model, config) model = model.to(self.torch_device) - if config.peft_type != "LORA": + if config.peft_type != "LORA" or model.config.model_type == "gpt2": with self.assertRaises(ValueError): merge_lora(model) else: @@ -202,7 +202,14 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): self.assertTrue(torch.allclose(logits_merged, logits_merged_from_pretrained, atol=1e-3, rtol=1e-3)) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False], "merge_weights": [False, True]}, + }, + ) + ) def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): self._test_merge_layers(model_id, config_cls, config_kwargs) @@ -225,13 +232,6 @@ def _test_generate(self, model_id, config_cls, config_kwargs): # check if `generate` raises an error if no positional arguments are passed _ = model.generate(input_ids, attention_mask=attention_mask) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters( - { - "model_ids": PEFT_DECODER_MODELS_TO_TEST, - "lora_kwargs": {"init_lora_weights": [False], "merge_lora": [False, True]}, - } - ) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_generate(self, test_name, model_id, config_cls, config_kwargs): self._test_generate(model_id, config_cls, config_kwargs) diff --git a/tests/testing_common.py b/tests/testing_common.py index bea93a1c45..633bb87688 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -71,24 +71,45 @@ def __getitem__(self, key, *args, **kwargs): return super().__getitem__(key, *args, **kwargs) - def get_grid_parameters(self, grid_parameters): + def get_grid_parameters(self, grid_parameters, filter_params_func=None): r""" Returns a list of all possible combinations of the parameters in the config classes. + + Args: + grid_parameters (`dict`): + A dictionary containing the parameters to be tested. There should be at least the key "model_ids" which + contains a list of model ids to be tested. The other keys should be the name of the config class + post-fixed with "_kwargs" and the value should be a dictionary containing the parameters to be tested + for that config class. + filter_params_func (`callable`, `optional`): + A function that takes a list of tuples and returns a list of tuples. This function is used to filter + out the tests that needs for example to be skipped. + + Returns: + generated_tests (`list`): + A list of tuples containing the name of the test, the model id, the config class and the config class + kwargs. """ generated_tests = [] model_list = grid_parameters["model_ids"] for model_id in model_list: for key, value in self.items(): - peft_methods = [value[1].copy()] - if "{}_kwargs".format(key) in grid_parameters: + peft_configs = [] + current_peft_config = value[1].copy() for current_key, current_value in grid_parameters[f"{key}_kwargs"].items(): for kwarg in current_value: - peft_methods.append(value[1].copy().update({current_key: kwarg})) + current_peft_config.update({current_key: kwarg}) + peft_configs.append(current_peft_config) + else: + peft_configs = [value[1].copy()] + + for peft_config in peft_configs: + generated_tests.append((f"test_{model_id}_{key}", model_id, value[0], peft_config)) - for peft_method in peft_methods: - generated_tests.append((f"test_{model_id}_{key}", model_id, value[0], peft_method)) + if filter_params_func is not None: + generated_tests = filter_params_func(generated_tests) return generated_tests From 8bb6989f0a063cea0ed945e9909773867fbe99d8 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 29 Mar 2023 21:24:57 +0200 Subject: [PATCH 07/12] Update tests/test_peft_model.py --- tests/test_peft_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_peft_model.py b/tests/test_peft_model.py index 672a69e3e2..2bd89864c2 100644 --- a/tests/test_peft_model.py +++ b/tests/test_peft_model.py @@ -31,7 +31,6 @@ from .testing_common import PeftTestConfigManager -# This has to be in the order: model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, PEFT_DECODER_MODELS_TO_TEST = [ "hf-internal-testing/tiny-random-OPTForCausalLM", "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", From c1a876fa44b4e279846f238ec1894924e56189d1 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 30 Mar 2023 10:51:33 +0000 Subject: [PATCH 08/12] adapt from suggestions --- src/peft/__init__.py | 1 - src/peft/tuners/lora.py | 23 +++++++++++++++++++++ src/peft/utils/__init__.py | 1 - src/peft/utils/other.py | 41 -------------------------------------- tests/test_peft_model.py | 10 ++++++---- 5 files changed, 29 insertions(+), 47 deletions(-) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 0e61a61c82..e141347f1e 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -50,5 +50,4 @@ prepare_model_for_int8_training, set_peft_model_state_dict, shift_tokens_right, - merge_lora, ) diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index e26f655f1c..704e58659f 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -238,6 +238,29 @@ def enable_adapter_layers(self): def disable_adapter_layers(self): self._set_adapter_layers(enabled=False) + def merge_and_unload(self): + r""" + This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model + as a standalone model. + """ + if self.config.model_type == "gpt2": + raise ValueError("GPT2 models are not supported for merging LORA layers") + + key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] + for key in key_list: + parent, target, target_name = self._get_submodules(key) + if isinstance(target, (Linear, MergedLinear)): + bias = target.bias is not None + new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias) + + # manually merge if not merged + if not target.merged: + target.merge_weights = True + target.train(False) + + self._replace_module(parent, target_name, new_module, target) + return self.model + # Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py # and modified to work with PyTorch FSDP diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index decfdc4281..dd949c0387 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -26,6 +26,5 @@ prepare_model_for_int8_training, shift_tokens_right, transpose, - merge_lora, ) from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index c04930def4..132b033484 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -15,8 +15,6 @@ import torch -import peft - # needed for prefix-tuning of bloom model def bloom_model_postprocess_past_key_value(past_key_values): @@ -159,42 +157,3 @@ def lambda_policy_fn(module): def transpose(weight, fan_in_fan_out): return weight.T if fan_in_fan_out else weight - - -def merge_lora(model): - r""" - This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model as a - standalone model. - - Args: - model, (`PeftModel`): - The input PeftModel that needs to be merged - """ - if not isinstance(model, peft.PeftModel): - raise ValueError("The input model should be a PeftModel") - - if model.peft_config.peft_type != "LORA": - raise ValueError("The input model should be a LORA model") - - if model.config.model_type == "gpt2": - raise ValueError("GPT2 models are not supported for merging LORA layers") - - model.eval() - - key_list = [key for key, _ in model.base_model.model.named_modules() if "lora" not in key] - for key in key_list: - parent, target, target_name = model.base_model._get_submodules(key) - if isinstance(target, (peft.tuners.lora.Linear, peft.tuners.lora.MergedLinear)): - bias = target.bias is not None - new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias) - - # manually merge if not merged - if not target.merged: - target.merge_weights = True - target.train(False) - - model.base_model._replace_module(parent, target_name, new_module, target) - # elif isinstance(target, peft.tuners.lora.MergedLinear): - - model = model.base_model.model - return model diff --git a/tests/test_peft_model.py b/tests/test_peft_model.py index 2bd89864c2..4280ff3f86 100644 --- a/tests/test_peft_model.py +++ b/tests/test_peft_model.py @@ -24,7 +24,6 @@ PeftModel, get_peft_model, get_peft_model_state_dict, - merge_lora, prepare_model_for_int8_training, ) @@ -173,15 +172,18 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): model = get_peft_model(model, config) model = model.to(self.torch_device) - if config.peft_type != "LORA" or model.config.model_type == "gpt2": + if config.peft_type != "LORA": + with self.assertRaises(AttributeError): + model = model.merge_and_unload() + elif model.config.model_type == "gpt2": with self.assertRaises(ValueError): - merge_lora(model) + model = model.merge_and_unload() else: dummy_input = torch.LongTensor([[1, 2, 3, 2, 1]]).to(self.torch_device) model.eval() logits_lora = model(dummy_input)[0] - model = merge_lora(model) + model = model.merge_and_unload() logits_merged = model(dummy_input)[0] From 5d372c8ed48ab66491a9858ffe7c30fe97189d9a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 30 Mar 2023 10:58:00 +0000 Subject: [PATCH 09/12] adapt --- src/peft/tuners/lora.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index 704e58659f..2ff94017fe 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -255,8 +255,13 @@ def merge_and_unload(self): # manually merge if not merged if not target.merged: - target.merge_weights = True - target.train(False) + # merge weights per: https://arxiv.org/pdf/2106.09685.pdf / page 4 + if target.r > 0: + target.weight.data += ( + transpose(target.lora_B.weight @ target.lora_A.weight, target.fan_in_fan_out) + * target.scaling + ) + target.merged = True self._replace_module(parent, target_name, new_module, target) return self.model From d57932ca0252be46772aca889e347a6ca0320dab Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 30 Mar 2023 13:19:43 +0200 Subject: [PATCH 10/12] Update src/peft/tuners/lora.py Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> --- src/peft/tuners/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index 2ff94017fe..77e7babb9a 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -249,7 +249,7 @@ def merge_and_unload(self): key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] for key in key_list: parent, target, target_name = self._get_submodules(key) - if isinstance(target, (Linear, MergedLinear)): + if isinstance(target, LoraLayer): bias = target.bias is not None new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias) From 02ab38ef77da1851ba28d5beec1769b5ae1b6d5a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 30 Mar 2023 11:22:24 +0000 Subject: [PATCH 11/12] fix 8bit --- src/peft/tuners/lora.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index 2ff94017fe..d83f29f29a 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -246,6 +246,9 @@ def merge_and_unload(self): if self.config.model_type == "gpt2": raise ValueError("GPT2 models are not supported for merging LORA layers") + if getattr(self.model, "is_loaded_in_8bit", False): + raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode") + key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] for key in key_list: parent, target, target_name = self._get_submodules(key) From 929d8890fbd86ab0f6264efee21119bad9ddaffd Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 30 Mar 2023 13:35:11 +0200 Subject: [PATCH 12/12] Update src/peft/tuners/lora.py Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> --- src/peft/tuners/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index 4acd8ca9ee..47f2c02947 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -263,7 +263,7 @@ def merge_and_unload(self): target.weight.data += ( transpose(target.lora_B.weight @ target.lora_A.weight, target.fan_in_fan_out) * target.scaling - ) + ).to(target.weight.dtype) target.merged = True self._replace_module(parent, target_name, new_module, target)