diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 32f47bb23fdb..ff8c6b64cba3 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -81,6 +81,87 @@ "timestep_spacing": "leading", } + +STABLE_CASCADE_DEFAULT_CONFIGS = { + "stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"}, + "stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"}, + "stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"}, + "stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"}, +} + + +def convert_stable_cascade_unet_single_file_to_diffusers(original_state_dict): + is_stage_c = "clip_txt_mapper.weight" in original_state_dict + + if is_stage_c: + state_dict = {} + for key in original_state_dict.keys(): + if key.endswith("in_proj_weight"): + weights = original_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = original_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = original_state_dict[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = original_state_dict[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + else: + state_dict[key] = original_state_dict[key] + else: + state_dict = {} + for key in original_state_dict.keys(): + if key.endswith("in_proj_weight"): + weights = original_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = original_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = original_state_dict[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = original_state_dict[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + # rename clip_mapper to clip_txt_pooled_mapper + elif key.endswith("clip_mapper.weight"): + weights = original_state_dict[key] + state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights + elif key.endswith("clip_mapper.bias"): + weights = original_state_dict[key] + state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights + else: + state_dict[key] = original_state_dict[key] + + return state_dict + + +def infer_stable_cascade_single_file_config(checkpoint): + is_stage_c = "clip_txt_mapper.weight" in checkpoint + is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint + + if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536): + config_type = "stage_c_lite" + elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048): + config_type = "stage_c" + elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576: + config_type = "stage_b_lite" + elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640: + config_type = "stage_b" + + return STABLE_CASCADE_DEFAULT_CONFIGS[config_type] + + DIFFUSERS_TO_LDM_MAPPING = { "unet": { "layers": { @@ -229,10 +310,34 @@ def fetch_ldm_config_and_checkpoint( cache_dir=None, local_files_only=None, revision=None, +): + checkpoint = load_single_file_model_checkpoint( + pretrained_model_link_or_path, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + original_config = fetch_original_config(class_name, checkpoint, original_config_file) + + return original_config, checkpoint + + +def load_single_file_model_checkpoint( + pretrained_model_link_or_path, + resume_download=False, + force_download=False, + proxies=None, + token=None, + cache_dir=None, + local_files_only=None, + revision=None, ): if os.path.isfile(pretrained_model_link_or_path): checkpoint = load_state_dict(pretrained_model_link_or_path) - else: repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) checkpoint_path = _get_model_file( @@ -252,9 +357,7 @@ def fetch_ldm_config_and_checkpoint( while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] - original_config = fetch_original_config(class_name, checkpoint, original_config_file) - - return original_config, checkpoint + return checkpoint def infer_original_config_file(class_name, checkpoint): diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index e9f2cb2ed1f1..f89e004261f2 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -42,6 +42,11 @@ set_adapter_layers, set_weights_and_activate_adapters, ) +from .single_file_utils import ( + convert_stable_cascade_unet_single_file_to_diffusers, + infer_stable_cascade_single_file_config, + load_single_file_model_checkpoint, +) from .utils import AttnProcsLayers @@ -896,3 +901,103 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): self.config.encoder_hid_dim_type = "ip_image_proj" self.to(dtype=self.dtype, device=self.device) + + +class FromOriginalUNetMixin: + """ + Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`]. + """ + + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or + `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + config: (`dict`, *optional*): + Dictionary containing the configuration of the model: + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables of the model. + + """ + config = kwargs.pop("config", None) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + + class_name = cls.__name__ + if class_name != "StableCascadeUNet": + raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet") + + checkpoint = load_single_file_model_checkpoint( + pretrained_model_link_or_path, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + + if config is None: + config = infer_stable_cascade_single_file_config(checkpoint) + model_config = cls.load_config(**config, **kwargs) + else: + model_config = config + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model = cls.from_config(model_config, **kwargs) + + diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint) + if is_accelerate_available(): + unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + if len(unexpected_keys) > 0: + logger.warn( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + else: + model.load_state_dict(diffusers_format_checkpoint) + + if torch_dtype is not None: + model.to(torch_dtype) + + return model diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index b8833779ba9f..9f81e50241a9 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -21,6 +21,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders.unet import FromOriginalUNetMixin from ...utils import BaseOutput from ..attention_processor import Attention from ..modeling_utils import ModelMixin @@ -134,7 +135,7 @@ class StableCascadeUNetOutput(BaseOutput): sample: torch.FloatTensor = None -class StableCascadeUNet(ModelMixin, ConfigMixin): +class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin): _supports_gradient_checkpointing = True @register_to_config diff --git a/tests/models/unets/test_models_unet_stable_cascade.py b/tests/models/unets/test_models_unet_stable_cascade.py new file mode 100644 index 000000000000..fad1dcf2448c --- /dev/null +++ b/tests/models/unets/test_models_unet_stable_cascade.py @@ -0,0 +1,195 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch + +from diffusers import StableCascadeUNet +from diffusers.utils import logging +from diffusers.utils.testing_utils import ( + enable_full_determinism, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, +) +from diffusers.utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) + +enable_full_determinism() + + +@slow +class StableCascadeUNetModelSlowTests(unittest.TestCase): + def tearDown(self) -> None: + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_cascade_unet_prior_single_file_components(self): + single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors" + single_file_unet = StableCascadeUNet.from_single_file(single_file_url) + + single_file_unet_config = single_file_unet.config + del single_file_unet + gc.collect() + torch.cuda.empty_cache() + + unet = StableCascadeUNet.from_pretrained( + "stabilityai/stable-cascade-prior", subfolder="prior", revision="refs/pr/2", variant="bf16" + ) + unet_config = unet.config + del unet + gc.collect() + torch.cuda.empty_cache() + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"] + for param_name, param_value in single_file_unet_config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + + assert unet_config[param_name] == param_value + + def test_stable_cascade_unet_decoder_single_file_components(self): + single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_bf16.safetensors" + single_file_unet = StableCascadeUNet.from_single_file(single_file_url) + + single_file_unet_config = single_file_unet.config + del single_file_unet + gc.collect() + torch.cuda.empty_cache() + + unet = StableCascadeUNet.from_pretrained( + "stabilityai/stable-cascade", subfolder="decoder", revision="refs/pr/44", variant="bf16" + ) + unet_config = unet.config + del unet + gc.collect() + torch.cuda.empty_cache() + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"] + for param_name, param_value in single_file_unet_config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + + assert unet_config[param_name] == param_value + + def test_stable_cascade_unet_config_loading(self): + config = StableCascadeUNet.load_config( + pretrained_model_name_or_path="diffusers/stable-cascade-configs", subfolder="prior" + ) + single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors" + + single_file_unet = StableCascadeUNet.from_single_file(single_file_url, config=config) + single_file_unet_config = single_file_unet.config + del single_file_unet + gc.collect() + torch.cuda.empty_cache() + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"] + for param_name, param_value in config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + + assert single_file_unet_config[param_name] == param_value + + @require_torch_gpu + def test_stable_cascade_unet_single_file_prior_forward_pass(self): + dtype = torch.bfloat16 + generator = torch.Generator("cpu") + + model_inputs = { + "sample": randn_tensor((1, 16, 24, 24), generator=generator.manual_seed(0)).to("cuda", dtype), + "timestep_ratio": torch.tensor([1]).to("cuda", dtype), + "clip_text_pooled": randn_tensor((1, 1, 1280), generator=generator.manual_seed(0)).to("cuda", dtype), + "clip_text": randn_tensor((1, 77, 1280), generator=generator.manual_seed(0)).to("cuda", dtype), + "clip_img": randn_tensor((1, 1, 768), generator=generator.manual_seed(0)).to("cuda", dtype), + "pixels": randn_tensor((1, 3, 8, 8), generator=generator.manual_seed(0)).to("cuda", dtype), + } + + unet = StableCascadeUNet.from_pretrained( + "stabilityai/stable-cascade-prior", + subfolder="prior", + revision="refs/pr/2", + variant="bf16", + torch_dtype=dtype, + ) + unet.to("cuda") + with torch.no_grad(): + prior_output = unet(**model_inputs).sample.float().cpu().numpy() + + # Remove UNet from GPU memory before loading the single file UNet model + del unet + gc.collect() + torch.cuda.empty_cache() + + single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors" + single_file_unet = StableCascadeUNet.from_single_file(single_file_url, torch_dtype=dtype) + single_file_unet.to("cuda") + with torch.no_grad(): + prior_single_file_output = single_file_unet(**model_inputs).sample.float().cpu().numpy() + + # Remove UNet from GPU memory before loading the single file UNet model + del single_file_unet + gc.collect() + torch.cuda.empty_cache() + + max_diff = numpy_cosine_similarity_distance(prior_output.flatten(), prior_single_file_output.flatten()) + assert max_diff < 8e-3 + + @require_torch_gpu + def test_stable_cascade_unet_single_file_decoder_forward_pass(self): + dtype = torch.float32 + generator = torch.Generator("cpu") + + model_inputs = { + "sample": randn_tensor((1, 4, 256, 256), generator=generator.manual_seed(0)).to("cuda", dtype), + "timestep_ratio": torch.tensor([1]).to("cuda", dtype), + "clip_text": randn_tensor((1, 77, 1280), generator=generator.manual_seed(0)).to("cuda", dtype), + "clip_text_pooled": randn_tensor((1, 1, 1280), generator=generator.manual_seed(0)).to("cuda", dtype), + "pixels": randn_tensor((1, 3, 8, 8), generator=generator.manual_seed(0)).to("cuda", dtype), + } + + unet = StableCascadeUNet.from_pretrained( + "stabilityai/stable-cascade", + subfolder="decoder", + revision="refs/pr/44", + torch_dtype=dtype, + ) + unet.to("cuda") + with torch.no_grad(): + prior_output = unet(**model_inputs).sample.float().cpu().numpy() + + # Remove UNet from GPU memory before loading the single file UNet model + del unet + gc.collect() + torch.cuda.empty_cache() + + single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b.safetensors" + single_file_unet = StableCascadeUNet.from_single_file(single_file_url, torch_dtype=dtype) + single_file_unet.to("cuda") + with torch.no_grad(): + prior_single_file_output = single_file_unet(**model_inputs).sample.float().cpu().numpy() + + # Remove UNet from GPU memory before loading the single file UNet model + del single_file_unet + gc.collect() + torch.cuda.empty_cache() + + max_diff = numpy_cosine_similarity_distance(prior_output.flatten(), prior_single_file_output.flatten()) + assert max_diff < 1e-4