diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 0c05f0ef7ffa..7535caba586c 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -329,6 +329,8 @@
title: Latent upscaler
- local: api/pipelines/stable_diffusion/upscale
title: Super-resolution
+ - local: api/pipelines/stable_diffusion/k_diffusion
+ title: K-Diffusion
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
- local: api/pipelines/stable_diffusion/adapter
diff --git a/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md b/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md
new file mode 100644
index 000000000000..2dd6a8dc9c53
--- /dev/null
+++ b/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md
@@ -0,0 +1,27 @@
+
+
+# K-Diffusion
+
+[k-diffusion](https://github.com/crowsonkb/k-diffusion) is a popular library created by [Katherine Crowson](https://github.com/crowsonkb/). We provide `StableDiffusionKDiffusionPipeline` and `StableDiffusionXLKDiffusionPipeline` that allow you to run Stable DIffusion with samplers from k-diffusion.
+
+Note that most the samplers from k-diffusion are implemented in Diffusers and we recommend using existing schedulers. You can find a mapping between k-diffusion samplers and schedulers in Diffusers [here](https://huggingface.co/docs/diffusers/api/schedulers/overview)
+
+
+## StableDiffusionKDiffusionPipeline
+
+[[autodoc]] StableDiffusionKDiffusionPipeline
+
+
+## StableDiffusionXLKDiffusionPipeline
+
+[[autodoc]] StableDiffusionXLKDiffusionPipeline
\ No newline at end of file
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 180b210953c1..012845a6594a 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -316,7 +316,7 @@
]
else:
- _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"])
+ _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
@@ -668,7 +668,7 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
else:
- from .pipelines import StableDiffusionKDiffusionPipeline
+ from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 2b456f4c3d08..13163f96e059 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -265,7 +265,10 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else:
- _import_structure["stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"]
+ _import_structure["stable_diffusion_k_diffusion"] = [
+ "StableDiffusionKDiffusionPipeline",
+ "StableDiffusionXLKDiffusionPipeline",
+ ]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
@@ -491,7 +494,10 @@
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
else:
- from .stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
+ from .stable_diffusion_k_diffusion import (
+ StableDiffusionKDiffusionPipeline,
+ StableDiffusionXLKDiffusionPipeline,
+ )
try:
if not is_flax_available():
diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py
index 6c4bd0047f02..7eb5bf8c229b 100644
--- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py
+++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py
@@ -30,6 +30,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else:
_import_structure["pipeline_stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"]
+ _import_structure["pipeline_stable_diffusion_xl_k_diffusion"] = ["StableDiffusionXLKDiffusionPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -45,6 +46,7 @@
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
else:
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
+ from .pipeline_stable_diffusion_xl_k_diffusion import StableDiffusionXLKDiffusionPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
index 53e5a34a3b33..71b2ef27e07d 100755
--- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
@@ -134,7 +134,15 @@ def __init__(
def set_scheduler(self, scheduler_type: str):
library = importlib.import_module("k_diffusion")
sampling = getattr(library, "sampling")
- self.sampler = getattr(sampling, scheduler_type)
+ try:
+ self.sampler = getattr(sampling, scheduler_type)
+ except Exception:
+ valid_samplers = []
+ for s in dir(sampling):
+ if "sample_" in s:
+ valid_samplers.append(s)
+
+ raise ValueError(f"Invalid scheduler type {scheduler_type}. Please choose one of {valid_samplers}.")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
new file mode 100644
index 000000000000..39f7f6f7a4f0
--- /dev/null
+++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
@@ -0,0 +1,1012 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import inspect
+from typing import List, Optional, Tuple, Union
+
+import torch
+from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
+from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
+from transformers import (
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+)
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models.attention_processor import (
+ AttnProcessor2_0,
+ FusedAttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLKDiffusionPipeline
+
+ >>> pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+ >>> pipe.set_scheduler("sample_dpmpp_2m_sde")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.ModelWrapper
+class ModelWrapper:
+ def __init__(self, model, alphas_cumprod):
+ self.model = model
+ self.alphas_cumprod = alphas_cumprod
+
+ def apply_model(self, *args, **kwargs):
+ if len(args) == 3:
+ encoder_hidden_states = args[-1]
+ args = args[:2]
+ if kwargs.get("cond", None) is not None:
+ encoder_hidden_states = kwargs.pop("cond")
+ return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample
+
+
+class StableDiffusionXLKDiffusionPipeline(
+ DiffusionPipeline,
+ FromSingleFileMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+ IPAdapterMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL and k-diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "feature_extractor",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ force_zeros_for_empty_prompt: bool = True,
+ ):
+ super().__init__()
+
+ # get correct sigmas from LMS
+ scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ self.default_sample_size = self.unet.config.sample_size
+
+ model = ModelWrapper(unet, scheduler.alphas_cumprod)
+ if scheduler.config.prediction_type == "v_prediction":
+ self.k_diffusion_model = CompVisVDenoiser(model)
+ else:
+ self.k_diffusion_model = CompVisDenoiser(model)
+
+ # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.set_scheduler
+ def set_scheduler(self, scheduler_type: str):
+ library = importlib.import_module("k_diffusion")
+ sampling = getattr(library, "sampling")
+ try:
+ self.sampler = getattr(sampling, scheduler_type)
+ except Exception:
+ valid_samplers = []
+ for s in dir(sampling):
+ if "sample_" in s:
+ valid_samplers.append(s)
+
+ raise ValueError(f"Invalid scheduler type {scheduler_type}. Please choose one of {valid_samplers}.")
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
+ def _get_add_time_ids(
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
+ ):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ FusedAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.fuse_qkv_projections
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ Args:
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+ """
+ self.fusing_unet = False
+ self.fusing_vae = False
+
+ if unet:
+ self.fusing_unet = True
+ self.unet.fuse_qkv_projections()
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
+
+ if vae:
+ if not isinstance(self.vae, AutoencoderKL):
+ raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
+
+ self.fusing_vae = True
+ self.vae.fuse_qkv_projections()
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
+
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+ """Disable QKV projection fusion if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ Args:
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+
+ """
+ if unet:
+ if not self.fusing_unet:
+ logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.unet.unfuse_qkv_projections()
+ self.fusing_unet = False
+
+ if vae:
+ if not self.fusing_vae:
+ logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.vae.unfuse_qkv_projections()
+ self.fusing_vae = False
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ use_karras_sigmas: Optional[bool] = False,
+ noise_sampler_seed: Optional[int] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ if guidance_scale <= 1.0:
+ raise ValueError("has to use guidance_scale")
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ lora_scale = None
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device)
+
+ # 5. Prepare sigmas
+ if use_karras_sigmas:
+ sigma_min: float = self.k_diffusion_model.sigmas[0].item()
+ sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
+ sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
+ sigmas = sigmas.to(device)
+ else:
+ sigmas = self.scheduler.sigmas
+ sigmas = sigmas.to(prompt_embeds.dtype)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ latents = latents * sigmas[0]
+
+ self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
+ self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ # 8. Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 9. Define model function
+ def model_fn(x, t):
+ latent_model_input = torch.cat([x] * 2)
+ t = torch.cat([t] * 2)
+
+ noise_pred = self.k_diffusion_model(
+ latent_model_input,
+ t,
+ cond=prompt_embeds,
+ timestep_cond=timestep_cond,
+ added_cond_kwargs=added_cond_kwargs,
+ )
+
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ return noise_pred
+
+ # 10. Run k-diffusion solver
+ sampler_kwargs = {}
+
+ if "noise_sampler" in inspect.signature(self.sampler).parameters:
+ min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max()
+ noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
+ sampler_kwargs["noise_sampler"] = noise_sampler
+
+ if "generator" in inspect.signature(self.sampler).parameters:
+ sampler_kwargs["generator"] = generator
+
+ latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py
index 56836f0b6d77..2ab00c54ce2d 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py
@@ -15,3 +15,18 @@ def from_config(cls, *args, **kwargs):
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "k_diffusion"])
+
+
+class StableDiffusionXLKDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "k_diffusion"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "k_diffusion"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "k_diffusion"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "k_diffusion"])
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py
new file mode 100644
index 000000000000..158385db0205
--- /dev/null
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py
@@ -0,0 +1,138 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import numpy as np
+import torch
+
+from diffusers import StableDiffusionXLKDiffusionPipeline
+from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
+
+
+enable_full_determinism()
+
+
+@slow
+@require_torch_gpu
+class StableDiffusionXLKPipelineIntegrationTests(unittest.TestCase):
+ dtype = torch.float16
+
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_stable_diffusion_xl(self):
+ sd_pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=self.dtype
+ )
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ sd_pipe.set_scheduler("sample_euler")
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.manual_seed(0)
+ output = sd_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=9.0,
+ num_inference_steps=20,
+ height=512,
+ width=512,
+ output_type="np",
+ )
+
+ image = output.images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 512, 512, 3)
+ expected_slice = np.array(
+ [0.79804534, 0.7981539, 0.8019961, 0.7936565, 0.7892033, 0.7914713, 0.7792827, 0.77754563, 0.7836789]
+ )
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_stable_diffusion_karras_sigmas(self):
+ sd_pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=self.dtype
+ )
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ sd_pipe.set_scheduler("sample_dpmpp_2m")
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.manual_seed(0)
+ output = sd_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=7.5,
+ num_inference_steps=15,
+ output_type="np",
+ use_karras_sigmas=True,
+ height=512,
+ width=512,
+ )
+
+ image = output.images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 512, 512, 3)
+ expected_slice = np.array(
+ [0.9704869, 0.9714559, 0.9693254, 0.96892524, 0.9685236, 0.9659081, 0.9666761, 0.9619067, 0.961759]
+ )
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_stable_diffusion_noise_sampler_seed(self):
+ sd_pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=self.dtype
+ )
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ sd_pipe.set_scheduler("sample_dpmpp_sde")
+
+ prompt = "A painting of a squirrel eating a burger"
+ seed = 0
+ images1 = sd_pipe(
+ [prompt],
+ generator=torch.manual_seed(seed),
+ noise_sampler_seed=seed,
+ guidance_scale=9.0,
+ num_inference_steps=20,
+ output_type="np",
+ height=512,
+ width=512,
+ ).images
+ images2 = sd_pipe(
+ [prompt],
+ generator=torch.manual_seed(seed),
+ noise_sampler_seed=seed,
+ guidance_scale=9.0,
+ num_inference_steps=20,
+ output_type="np",
+ height=512,
+ width=512,
+ ).images
+ assert images1.shape == (1, 512, 512, 3)
+ assert images2.shape == (1, 512, 512, 3)
+ assert np.abs(images1.flatten() - images2.flatten()).max() < 1e-2