From d4b7c8108c4b1a7fa94ed30240680dabccf15dbf Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sun, 10 Mar 2024 14:13:43 +0000 Subject: [PATCH 1/6] update --- src/diffusers/loaders/single_file_utils.py | 33 ++++++ .../models/unets/unet_stable_cascade.py | 109 +++++++++++++++++- 2 files changed, 140 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 085c3c12cdd5..d4ad4a551b3f 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -257,6 +257,39 @@ def fetch_ldm_config_and_checkpoint( return original_config, checkpoint +def load_single_file_model_checkpoint(pretrained_model_link_or_path, **kwargs): + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", None) + revision = kwargs.pop("revision", None) + + if os.path.isfile(pretrained_model_link_or_path): + checkpoint = load_state_dict(pretrained_model_link_or_path) + else: + repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) + checkpoint_path = _get_model_file( + repo_id, + weights_name=weights_name, + force_download=force_download, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + ) + checkpoint = load_state_dict(checkpoint_path) + + # some checkpoints contain the model state dict under a "state_dict" key + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + return checkpoint + + def infer_original_config_file(class_name, checkpoint): if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: config_url = CONFIG_URLS["v2"] diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index b8833779ba9f..8296ff52eded 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from contextlib import nullcontext from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -21,9 +22,94 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import BaseOutput +from ...loaders.single_file_utils import load_single_file_model_checkpoint +from ...utils import BaseOutput, is_accelerate_available from ..attention_processor import Attention -from ..modeling_utils import ModelMixin +from ..modeling_utils import ModelMixin, load_model_dict_into_meta + + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +DEFAULT_CONFIGS = { + "stage_c": {"pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", "subfolder": "prior"}, + "stage_c_lite": {"pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", "subfolder": "prior_lite"}, + "stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"}, + "stage_b_lite": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder_lite"}, +} + + +def convert_single_file_to_diffusers(original_state_dict): + is_stage_c = "clip_txt_mapper.weight" in original_state_dict + + if is_stage_c: + state_dict = {} + for key in original_state_dict.keys(): + if key.endswith("in_proj_weight"): + weights = original_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = original_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = original_state_dict[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = original_state_dict[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + else: + state_dict[key] = original_state_dict[key] + else: + state_dict = {} + for key in original_state_dict.keys(): + if key.endswith("in_proj_weight"): + weights = original_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = original_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = original_state_dict[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = original_state_dict[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + # rename clip_mapper to clip_txt_pooled_mapper + elif key.endswith("clip_mapper.weight"): + weights = original_state_dict[key] + state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights + elif key.endswith("clip_mapper.bias"): + weights = original_state_dict[key] + state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights + else: + state_dict[key] = original_state_dict[key] + + return state_dict + + +def infer_single_file_config(checkpoint): + is_stage_c = "clip_txt_mapper.weight" in checkpoint + is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint + + if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536): + config_type = "stage_c_lite" + elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048): + config_type = "stage_c" + elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576: + config_type = "stage_b_lite" + elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640: + config_type = "stage_b" + + return DEFAULT_CONFIGS[config_type] # Copied from diffusers.pipelines.wuerstchen.modeling_wuerstchen_common.WuerstchenLayerNorm with WuerstchenLayerNorm -> SDCascadeLayerNorm @@ -387,6 +473,25 @@ def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=Tru self.gradient_checkpointing = False + @classmethod + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + config = kwargs.pop("config", None) + torch_dtype = kwargs.pop("torch_dtype", None) + + checkpoint = load_single_file_model_checkpoint(pretrained_model_link_or_path, **kwargs) + if config is None: + config = infer_single_file_config(checkpoint) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model_config = cls.load_config(**config, **kwargs) + model = cls.from_config(model_config, **kwargs) + + diffusers_format_checkpoint = convert_single_file_to_diffusers(checkpoint) + load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + + return model + def _set_gradient_checkpointing(self, value=False): self.gradient_checkpointing = value From c30c00a1d154ba598158e04f0d3e47a8e568aedb Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 11 Mar 2024 05:54:58 +0000 Subject: [PATCH 2/6] update --- src/diffusers/models/unets/unet_stable_cascade.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index 8296ff52eded..5031fd94a34f 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -33,10 +33,10 @@ DEFAULT_CONFIGS = { - "stage_c": {"pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", "subfolder": "prior"}, - "stage_c_lite": {"pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", "subfolder": "prior_lite"}, - "stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"}, - "stage_b_lite": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder_lite"}, + "stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"}, + "stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"}, + "stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"}, + "stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"}, } @@ -481,10 +481,12 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): checkpoint = load_single_file_model_checkpoint(pretrained_model_link_or_path, **kwargs) if config is None: config = infer_single_file_config(checkpoint) + model_config = cls.load_config(**config, **kwargs) + else: + model_config = config ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): - model_config = cls.load_config(**config, **kwargs) model = cls.from_config(model_config, **kwargs) diffusers_format_checkpoint = convert_single_file_to_diffusers(checkpoint) From 8cd0025b674bc969feac836727c5cb20bb62d27d Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 11 Mar 2024 07:09:12 +0000 Subject: [PATCH 3/6] update --- .../unets/test_models_unet_stable_cascade.py | 158 ++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 tests/models/unets/test_models_unet_stable_cascade.py diff --git a/tests/models/unets/test_models_unet_stable_cascade.py b/tests/models/unets/test_models_unet_stable_cascade.py new file mode 100644 index 000000000000..51945201fbd2 --- /dev/null +++ b/tests/models/unets/test_models_unet_stable_cascade.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch + +from diffusers import StableCascadeUNet +from diffusers.utils import logging +from diffusers.utils.testing_utils import ( + enable_full_determinism, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, +) +from diffusers.utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) + +enable_full_determinism() + + +@slow +class StableCascadeUNetModelSlowTests(unittest.TestCase): + def tearDown(self) -> None: + gc.collect() + torch.cuda.empty_cache() + super().tearDown() + + def test_stable_cascade_prior_unet_single_file_components(self): + single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors" + single_file_unet = StableCascadeUNet.from_single_file(single_file_url) + + unet = StableCascadeUNet.from_pretrained( + "stabilityai/stable-cascade-prior", subfolder="prior", revision="refs/pr/2", variant="bf16" + ) + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"] + for param_name, param_value in single_file_unet.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + + assert unet.config[param_name] == param_value + + def test_stable_cascade_config_loading(self): + config = StableCascadeUNet.load_config( + pretrained_model_name_or_path="diffusers/stable-cascade-configs", subfolder="prior" + ) + single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors" + + single_file_unet = StableCascadeUNet.from_single_file(single_file_url, config=config) + single_file_unet_config = single_file_unet.config + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"] + for param_name, param_value in config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + + assert single_file_unet_config[param_name] == param_value + + @require_torch_gpu + def test_stable_cascade_unet_single_file_prior_forward_pass(self): + dtype = torch.bfloat16 + generator = torch.Generator("cpu") + + model_inputs = { + "sample": randn_tensor((1, 16, 24, 24), generator=generator.manual_seed(0)).to("cuda", dtype), + "timestep_ratio": torch.tensor([1]).to("cuda", dtype), + "clip_text_pooled": randn_tensor((1, 1, 1280), generator=generator.manual_seed(0)).to("cuda", dtype), + "clip_text": randn_tensor((1, 77, 1280), generator=generator.manual_seed(0)).to("cuda", dtype), + "clip_img": randn_tensor((1, 1, 768), generator=generator.manual_seed(0)).to("cuda", dtype), + "pixels": randn_tensor((1, 3, 8, 8), generator=generator.manual_seed(0)).to("cuda", dtype), + } + + unet = StableCascadeUNet.from_pretrained( + "stabilityai/stable-cascade-prior", + subfolder="prior", + revision="refs/pr/2", + variant="bf16", + torch_dtype=dtype, + ) + unet.to("cuda") + with torch.no_grad(): + prior_output = unet(**model_inputs).sample.float().cpu().numpy() + + # Remove UNet from GPU memory before loading the single file UNet model + del unet + gc.collect() + torch.cuda.empty_cache() + + single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors" + single_file_unet = StableCascadeUNet.from_single_file(single_file_url, torch_dtype=dtype) + single_file_unet.to("cuda") + with torch.no_grad(): + prior_single_file_output = single_file_unet(**model_inputs).sample.float().cpu().numpy() + + # Remove UNet from GPU memory before loading the single file UNet model + del single_file_unet + gc.collect() + torch.cuda.empty_cache() + + max_diff = numpy_cosine_similarity_distance(prior_output.flatten(), prior_single_file_output.flatten()) + assert max_diff < 8e-3 + + @require_torch_gpu + def test_stable_cascade_unet_single_file_decoder_forward_pass(self): + dtype = torch.float32 + generator = torch.Generator("cpu") + + model_inputs = { + "sample": randn_tensor((1, 4, 256, 256), generator=generator.manual_seed(0)).to("cuda", dtype), + "timestep_ratio": torch.tensor([1]).to("cuda", dtype), + "clip_text": randn_tensor((1, 77, 1280), generator=generator.manual_seed(0)).to("cuda", dtype), + "clip_text_pooled": randn_tensor((1, 1, 1280), generator=generator.manual_seed(0)).to("cuda", dtype), + "pixels": randn_tensor((1, 3, 8, 8), generator=generator.manual_seed(0)).to("cuda", dtype), + } + + unet = StableCascadeUNet.from_pretrained( + "stabilityai/stable-cascade", + subfolder="decoder", + revision="refs/pr/44", + torch_dtype=dtype, + ) + unet.to("cuda") + with torch.no_grad(): + prior_output = unet(**model_inputs).sample.float().cpu().numpy() + + # Remove UNet from GPU memory before loading the single file UNet model + del unet + gc.collect() + torch.cuda.empty_cache() + + single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b.safetensors" + single_file_unet = StableCascadeUNet.from_single_file(single_file_url, torch_dtype=dtype) + single_file_unet.to("cuda") + with torch.no_grad(): + prior_single_file_output = single_file_unet(**model_inputs).sample.float().cpu().numpy() + + # Remove UNet from GPU memory before loading the single file UNet model + del single_file_unet + gc.collect() + torch.cuda.empty_cache() + + max_diff = numpy_cosine_similarity_distance(prior_output.flatten(), prior_single_file_output.flatten()) + assert max_diff < 1e-4 From 117dd8e75a80b0e59aa3d30f3bba38fa4debfeb8 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 11 Mar 2024 12:52:58 +0000 Subject: [PATCH 4/6] update --- src/diffusers/loaders/single_file_utils.py | 51 ++++++++----------- .../models/unets/unet_stable_cascade.py | 19 ++++++- .../unets/test_models_unet_stable_cascade.py | 47 +++++++++++++++-- 3 files changed, 80 insertions(+), 37 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index d4ad4a551b3f..f6ed6ad700a2 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -230,42 +230,31 @@ def fetch_ldm_config_and_checkpoint( local_files_only=None, revision=None, ): - if os.path.isfile(pretrained_model_link_or_path): - checkpoint = load_state_dict(pretrained_model_link_or_path) - - else: - repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) - checkpoint_path = _get_model_file( - repo_id, - weights_name=weights_name, - force_download=force_download, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - ) - checkpoint = load_state_dict(checkpoint_path) - - # some checkpoints contain the model state dict under a "state_dict" key - while "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - + checkpoint = load_single_file_model_checkpoint( + pretrained_model_link_or_path, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) original_config = fetch_original_config(class_name, checkpoint, original_config_file) return original_config, checkpoint -def load_single_file_model_checkpoint(pretrained_model_link_or_path, **kwargs): - resume_download = kwargs.pop("resume_download", False) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - token = kwargs.pop("token", None) - cache_dir = kwargs.pop("cache_dir", None) - local_files_only = kwargs.pop("local_files_only", None) - revision = kwargs.pop("revision", None) - +def load_single_file_model_checkpoint( + pretrained_model_link_or_path, + resume_download=False, + force_download=False, + proxies=None, + token=None, + cache_dir=None, + local_files_only=None, + revision=None, +): if os.path.isfile(pretrained_model_link_or_path): checkpoint = load_state_dict(pretrained_model_link_or_path) else: diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index 5031fd94a34f..c5b53e5128ed 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -476,9 +476,26 @@ def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=Tru @classmethod def from_single_file(cls, pretrained_model_link_or_path, **kwargs): config = kwargs.pop("config", None) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) - checkpoint = load_single_file_model_checkpoint(pretrained_model_link_or_path, **kwargs) + checkpoint = load_single_file_model_checkpoint( + pretrained_model_link_or_path, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + if config is None: config = infer_single_file_config(checkpoint) model_config = cls.load_config(**config, **kwargs) diff --git a/tests/models/unets/test_models_unet_stable_cascade.py b/tests/models/unets/test_models_unet_stable_cascade.py index 51945201fbd2..fad1dcf2448c 100644 --- a/tests/models/unets/test_models_unet_stable_cascade.py +++ b/tests/models/unets/test_models_unet_stable_cascade.py @@ -37,25 +37,59 @@ @slow class StableCascadeUNetModelSlowTests(unittest.TestCase): def tearDown(self) -> None: + super().tearDown() gc.collect() torch.cuda.empty_cache() - super().tearDown() - def test_stable_cascade_prior_unet_single_file_components(self): + def test_stable_cascade_unet_prior_single_file_components(self): single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors" single_file_unet = StableCascadeUNet.from_single_file(single_file_url) + single_file_unet_config = single_file_unet.config + del single_file_unet + gc.collect() + torch.cuda.empty_cache() + unet = StableCascadeUNet.from_pretrained( "stabilityai/stable-cascade-prior", subfolder="prior", revision="refs/pr/2", variant="bf16" ) + unet_config = unet.config + del unet + gc.collect() + torch.cuda.empty_cache() + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"] + for param_name, param_value in single_file_unet_config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + + assert unet_config[param_name] == param_value + + def test_stable_cascade_unet_decoder_single_file_components(self): + single_file_url = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_bf16.safetensors" + single_file_unet = StableCascadeUNet.from_single_file(single_file_url) + + single_file_unet_config = single_file_unet.config + del single_file_unet + gc.collect() + torch.cuda.empty_cache() + + unet = StableCascadeUNet.from_pretrained( + "stabilityai/stable-cascade", subfolder="decoder", revision="refs/pr/44", variant="bf16" + ) + unet_config = unet.config + del unet + gc.collect() + torch.cuda.empty_cache() + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"] - for param_name, param_value in single_file_unet.config.items(): + for param_name, param_value in single_file_unet_config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert unet.config[param_name] == param_value + assert unet_config[param_name] == param_value - def test_stable_cascade_config_loading(self): + def test_stable_cascade_unet_config_loading(self): config = StableCascadeUNet.load_config( pretrained_model_name_or_path="diffusers/stable-cascade-configs", subfolder="prior" ) @@ -63,6 +97,9 @@ def test_stable_cascade_config_loading(self): single_file_unet = StableCascadeUNet.from_single_file(single_file_url, config=config) single_file_unet_config = single_file_unet.config + del single_file_unet + gc.collect() + torch.cuda.empty_cache() PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"] for param_name, param_value in config.items(): From bf06b93a3541e7d6954966449e4c668272d9285a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 11 Mar 2024 13:30:18 +0000 Subject: [PATCH 5/6] update --- src/diffusers/models/unets/unet_stable_cascade.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index c5b53e5128ed..6e839cdd6c56 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -23,11 +23,13 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders.single_file_utils import load_single_file_model_checkpoint -from ...utils import BaseOutput, is_accelerate_available +from ...utils import BaseOutput, is_accelerate_available, logging from ..attention_processor import Attention from ..modeling_utils import ModelMixin, load_model_dict_into_meta +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + if is_accelerate_available(): from accelerate import init_empty_weights @@ -507,7 +509,15 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): model = cls.from_config(model_config, **kwargs) diffusers_format_checkpoint = convert_single_file_to_diffusers(checkpoint) - load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + if is_accelerate_available(): + unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + if len(unexpected_keys) > 0: + logger.warn( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + else: + model.load_state_dict(diffusers_format_checkpoint) return model From dca175beac8260afbcc7e3c8513c99b0e8b169d3 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 12 Mar 2024 11:14:41 +0000 Subject: [PATCH 6/6] update --- src/diffusers/loaders/single_file_utils.py | 81 ++++++++++ src/diffusers/loaders/unet.py | 105 +++++++++++++ .../models/unets/unet_stable_cascade.py | 141 +----------------- 3 files changed, 190 insertions(+), 137 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index e1ab7fe8a8a7..cd0483b863be 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -81,6 +81,87 @@ "timestep_spacing": "leading", } + +STABLE_CASCADE_DEFAULT_CONFIGS = { + "stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"}, + "stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"}, + "stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"}, + "stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"}, +} + + +def convert_stable_cascade_unet_single_file_to_diffusers(original_state_dict): + is_stage_c = "clip_txt_mapper.weight" in original_state_dict + + if is_stage_c: + state_dict = {} + for key in original_state_dict.keys(): + if key.endswith("in_proj_weight"): + weights = original_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = original_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = original_state_dict[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = original_state_dict[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + else: + state_dict[key] = original_state_dict[key] + else: + state_dict = {} + for key in original_state_dict.keys(): + if key.endswith("in_proj_weight"): + weights = original_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = original_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = original_state_dict[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = original_state_dict[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + # rename clip_mapper to clip_txt_pooled_mapper + elif key.endswith("clip_mapper.weight"): + weights = original_state_dict[key] + state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights + elif key.endswith("clip_mapper.bias"): + weights = original_state_dict[key] + state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights + else: + state_dict[key] = original_state_dict[key] + + return state_dict + + +def infer_stable_cascade_single_file_config(checkpoint): + is_stage_c = "clip_txt_mapper.weight" in checkpoint + is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint + + if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536): + config_type = "stage_c_lite" + elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048): + config_type = "stage_c" + elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576: + config_type = "stage_b_lite" + elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640: + config_type = "stage_b" + + return STABLE_CASCADE_DEFAULT_CONFIGS[config_type] + + DIFFUSERS_TO_LDM_MAPPING = { "unet": { "layers": { diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 9d8e2666c518..9445f147a610 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -42,6 +42,11 @@ set_adapter_layers, set_weights_and_activate_adapters, ) +from .single_file_utils import ( + convert_stable_cascade_unet_single_file_to_diffusers, + infer_stable_cascade_single_file_config, + load_single_file_model_checkpoint, +) from .utils import AttnProcsLayers @@ -896,3 +901,103 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): self.config.encoder_hid_dim_type = "ip_image_proj" self.to(dtype=self.dtype, device=self.device) + + +class FromOriginalUNetMixin: + """ + Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`]. + """ + + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or + `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + config: (`dict`, *optional*): + Dictionary containing the configuration of the model: + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables of the model. + + """ + config = kwargs.pop("config", None) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + + class_name = cls.__name__ + if class_name != "StableCascadeUNet": + raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet") + + checkpoint = load_single_file_model_checkpoint( + pretrained_model_link_or_path, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + + if config is None: + config = infer_stable_cascade_single_file_config(checkpoint) + model_config = cls.load_config(**config, **kwargs) + else: + model_config = config + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model = cls.from_config(model_config, **kwargs) + + diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint) + if is_accelerate_available(): + unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + if len(unexpected_keys) > 0: + logger.warn( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + else: + model.load_state_dict(diffusers_format_checkpoint) + + if torch_dtype is not None: + model.to(torch_dtype) + + return model diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index 6e839cdd6c56..9f81e50241a9 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -13,7 +13,6 @@ # limitations under the License. import math -from contextlib import nullcontext from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -22,96 +21,10 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders.single_file_utils import load_single_file_model_checkpoint -from ...utils import BaseOutput, is_accelerate_available, logging +from ...loaders.unet import FromOriginalUNetMixin +from ...utils import BaseOutput from ..attention_processor import Attention -from ..modeling_utils import ModelMixin, load_model_dict_into_meta - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -if is_accelerate_available(): - from accelerate import init_empty_weights - - -DEFAULT_CONFIGS = { - "stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"}, - "stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"}, - "stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"}, - "stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"}, -} - - -def convert_single_file_to_diffusers(original_state_dict): - is_stage_c = "clip_txt_mapper.weight" in original_state_dict - - if is_stage_c: - state_dict = {} - for key in original_state_dict.keys(): - if key.endswith("in_proj_weight"): - weights = original_state_dict[key].chunk(3, 0) - state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] - state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] - state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] - elif key.endswith("in_proj_bias"): - weights = original_state_dict[key].chunk(3, 0) - state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] - state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] - state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] - elif key.endswith("out_proj.weight"): - weights = original_state_dict[key] - state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights - elif key.endswith("out_proj.bias"): - weights = original_state_dict[key] - state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights - else: - state_dict[key] = original_state_dict[key] - else: - state_dict = {} - for key in original_state_dict.keys(): - if key.endswith("in_proj_weight"): - weights = original_state_dict[key].chunk(3, 0) - state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] - state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] - state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] - elif key.endswith("in_proj_bias"): - weights = original_state_dict[key].chunk(3, 0) - state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] - state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] - state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] - elif key.endswith("out_proj.weight"): - weights = original_state_dict[key] - state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights - elif key.endswith("out_proj.bias"): - weights = original_state_dict[key] - state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights - # rename clip_mapper to clip_txt_pooled_mapper - elif key.endswith("clip_mapper.weight"): - weights = original_state_dict[key] - state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights - elif key.endswith("clip_mapper.bias"): - weights = original_state_dict[key] - state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights - else: - state_dict[key] = original_state_dict[key] - - return state_dict - - -def infer_single_file_config(checkpoint): - is_stage_c = "clip_txt_mapper.weight" in checkpoint - is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint - - if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536): - config_type = "stage_c_lite" - elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048): - config_type = "stage_c" - elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576: - config_type = "stage_b_lite" - elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640: - config_type = "stage_b" - - return DEFAULT_CONFIGS[config_type] +from ..modeling_utils import ModelMixin # Copied from diffusers.pipelines.wuerstchen.modeling_wuerstchen_common.WuerstchenLayerNorm with WuerstchenLayerNorm -> SDCascadeLayerNorm @@ -222,7 +135,7 @@ class StableCascadeUNetOutput(BaseOutput): sample: torch.FloatTensor = None -class StableCascadeUNet(ModelMixin, ConfigMixin): +class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin): _supports_gradient_checkpointing = True @register_to_config @@ -475,52 +388,6 @@ def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=Tru self.gradient_checkpointing = False - @classmethod - def from_single_file(cls, pretrained_model_link_or_path, **kwargs): - config = kwargs.pop("config", None) - resume_download = kwargs.pop("resume_download", False) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - token = kwargs.pop("token", None) - cache_dir = kwargs.pop("cache_dir", None) - local_files_only = kwargs.pop("local_files_only", False) - revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) - - checkpoint = load_single_file_model_checkpoint( - pretrained_model_link_or_path, - resume_download=resume_download, - force_download=force_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - ) - - if config is None: - config = infer_single_file_config(checkpoint) - model_config = cls.load_config(**config, **kwargs) - else: - model_config = config - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - model = cls.from_config(model_config, **kwargs) - - diffusers_format_checkpoint = convert_single_file_to_diffusers(checkpoint) - if is_accelerate_available(): - unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - if len(unexpected_keys) > 0: - logger.warn( - f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" - ) - - else: - model.load_state_dict(diffusers_format_checkpoint) - - return model - def _set_gradient_checkpointing(self, value=False): self.gradient_checkpointing = value