Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,6 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
class_embed_type = "projection"
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
else:
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")

config = {
"sample_size": image_size // vae_scale_factor,
Expand All @@ -323,6 +321,12 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
"transformer_layers_per_block": transformer_layers_per_block,
}

if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions

if "num_classes" in unet_params and type(unet_params.num_classes) == int:
config["num_class_embeds"] = unet_params.num_classes

if controlnet:
config["conditioning_channels"] = unet_params.hint_channels
else:
Expand Down Expand Up @@ -441,6 +445,10 @@ def convert_ldm_unet_checkpoint(
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]

# Relevant to StableDiffusionUpscalePipeline
if "num_class_embeds" in config:
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]

new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]

Expand Down Expand Up @@ -496,6 +504,7 @@ def convert_ldm_unet_checkpoint(

if len(attentions):
paths = renew_attention_paths(attentions)

meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
Expand Down Expand Up @@ -1210,6 +1219,7 @@ def download_from_original_stable_diffusion_ckpt(
StableDiffusionControlNetPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionXLImg2ImgPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
Expand Down Expand Up @@ -1256,6 +1266,8 @@ def download_from_original_stable_diffusion_ckpt(
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
is_upscale = pipeline_class == StableDiffusionUpscalePipeline

config_url = None

# model_type = "v1"
Expand Down Expand Up @@ -1285,6 +1297,10 @@ def download_from_original_stable_diffusion_ckpt(
original_config_file = config_files["xl_refiner"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"

if is_upscale:
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"

if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)

Expand All @@ -1308,6 +1324,8 @@ def download_from_original_stable_diffusion_ckpt(

if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
num_in_channels = 9
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
num_in_channels = 7
elif num_in_channels is None:
num_in_channels = 4

Expand Down Expand Up @@ -1391,9 +1409,13 @@ def download_from_original_stable_diffusion_ckpt(
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")

if pipeline_class == StableDiffusionUpscalePipeline:
image_size = original_config.model.params.unet_config.params.image_size

# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention

path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=path, extract_ema=extract_ema
Expand Down Expand Up @@ -1458,8 +1480,29 @@ def download_from_original_stable_diffusion_ckpt(
controlnet=controlnet,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False

elif pipeline_class == StableDiffusionUpscalePipeline:
scheduler = DDIMScheduler.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler"
)
low_res_scheduler = DDPMScheduler.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler"
)

pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
low_res_scheduler=low_res_scheduler,
safety_checker=None,
feature_extractor=None,
)

else:
pipe = pipeline_class(
vae=vae,
Expand All @@ -1469,8 +1512,10 @@ def download_from_original_stable_diffusion_ckpt(
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False

else:
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
original_config, clip_stats_path=clip_stats_path, device=device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers import CLIPTextModel, CLIPTokenizer

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import deprecate, logging
Expand Down Expand Up @@ -59,7 +60,7 @@ def preprocess(image):
return image


class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, FromSingleFileMixin):
r"""
Pipeline for upscaling Stable Diffusion output image resolution by a factor of 2.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
Expand Down Expand Up @@ -67,7 +67,9 @@ def preprocess(image):
return image


class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
class StableDiffusionUpscalePipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-guided image super-resolution using Stable Diffusion 2.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
floats_tensor,
load_image,
load_numpy,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
Expand Down Expand Up @@ -479,3 +480,36 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 2.9 GB is allocated
assert mem_bytes < 2.9 * 10**9

def test_download_ckpt_diff_format_is_same(self):
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/sd2-upscale/low_res_cat.png"
)

prompt = "a cat sitting on a park bench"
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id)
pipe.enable_model_cpu_offload()

generator = torch.Generator("cpu").manual_seed(0)
output = pipe(prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3)
image_from_pretrained = output.images[0]

single_file_path = (
"https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
)
pipe_from_single_file = StableDiffusionUpscalePipeline.from_single_file(single_file_path)
pipe_from_single_file.enable_model_cpu_offload()

generator = torch.Generator("cpu").manual_seed(0)
output_from_single_file = pipe_from_single_file(
prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3
)
image_from_single_file = output_from_single_file.images[0]

assert image_from_pretrained.shape == (512, 512, 3)
assert image_from_single_file.shape == (512, 512, 3)
assert (
numpy_cosine_similarity_distance(image_from_pretrained.flatten(), image_from_single_file.flatten()) < 1e-3
)