Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler_k_diffusion.png)

### Checkpoint Merger Pipeline
Based on the AUTOMATIC1111/webui for checkpoint merging. This is a custom pipeline that merges upto 3 pretrained model checkpoints as long as they are in the HuggingFace model_index.json format.
Based on the AUTOMATIC1111/webui for checkpoint merging. This is a custom pipeline that merges up to 3 pretrained model checkpoints as long as they are in the HuggingFace model_index.json format.

The checkpoint merging is currently memory intensive as it modifies the weights of a DiffusionPipeline object in place. Expect atleast 13GB RAM Usage on Kaggle GPU kernels and
on colab you might run out of the 12GB memory even while merging two checkpoints.
Expand All @@ -758,7 +758,15 @@ merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","CompVis/stable-diffus
merged_pipe_1 = pipe.merge(["CompVis/stable-diffusion-v1-4","hakurei/waifu-diffusion"], force = True, interp = "sigmoid", alpha = 0.4)

#Three checkpoint merging. Only "add_difference" method actually works on all three checkpoints. Using any other options will ignore the 3rd checkpoint.
merged_pipe_2 = pipe.merge(["CompVis/stable-diffusion-v1-4","hakurei/waifu-diffusion","prompthero/openjourney"], force = True, interp = "add_difference", alpha = 0.4)
merged_pipe_2 = pipe.merge(["CompVis/stable-diffusion-v1-4","hakurei/waifu-diffusion","prompthero/openjourney"], force = True, interp = "add_diff", alpha = 0.4)

#Merging with different weights for unet and text_encoder
merged_pipe_3 = pipe.merge(["CompVis/stable-diffusion-v2-1","IlluminatiAI/Illuminati_Diffusion_v1.0"], force = True, interp = "weighted_sum",
alpha = 0.5, module_override_alphas = {'unet': 0.2, 'text_encoder': 0.6})

#Merging with different weights for different layers in the unet (12 weights for the down layers, 1 weight for the middle layer, 12 weights for the up layers)
merged_pipe_4 = pipe.merge(["CompVis/stable-diffusion-v2-1","IlluminatiAI/Illuminati_Diffusion_v1.0"], force = True, interp = "weighted_sum",
alpha = 0.5, block_weights = "0,0,0,0,0,0,0,0,0,0,0,0,0.5,1,1,1,1,1,1,1,1,1,1,1,1")

prompt = "An astronaut riding a horse on Mars"

Expand Down
101 changes: 91 additions & 10 deletions examples/community/checkpoint_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,78 @@

from huggingface_hub import snapshot_download

from diffusers import DiffusionPipeline, __version__
from diffusers import DiffusionPipeline, UNet2DConditionModel, __version__
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME


NUM_INPUT_BLOCKS = 12
NUM_MID_BLOCK = 1
NUM_OUTPUT_BLOCKS = 12
NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS

DIFFUSERS_KEY_PREFIX_TO_WEIGHT_INDEX = {
"time_embedding.": 0,
"conv_in.": 0,
"down_blocks.0.resnets.0": 1,
"down_blocks.0.attentions.0": 1,
"down_blocks.0.resnets.1": 2,
"down_blocks.0.attentions.1": 2,
"down_blocks.0.downsamplers": 3,
"down_blocks.1.resnets.0": 4,
"down_blocks.1.attentions.0": 4,
"down_blocks.1.resnets.1": 5,
"down_blocks.1.attentions.1": 5,
"down_blocks.1.downsamplers": 6,
"down_blocks.2.resnets.0": 7,
"down_blocks.2.attentions.0": 7,
"down_blocks.2.resnets.1": 8,
"down_blocks.2.attentions.1": 8,
"down_blocks.2.downsamplers": 9,
"down_blocks.3.resnets.0": 10,
"down_blocks.3.resnets.1": 11,
"mid_block": 12,
"up_blocks.0.resnets.0": 13,
"up_blocks.0.resnets.1": 14,
"up_blocks.0.resnets.2": 15,
"up_blocks.0.upsamplers.0": 15,
"up_blocks.1.resnets.0": 16,
"up_blocks.1.attentions.0": 16,
"up_blocks.1.resnets.1": 17,
"up_blocks.1.attentions.1": 17,
"up_blocks.1.resnets.2": 18,
"up_blocks.1.upsamplers.0": 18,
"up_blocks.1.attentions.2": 18,
"up_blocks.2.resnets.0": 19,
"up_blocks.2.attentions.0": 19,
"up_blocks.2.resnets.1": 20,
"up_blocks.2.attentions.1": 20,
"up_blocks.2.resnets.2": 21,
"up_blocks.2.upsamplers.0": 21,
"up_blocks.2.attentions.2": 21,
"up_blocks.3.resnets.0": 22,
"up_blocks.3.attentions.0": 22,
"up_blocks.3.resnets.1": 23,
"up_blocks.3.attentions.1": 23,
"up_blocks.3.resnets.2": 24,
"up_blocks.3.attentions.2": 24,
"conv_norm_out.": 24,
"conv_out.": 24,
}


