From d44f39c587be9936636d2a3006175df753b5ea3f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 13 Dec 2024 14:19:09 +0530 Subject: [PATCH 1/5] feat: support unload_lora_weights() for Flux Control. --- src/diffusers/loaders/lora_pipeline.py | 53 ++++++++++++++++++++++ tests/lora/test_lora_layers_flux.py | 62 ++++++++++++++++++++++++++ 2 files changed, 115 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1445394b8784..545d4ff35a5a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2283,6 +2283,50 @@ def unload_lora_weights(self): transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) transformer._transformer_norm_layers = None + if getattr(transformer, "_overwritten_params", None) is not None: + print(f"{transformer._overwritten_params.keys()=}") + overwritten_params = transformer._overwritten_params + module_names = set() + + for param_name in overwritten_params: + if param_name.endswith(".weight"): + module_names.add(param_name.replace(".weight", "")) + + for name, module in transformer.named_modules(): + if isinstance(module, torch.nn.Linear) and name in module_names: + module_weight = module.weight.data + module_bias = module.bias.data if module.bias is not None else None + bias = module_bias is not None + + parent_module_name, _, current_module_name = name.rpartition(".") + parent_module = transformer.get_submodule(parent_module_name) + + current_param_weight = overwritten_params[f"{name}.weight"] + in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0] + with torch.device("meta"): + original_module = torch.nn.Linear( + in_features, + out_features, + bias=bias, + device=module_weight.device, + dtype=module_weight.dtype, + ) + + original_module.weight.data.copy_(current_param_weight) + if module_bias is not None: + original_module.bias.data.copy_(overwritten_params[f"{name}.bias"]) + + setattr(parent_module, current_module_name, original_module) + + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: + attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] + new_value = int(current_param_weight.shape[1]) + old_value = getattr(transformer.config, attribute_name) + setattr(transformer.config, attribute_name, new_value) + logger.info( + f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." + ) + @classmethod def _maybe_expand_transformer_param_shape_or_error_( cls, @@ -2309,6 +2353,7 @@ def _maybe_expand_transformer_param_shape_or_error_( # Expand transformer parameter shapes if they don't match lora has_param_with_shape_update = False + overwritten_params = {} for name, module in transformer.named_modules(): if isinstance(module, torch.nn.Linear): @@ -2371,6 +2416,14 @@ def _maybe_expand_transformer_param_shape_or_error_( setattr(transformer.config, attribute_name, new_value) logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.") + # For `unload_lora_weights()`. + overwritten_params[f"{current_module_name}.weight"] = module_weight + if module_bias is not None: + overwritten_params[f"{current_module_name}.bias"] = module_bias + + if len(overwritten_params) > 0: + transformer._overwritten_params = overwritten_params + return has_param_with_shape_update diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 8142085f981c..b1df5564a550 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -430,6 +430,68 @@ def test_correct_lora_configs_with_different_ranks(self): self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + def test_lora_unload_with_parameter_expanded_shapes(self): + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + self.assertTrue( + transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + ) + + # This should be initialize with a Flux pipeline variant that doesn't accept `control_image`. + components["transformer"] = transformer + pipe = FluxPipeline(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + control_image = inputs.pop("control_image") + original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + control_pipe = self.pipeline_class(**components) + out_features, in_features = control_pipe.transformer.x_embedder.weight.shape + rank = 4 + + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, + } + with CaptureLogger(logger) as cap_logger: + control_pipe.load_lora_weights(lora_state_dict, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + inputs["control_image"] = control_image + lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) + self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) + self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) + + control_pipe.unload_lora_weights() + loaded_pipe = FluxPipeline.from_pipe(control_pipe) + self.assertTrue( + loaded_pipe.transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}", + ) + inputs.pop("control_image") + unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4)) + self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4)) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) + self.assertTrue(pipe.transformer.config.in_channels == in_features) + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass From c541d741cf0c503f03cc5e90726a6098b7818a6e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 13 Dec 2024 14:23:35 +0530 Subject: [PATCH 2/5] tighten test --- src/diffusers/loaders/lora_pipeline.py | 1 - tests/lora/test_lora_layers_flux.py | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 545d4ff35a5a..c8c7ecdf30b7 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2284,7 +2284,6 @@ def unload_lora_weights(self): transformer._transformer_norm_layers = None if getattr(transformer, "_overwritten_params", None) is not None: - print(f"{transformer._overwritten_params.keys()=}") overwritten_params = transformer._overwritten_params module_names = set() diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index b1df5564a550..ec04e98b0c6d 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -479,6 +479,10 @@ def test_lora_unload_with_parameter_expanded_shapes(self): self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) control_pipe.unload_lora_weights() + self.assertTrue( + control_pipe.transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", + ) loaded_pipe = FluxPipeline.from_pipe(control_pipe) self.assertTrue( loaded_pipe.transformer.config.in_channels == num_channels_without_control, From 4509f34920216e8d222d21c370bb4d910062cd69 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 13 Dec 2024 14:24:40 +0530 Subject: [PATCH 3/5] minor --- tests/lora/test_lora_layers_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index ec04e98b0c6d..92c9603078b5 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -446,7 +446,7 @@ def test_lora_unload_with_parameter_expanded_shapes(self): f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", ) - # This should be initialize with a Flux pipeline variant that doesn't accept `control_image`. + # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. components["transformer"] = transformer pipe = FluxPipeline(**components) pipe = pipe.to(torch_device) From 2f0545568b780da7bb34bf25c4285a9465ba55af Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Dec 2024 10:54:23 +0530 Subject: [PATCH 4/5] updates --- src/diffusers/loaders/lora_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 68b0fec01e05..eed724b3de0e 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2307,7 +2307,6 @@ def unload_lora_weights(self): in_features, out_features, bias=bias, - device=module_weight.device, dtype=module_weight.dtype, ) @@ -2423,6 +2422,8 @@ def _maybe_expand_transformer_param_shape_or_error_( logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.") # For `unload_lora_weights()`. + # TODO: this could lead to more memory overhead if the number of overwritten params + # are large. Should be revisited later and tackled through a `discard_original_layers` arg. overwritten_params[f"{current_module_name}.weight"] = module_weight if module_bias is not None: overwritten_params[f"{current_module_name}.bias"] = module_bias From 6ed1131489636c55bbddd3de1d1816c6f8535886 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Dec 2024 11:11:01 +0530 Subject: [PATCH 5/5] meta device fixes. --- src/diffusers/loaders/lora_pipeline.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index eed724b3de0e..0d26d7685652 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2310,12 +2310,14 @@ def unload_lora_weights(self): dtype=module_weight.dtype, ) - original_module.weight.data.copy_(current_param_weight) + tmp_state_dict = {"weight": current_param_weight} if module_bias is not None: - original_module.bias.data.copy_(overwritten_params[f"{name}.bias"]) - + tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]}) + original_module.load_state_dict(tmp_state_dict, assign=True, strict=True) setattr(parent_module, current_module_name, original_module) + del tmp_state_dict + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] new_value = int(current_param_weight.shape[1])