From f03ea106813ea930b6ed51c3423e0765fcc591a0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 14 Mar 2024 15:33:40 +0530 Subject: [PATCH 1/2] refactor unet single file loading a bit. --- src/diffusers/loaders/__init__.py | 4 +- src/diffusers/loaders/single_file_utils.py | 55 +++++++++++ src/diffusers/loaders/unet.py | 92 ++++++++++--------- .../models/unets/unet_2d_condition.py | 6 +- 4 files changed, 111 insertions(+), 46 deletions(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 4da047435d8e..bd2f10e786a9 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -57,7 +57,7 @@ def text_encoder_attn_modules(text_encoder): _import_structure["autoencoder"] = ["FromOriginalVAEMixin"] _import_structure["controlnet"] = ["FromOriginalControlNetMixin"] - _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] + _import_structure["unet"] = ["FromOriginalUNetMixin", "UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] if is_transformers_available(): _import_structure["single_file"] = ["FromSingleFileMixin"] @@ -72,7 +72,7 @@ def text_encoder_attn_modules(text_encoder): if is_torch_available(): from .autoencoder import FromOriginalVAEMixin from .controlnet import FromOriginalControlNetMixin - from .unet import UNet2DConditionLoadersMixin + from .unet import FromOriginalUNetMixin, UNet2DConditionLoadersMixin from .utils import AttnProcsLayers if is_transformers_available(): diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index cdaa0802a2fa..610f1dfc5914 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -50,6 +50,8 @@ if is_accelerate_available(): from accelerate import init_empty_weights + from ..models.modeling_utils import load_model_dict_into_meta + logger = logging.get_logger(__name__) # pylint: disable=invalid-name CONFIG_URLS = { @@ -1274,6 +1276,59 @@ def create_text_encoder_from_open_clip_checkpoint( return text_model +def create_diffusers_unet_from_stable_cascade( + cls, + pretrained_model_link_or_path, + config, + resume_download, + force_download, + proxies, + token, + cache_dir, + local_files_only, + revision, + torch_dtype, + **kwargs, +): + checkpoint = load_single_file_model_checkpoint( + pretrained_model_link_or_path, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + + if config is None: + config = infer_stable_cascade_single_file_config(checkpoint) + model_config = cls.load_config(**config, **kwargs) + else: + model_config = config + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model = cls.from_config(model_config, **kwargs) + + diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint) + + if is_accelerate_available(): + unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + if len(unexpected_keys) > 0: + logger.warn( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + else: + model.load_state_dict(diffusers_format_checkpoint) + + if torch_dtype is not None: + model.to(torch_dtype) + + return model + + def create_diffusers_unet_model_from_ldm( pipeline_class_name, original_config, diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 0a9544d0dbbe..6b36e1baa7ed 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -43,9 +43,9 @@ set_weights_and_activate_adapters, ) from .single_file_utils import ( - convert_stable_cascade_unet_single_file_to_diffusers, - infer_stable_cascade_single_file_config, - load_single_file_model_checkpoint, + create_diffusers_unet_from_stable_cascade, + create_diffusers_unet_model_from_ldm, + fetch_ldm_config_and_checkpoint, ) from .utils import AttnProcsLayers @@ -66,6 +66,8 @@ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" +COMPATIBLE_SINGLE_FILE_CLASSES = ["StableCascadeUNet", "UNet2DConditionModel"] + class UNet2DConditionLoadersMixin: """ @@ -912,8 +914,9 @@ class FromOriginalUNetMixin: @validate_hf_hub_args def from_single_file(cls, pretrained_model_link_or_path, **kwargs): r""" - Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or - `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. + Instantiate a UNet from pretrained weights saved in the original `.ckpt`, `.bin`, or + `.safetensors` format. The model is set in evaluation mode (`model.eval()`) by default. + Currently supported checkpoints: StableCascade, SDXL, SD, Playground v2.5, etc. Parameters: pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): @@ -952,8 +955,10 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): """ class_name = cls.__name__ - if class_name != "StableCascadeUNet": - raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet") + if class_name not in COMPATIBLE_SINGLE_FILE_CLASSES: + raise ValueError( + f"FromOriginalUNetMixin is currently only compatible with {', '.join(COMPATIBLE_SINGLE_FILE_CLASSES)}" + ) config = kwargs.pop("config", None) resume_download = kwargs.pop("resume_download", False) @@ -965,39 +970,42 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) - checkpoint = load_single_file_model_checkpoint( - pretrained_model_link_or_path, - resume_download=resume_download, - force_download=force_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - ) - - if config is None: - config = infer_stable_cascade_single_file_config(checkpoint) - model_config = cls.load_config(**config, **kwargs) - else: - model_config = config - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - model = cls.from_config(model_config, **kwargs) - - diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint) - if is_accelerate_available(): - unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - if len(unexpected_keys) > 0: - logger.warn( - f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" - ) - + if class_name == "StableCascadeUNet": + return create_diffusers_unet_from_stable_cascade( + cls, + pretrained_model_link_or_path, + config, + resume_download, + force_download, + proxies, + token, + cache_dir, + local_files_only, + revision, + torch_dtype, + **kwargs, + ) else: - model.load_state_dict(diffusers_format_checkpoint) - - if torch_dtype is not None: - model.to(torch_dtype) - - return model + original_config, checkpoint = fetch_ldm_config_and_checkpoint( + pretrained_model_link_or_path=pretrained_model_link_or_path, + class_name=kwargs.get("pipeline_class_name", None), + original_config_file=kwargs.get("original_config_file", None), + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + cache_dir=cache_dir, + ) + return create_diffusers_unet_model_from_ldm( + pipeline_class_name=kwargs.get("pipeline_class_name", None), + original_config=original_config, + checkpoint=checkpoint, + num_in_channels=kwargs.get("num_in_channels", 4), + upcast_attention=kwargs.get("upcast_attention", None), + extract_ema=kwargs.get("upcast_attention", False), + image_size=kwargs.get("image_size", None), + torch_dtype=torch_dtype, + model_type=kwargs.pop("model_type", None), + ) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 9f69b03462dc..81c06f66f325 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -19,7 +19,7 @@ import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...loaders import FromOriginalUNetMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from ..activations import get_activation from ..attention_processor import ( @@ -66,7 +66,9 @@ class UNet2DConditionOutput(BaseOutput): sample: torch.FloatTensor = None -class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): +class UNet2DConditionModel( + ModelMixin, ConfigMixin, FromOriginalUNetMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin +): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. From bfaa0d8ab45e1fbf497abfbaa56155120bbe4ea9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 14 Mar 2024 15:52:44 +0530 Subject: [PATCH 2/2] retrieve the unet from create_diffusers_unet_model_from_ldm --- src/diffusers/loaders/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 6b36e1baa7ed..bcaf95416e9a 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -1008,4 +1008,4 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): image_size=kwargs.get("image_size", None), torch_dtype=torch_dtype, model_type=kwargs.pop("model_type", None), - ) + )["unet"]