def get_weight_index(key: str) -> int:
for k, v in DIFFUSERS_KEY_PREFIX_TO_WEIGHT_INDEX.items():
if key.startswith(k):
return v
raise ValueError(f"Unknown unet key: {key}")


def get_block_alpha(block_weights: list, key: str) -> float:
weight_index = get_weight_index(key)
return block_weights[weight_index]


class CheckpointMergerPipeline(DiffusionPipeline):
"""
A class that that supports merging diffusion models based on the discussion here:
Expand Down Expand Up @@ -85,6 +152,10 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]

force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.

block_weights - list of 25 floats for per-block weighting. ref https://rentry.org/Merge_Block_Weight_-china-_v1_Beta#3-basic-theory-explanation

module_override_alphas - dict of str -> float for per-module alpha overrides eg {'unet': 0.2, 'text_encoder': 0.8}

"""
# Default kwargs from DiffusionPipeline
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
Expand All @@ -99,9 +170,13 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]

alpha = kwargs.pop("alpha", 0.5)
interp = kwargs.pop("interp", None)
block_weights: list[float] = kwargs.pop("block_weights", None)
module_override_alphas: dict[str, float] = kwargs.pop("module_override_alphas", {})

print("Received list", pretrained_model_name_or_path_list)
print(f"Combining with alpha={alpha}, interpolation mode={interp}")
if block_weights is not None:
print(f"Merging unet using block weights {block_weights}")

checkpoint_count = len(pretrained_model_name_or_path_list)
# Ignore result from model_index_json comparision of the two checkpoints
Expand All @@ -122,6 +197,7 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
# Step 1: Load the model config and compare the checkpoints. We'll compare the model_index.json first while ignoring the keys starting with '_'
config_dicts = []
for pretrained_model_name_or_path in pretrained_model_name_or_path_list:
print(f"loading DiffusionPipeline from {pretrained_model_name_or_path}...")
config_dict = DiffusionPipeline.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
Expand All @@ -138,12 +214,13 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
for idx in range(1, len(config_dicts)):
comparison_result &= self._compare_model_configs(config_dicts[idx - 1], config_dicts[idx])
if not force and comparison_result is False:
raise ValueError("Incompatible checkpoints. Please check model_index.json for the models.")
print(config_dicts[0], config_dicts[1])
raise ValueError("Incompatible checkpoints. Please check model_index.json for the models.")
print("Compatible model_index.json files found")
# Step 2: Basic Validation has succeeded. Let's download the models and save them into our local files.
cached_folders = []
for pretrained_model_name_or_path, config_dict in zip(pretrained_model_name_or_path_list, config_dicts):
print(f"Loading {pretrained_model_name_or_path}...")
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [
Expand Down Expand Up @@ -181,10 +258,6 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
)
final_pipe.to(self.device)

checkpoint_path_2 = None
if len(cached_folders) > 2:
checkpoint_path_2 = os.path.join(cached_folders[2])

if interp == "sigmoid":
theta_func = CheckpointMergerPipeline.sigmoid
elif interp == "inv_sigmoid":
Expand Down Expand Up @@ -221,7 +294,7 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
# For an attr if both checkpoint_path_1 and 2 are None, ignore.
# If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match.
if checkpoint_path_1 is None and checkpoint_path_2 is None:
print(f"Skipping {attr}: not present in 2nd or 3d model")
print(f"Skipping {attr}: not present in 2nd or 3rd model")
continue
try:
module = getattr(final_pipe, attr)
Expand All @@ -243,7 +316,6 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
if (is_safetensors_available() and checkpoint_path_2.endswith(".safetensors"))
else torch.load(checkpoint_path_2, map_location="cpu")
)

if not theta_0.keys() == theta_1.keys():
print(f"Skipping {attr}: key mismatch")
continue
Expand All @@ -253,12 +325,21 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
print(f"Skipping {attr} do to an unexpected error: {str(e)}")
continue
print(f"MERGING {attr}")
if block_weights is not None and type(module) is UNet2DConditionModel:
print(f" - using block weights {block_weights}")
elif module_override_alphas.get(attr, None) is not None:
print(f" - using override alpha {module_override_alphas[attr]}")

for key in theta_0.keys():
this_alpha = (
get_block_alpha(block_weights, key)
if block_weights is not None and type(module) is UNet2DConditionModel
else (module_override_alphas.get(attr, None) or alpha)
)
if theta_2:
theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key], alpha)
theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key], this_alpha)
else:
theta_0[key] = theta_func(theta_0[key], theta_1[key], None, alpha)
theta_0[key] = theta_func(theta_0[key], theta_1[key], None, this_alpha)

del theta_1
del theta_2
Expand Down