diff --git a/examples/community/README.md b/examples/community/README.md index 905f7b887b46..c68dd1161b47 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -734,7 +734,7 @@ image = pipe(prompt, generator=generator, num_inference_steps=50).images[0] ![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler_k_diffusion.png) ### Checkpoint Merger Pipeline -Based on the AUTOMATIC1111/webui for checkpoint merging. This is a custom pipeline that merges upto 3 pretrained model checkpoints as long as they are in the HuggingFace model_index.json format. +Based on the AUTOMATIC1111/webui for checkpoint merging. This is a custom pipeline that merges up to 3 pretrained model checkpoints as long as they are in the HuggingFace model_index.json format. The checkpoint merging is currently memory intensive as it modifies the weights of a DiffusionPipeline object in place. Expect atleast 13GB RAM Usage on Kaggle GPU kernels and on colab you might run out of the 12GB memory even while merging two checkpoints. @@ -758,7 +758,15 @@ merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","CompVis/stable-diffus merged_pipe_1 = pipe.merge(["CompVis/stable-diffusion-v1-4","hakurei/waifu-diffusion"], force = True, interp = "sigmoid", alpha = 0.4) #Three checkpoint merging. Only "add_difference" method actually works on all three checkpoints. Using any other options will ignore the 3rd checkpoint. -merged_pipe_2 = pipe.merge(["CompVis/stable-diffusion-v1-4","hakurei/waifu-diffusion","prompthero/openjourney"], force = True, interp = "add_difference", alpha = 0.4) +merged_pipe_2 = pipe.merge(["CompVis/stable-diffusion-v1-4","hakurei/waifu-diffusion","prompthero/openjourney"], force = True, interp = "add_diff", alpha = 0.4) + +#Merging with different weights for unet and text_encoder +merged_pipe_3 = pipe.merge(["CompVis/stable-diffusion-v2-1","IlluminatiAI/Illuminati_Diffusion_v1.0"], force = True, interp = "weighted_sum", + alpha = 0.5, module_override_alphas = {'unet': 0.2, 'text_encoder': 0.6}) + +#Merging with different weights for different layers in the unet (12 weights for the down layers, 1 weight for the middle layer, 12 weights for the up layers) +merged_pipe_4 = pipe.merge(["CompVis/stable-diffusion-v2-1","IlluminatiAI/Illuminati_Diffusion_v1.0"], force = True, interp = "weighted_sum", + alpha = 0.5, block_weights = "0,0,0,0,0,0,0,0,0,0,0,0,0.5,1,1,1,1,1,1,1,1,1,1,1,1") prompt = "An astronaut riding a horse on Mars" diff --git a/examples/community/checkpoint_merger.py b/examples/community/checkpoint_merger.py index 24f187b41c07..504155d8d0d3 100644 --- a/examples/community/checkpoint_merger.py +++ b/examples/community/checkpoint_merger.py @@ -12,11 +12,78 @@ from huggingface_hub import snapshot_download -from diffusers import DiffusionPipeline, __version__ +from diffusers import DiffusionPipeline, UNet2DConditionModel, __version__ from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME +NUM_INPUT_BLOCKS = 12 +NUM_MID_BLOCK = 1 +NUM_OUTPUT_BLOCKS = 12 +NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS + +DIFFUSERS_KEY_PREFIX_TO_WEIGHT_INDEX = { + "time_embedding.": 0, + "conv_in.": 0, + "down_blocks.0.resnets.0": 1, + "down_blocks.0.attentions.0": 1, + "down_blocks.0.resnets.1": 2, + "down_blocks.0.attentions.1": 2, + "down_blocks.0.downsamplers": 3, + "down_blocks.1.resnets.0": 4, + "down_blocks.1.attentions.0": 4, + "down_blocks.1.resnets.1": 5, + "down_blocks.1.attentions.1": 5, + "down_blocks.1.downsamplers": 6, + "down_blocks.2.resnets.0": 7, + "down_blocks.2.attentions.0": 7, + "down_blocks.2.resnets.1": 8, + "down_blocks.2.attentions.1": 8, + "down_blocks.2.downsamplers": 9, + "down_blocks.3.resnets.0": 10, + "down_blocks.3.resnets.1": 11, + "mid_block": 12, + "up_blocks.0.resnets.0": 13, + "up_blocks.0.resnets.1": 14, + "up_blocks.0.resnets.2": 15, + "up_blocks.0.upsamplers.0": 15, + "up_blocks.1.resnets.0": 16, + "up_blocks.1.attentions.0": 16, + "up_blocks.1.resnets.1": 17, + "up_blocks.1.attentions.1": 17, + "up_blocks.1.resnets.2": 18, + "up_blocks.1.upsamplers.0": 18, + "up_blocks.1.attentions.2": 18, + "up_blocks.2.resnets.0": 19, + "up_blocks.2.attentions.0": 19, + "up_blocks.2.resnets.1": 20, + "up_blocks.2.attentions.1": 20, + "up_blocks.2.resnets.2": 21, + "up_blocks.2.upsamplers.0": 21, + "up_blocks.2.attentions.2": 21, + "up_blocks.3.resnets.0": 22, + "up_blocks.3.attentions.0": 22, + "up_blocks.3.resnets.1": 23, + "up_blocks.3.attentions.1": 23, + "up_blocks.3.resnets.2": 24, + "up_blocks.3.attentions.2": 24, + "conv_norm_out.": 24, + "conv_out.": 24, +} + + +def get_weight_index(key: str) -> int: + for k, v in DIFFUSERS_KEY_PREFIX_TO_WEIGHT_INDEX.items(): + if key.startswith(k): + return v + raise ValueError(f"Unknown unet key: {key}") + + +def get_block_alpha(block_weights: list, key: str) -> float: + weight_index = get_weight_index(key) + return block_weights[weight_index] + + class CheckpointMergerPipeline(DiffusionPipeline): """ A class that that supports merging diffusion models based on the discussion here: @@ -85,6 +152,10 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike] force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False. + block_weights - list of 25 floats for per-block weighting. ref https://rentry.org/Merge_Block_Weight_-china-_v1_Beta#3-basic-theory-explanation + + module_override_alphas - dict of str -> float for per-module alpha overrides eg {'unet': 0.2, 'text_encoder': 0.8} + """ # Default kwargs from DiffusionPipeline cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) @@ -99,9 +170,13 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike] alpha = kwargs.pop("alpha", 0.5) interp = kwargs.pop("interp", None) + block_weights: list[float] = kwargs.pop("block_weights", None) + module_override_alphas: dict[str, float] = kwargs.pop("module_override_alphas", {}) print("Received list", pretrained_model_name_or_path_list) print(f"Combining with alpha={alpha}, interpolation mode={interp}") + if block_weights is not None: + print(f"Merging unet using block weights {block_weights}") checkpoint_count = len(pretrained_model_name_or_path_list) # Ignore result from model_index_json comparision of the two checkpoints @@ -122,6 +197,7 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike] # Step 1: Load the model config and compare the checkpoints. We'll compare the model_index.json first while ignoring the keys starting with '_' config_dicts = [] for pretrained_model_name_or_path in pretrained_model_name_or_path_list: + print(f"loading DiffusionPipeline from {pretrained_model_name_or_path}...") config_dict = DiffusionPipeline.load_config( pretrained_model_name_or_path, cache_dir=cache_dir, @@ -138,12 +214,13 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike] for idx in range(1, len(config_dicts)): comparison_result &= self._compare_model_configs(config_dicts[idx - 1], config_dicts[idx]) if not force and comparison_result is False: - raise ValueError("Incompatible checkpoints. Please check model_index.json for the models.") print(config_dicts[0], config_dicts[1]) + raise ValueError("Incompatible checkpoints. Please check model_index.json for the models.") print("Compatible model_index.json files found") # Step 2: Basic Validation has succeeded. Let's download the models and save them into our local files. cached_folders = [] for pretrained_model_name_or_path, config_dict in zip(pretrained_model_name_or_path_list, config_dicts): + print(f"Loading {pretrained_model_name_or_path}...") folder_names = [k for k in config_dict.keys() if not k.startswith("_")] allow_patterns = [os.path.join(k, "*") for k in folder_names] allow_patterns += [ @@ -181,10 +258,6 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike] ) final_pipe.to(self.device) - checkpoint_path_2 = None - if len(cached_folders) > 2: - checkpoint_path_2 = os.path.join(cached_folders[2]) - if interp == "sigmoid": theta_func = CheckpointMergerPipeline.sigmoid elif interp == "inv_sigmoid": @@ -221,7 +294,7 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike] # For an attr if both checkpoint_path_1 and 2 are None, ignore. # If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match. if checkpoint_path_1 is None and checkpoint_path_2 is None: - print(f"Skipping {attr}: not present in 2nd or 3d model") + print(f"Skipping {attr}: not present in 2nd or 3rd model") continue try: module = getattr(final_pipe, attr) @@ -243,7 +316,6 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike] if (is_safetensors_available() and checkpoint_path_2.endswith(".safetensors")) else torch.load(checkpoint_path_2, map_location="cpu") ) - if not theta_0.keys() == theta_1.keys(): print(f"Skipping {attr}: key mismatch") continue @@ -253,12 +325,21 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike] print(f"Skipping {attr} do to an unexpected error: {str(e)}") continue print(f"MERGING {attr}") + if block_weights is not None and type(module) is UNet2DConditionModel: + print(f" - using block weights {block_weights}") + elif module_override_alphas.get(attr, None) is not None: + print(f" - using override alpha {module_override_alphas[attr]}") for key in theta_0.keys(): + this_alpha = ( + get_block_alpha(block_weights, key) + if block_weights is not None and type(module) is UNet2DConditionModel + else (module_override_alphas.get(attr, None) or alpha) + ) if theta_2: - theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key], alpha) + theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key], this_alpha) else: - theta_0[key] = theta_func(theta_0[key], theta_1[key], None, alpha) + theta_0[key] = theta_func(theta_0[key], theta_1[key], None, this_alpha) del theta_1 del theta_2