@@ -477,18 +500,21 @@ Once you've generated the embeddings, pass them to the `prompt_embeds` (and `neg
```py
import torch
from diffusers import AutoPipelineForInpainting
+from diffusers.utils import make_image_grid
pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
-image = pipeline(prompt_emebds=prompt_embeds, # generated from Compel
- negative_prompt_embeds, # generated from Compel
+image = pipeline(prompt_embeds=prompt_embeds, # generated from Compel
+ negative_prompt_embeds=negative_prompt_embeds, # generated from Compel
image=init_image,
mask_image=mask_image
).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
### ControlNet
@@ -501,7 +527,7 @@ For example, let's condition an image with a ControlNet pretrained on inpaint im
import torch
import numpy as np
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
-from diffusers.utils import load_image
+from diffusers.utils import load_image, make_image_grid
# load ControlNet
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16, variant="fp16")
@@ -511,11 +537,12 @@ pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16"
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
# prepare control image
def make_inpaint_condition(init_image, mask_image):
@@ -536,7 +563,7 @@ Now generate an image from the base, mask and control images. You'll notice feat
```py
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, control_image=control_image).images[0]
-image
+make_image_grid([init_image, mask_image, PIL.Image.fromarray(np.uint8(control_image[0][0])).convert('RGB'), image], rows=2, cols=2)
```
You can take this a step further and chain it with an image-to-image pipeline to apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion):
@@ -548,13 +575,14 @@ pipeline = AutoPipelineForImage2Image.from_pretrained(
"nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
prompt = "elden ring style castle" # include the token "elden ring style" in the prompt
negative_prompt = "bad architecture, deformed, disfigured, poor details"
-image = pipeline(prompt, negative_prompt=negative_prompt, image=image).images[0]
-image
+image_elden_ring = pipeline(prompt, negative_prompt=negative_prompt, image=image).images[0]
+make_image_grid([init_image, mask_image, image, image_elden_ring], rows=2, cols=2)
```
@@ -576,17 +604,17 @@ image
It can be difficult and slow to run diffusion models if you're resource constrained, but it doesn't have to be with a few optimization tricks. One of the biggest (and easiest) optimizations you can enable is switching to memory-efficient attention. If you're using PyTorch 2.0, [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) is automatically enabled and you don't need to do anything else. For non-PyTorch 2.0 users, you can install and use [xFormers](../optimization/xformers)'s implementation of memory-efficient attention. Both options reduce memory usage and accelerate inference.
-You can also offload the model to the GPU to save even more memory:
+You can also offload the model to the CPU to save even more memory:
```diff
+ pipeline.enable_xformers_memory_efficient_attention()
+ pipeline.enable_model_cpu_offload()
```
-To speed-up your inference code even more, use [`torch_compile`](../optimization/torch2.0#torch.compile). You should wrap `torch.compile` around the most intensive component in the pipeline which is typically the UNet:
+To speed-up your inference code even more, use [`torch_compile`](../optimization/torch2.0#torchcompile). You should wrap `torch.compile` around the most intensive component in the pipeline which is typically the UNet:
```py
-pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
-Learn more in the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
\ No newline at end of file
+Learn more in the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
diff --git a/docs/source/en/using-diffusers/unconditional_image_generation.md b/docs/source/en/using-diffusers/unconditional_image_generation.md
index 3893f7cce276..c055bc75c5a4 100644
--- a/docs/source/en/using-diffusers/unconditional_image_generation.md
+++ b/docs/source/en/using-diffusers/unconditional_image_generation.md
@@ -23,16 +23,16 @@ You can use any of the 🧨 Diffusers [checkpoints](https://huggingface.co/model
-💡 Want to train your own unconditional image generation model? Take a look at the training [guide](training/unconditional_training) to learn how to generate your own images.
+💡 Want to train your own unconditional image generation model? Take a look at the training [guide](../training/unconditional_training) to learn how to generate your own images.
In this guide, you'll use [`DiffusionPipeline`] for unconditional image generation with [DDPM](https://arxiv.org/abs/2006.11239):
```python
->>> from diffusers import DiffusionPipeline
+from diffusers import DiffusionPipeline
->>> generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128", use_safetensors=True)
+generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128", use_safetensors=True)
```
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
@@ -40,13 +40,14 @@ Because the model consists of roughly 1.4 billion parameters, we strongly recomm
You can move the generator object to a GPU, just like you would in PyTorch:
```python
->>> generator.to("cuda")
+generator.to("cuda")
```
Now you can use the `generator` to generate an image:
```python
->>> image = generator().images[0]
+image = generator().images[0]
+image
```
The output is by default wrapped into a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object.
@@ -54,7 +55,7 @@ The output is by default wrapped into a [`PIL.Image`](https://pillow.readthedocs
You can save the image by calling:
```python
->>> image.save("generated_image.png")
+image.save("generated_image.png")
```
Try out the Spaces below, and feel free to play around with the inference steps parameter to see how it affects the image quality!
@@ -65,5 +66,3 @@ Try out the Spaces below, and feel free to play around with the inference steps
width="850"
height="500"
>
-
-
diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py
index d60fa19e8a7f..76fcb547b6f9 100644
--- a/examples/controlnet/train_controlnet.py
+++ b/examples/controlnet/train_controlnet.py
@@ -56,7 +56,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py
index 68162d7824ab..47483883824e 100644
--- a/examples/controlnet/train_controlnet_flax.py
+++ b/examples/controlnet/train_controlnet_flax.py
@@ -59,7 +59,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py
index 04290885cf4b..5f745966c9d4 100644
--- a/examples/controlnet/train_controlnet_sdxl.py
+++ b/examples/controlnet/train_controlnet_sdxl.py
@@ -58,7 +58,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py
index 4773446a615b..894fb39deeb8 100644
--- a/examples/custom_diffusion/train_custom_diffusion.py
+++ b/examples/custom_diffusion/train_custom_diffusion.py
@@ -62,7 +62,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
index 6d59ee4de383..88d05be4561d 100644
--- a/examples/dreambooth/train_dreambooth.py
+++ b/examples/dreambooth/train_dreambooth.py
@@ -61,7 +61,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py
index a436d36cebfd..d2c0f8697baa 100644
--- a/examples/dreambooth/train_dreambooth_flax.py
+++ b/examples/dreambooth/train_dreambooth_flax.py
@@ -35,7 +35,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index 493430cadbdf..953d8e637d1e 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -68,7 +68,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index b729f7e1896d..e8dd6777f32c 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -58,7 +58,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py
index 5f8a2d9ee150..58baca312ce2 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py
@@ -52,7 +52,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
index e2d9b2105160..288404b4728c 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
@@ -55,7 +55,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
index 4ca95ecebea9..9ad01357c1f5 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
@@ -52,7 +52,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.21.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
index 19245724ecf5..472010320d73 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
@@ -46,7 +46,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.21.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
index 3656b480e9bb..a007d8c74b0c 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
@@ -46,7 +46,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.21.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
index d21eaf3dd0b0..799f9fbcb3ac 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
@@ -51,7 +51,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.21.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
index e23be2d754fe..783678cd346b 100644
--- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py
+++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
@@ -58,7 +58,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py
index e216529b2f54..89e154ef8825 100644
--- a/examples/text_to_image/train_text_to_image.py
+++ b/examples/text_to_image/train_text_to_image.py
@@ -53,7 +53,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py
index 63ea53c52a11..64b71b4f83ae 100644
--- a/examples/text_to_image/train_text_to_image_flax.py
+++ b/examples/text_to_image/train_text_to_image_flax.py
@@ -33,7 +33,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py
index eac0f18f49f4..de4076a2ceaf 100644
--- a/examples/text_to_image/train_text_to_image_lora.py
+++ b/examples/text_to_image/train_text_to_image_lora.py
@@ -49,7 +49,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py
index 6fbeae8b1f93..f0d83d55e9bf 100644
--- a/examples/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -58,7 +58,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py
index 4a3048a0ba23..a385795b1a4f 100644
--- a/examples/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_sdxl.py
@@ -57,7 +57,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index 01830751ffe2..55c907663249 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -79,7 +79,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py
index 224c1147be9f..938454eecb6e 100644
--- a/examples/textual_inversion/textual_inversion_flax.py
+++ b/examples/textual_inversion/textual_inversion_flax.py
@@ -56,7 +56,7 @@
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py
index 74b8ed106834..99c858778259 100644
--- a/examples/unconditional_image_generation/train_unconditional.py
+++ b/examples/unconditional_image_generation/train_unconditional.py
@@ -29,7 +29,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0.dev0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
index 5235fa99cfdd..48e5b96087de 100644
--- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
@@ -50,7 +50,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
index 92f63c93fc1a..b1e5abaaa278 100644
--- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
+++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
@@ -51,7 +51,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.22.0")
+check_min_version("0.23.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/scripts/convert_pixart_alpha_to_diffusers.py b/scripts/convert_pixart_alpha_to_diffusers.py
new file mode 100644
index 000000000000..fc037c87f5d5
--- /dev/null
+++ b/scripts/convert_pixart_alpha_to_diffusers.py
@@ -0,0 +1,198 @@
+import argparse
+import os
+
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPipeline, Transformer2DModel
+
+
+ckpt_id = "PixArt-alpha/PixArt-alpha"
+# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125
+interpolation_scale = {512: 1, 1024: 2}
+
+
+def main(args):
+ all_state_dict = torch.load(args.orig_ckpt_path)
+ state_dict = all_state_dict.pop("state_dict")
+ converted_state_dict = {}
+
+ # Patch embeddings.
+ converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
+ converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
+
+ # Caption projection.
+ converted_state_dict["caption_projection.y_embedding"] = state_dict.pop("y_embedder.y_embedding")
+ converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
+ converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
+ converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
+ converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
+
+ # AdaLN-single LN
+ converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
+ "t_embedder.mlp.0.weight"
+ )
+ converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
+ converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
+ "t_embedder.mlp.2.weight"
+ )
+ converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
+
+ if args.image_size == 1024:
+ # Resolution.
+ converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = state_dict.pop(
+ "csize_embedder.mlp.0.weight"
+ )
+ converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = state_dict.pop(
+ "csize_embedder.mlp.0.bias"
+ )
+ converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = state_dict.pop(
+ "csize_embedder.mlp.2.weight"
+ )
+ converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = state_dict.pop(
+ "csize_embedder.mlp.2.bias"
+ )
+ # Aspect ratio.
+ converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = state_dict.pop(
+ "ar_embedder.mlp.0.weight"
+ )
+ converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = state_dict.pop(
+ "ar_embedder.mlp.0.bias"
+ )
+ converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = state_dict.pop(
+ "ar_embedder.mlp.2.weight"
+ )
+ converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = state_dict.pop(
+ "ar_embedder.mlp.2.bias"
+ )
+ # Shared norm.
+ converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight")
+ converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias")
+
+ for depth in range(28):
+ # Transformer blocks.
+ converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
+ f"blocks.{depth}.scale_shift_table"
+ )
+
+ # Attention is all you need 🤘
+
+ # Self attention.
+ q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
+ q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0)
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
+ # Projection.
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
+ f"blocks.{depth}.attn.proj.bias"
+ )
+
+ # Feed-forward.
+ converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.fc1.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop(
+ f"blocks.{depth}.mlp.fc1.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.fc2.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop(
+ f"blocks.{depth}.mlp.fc2.bias"
+ )
+
+ # Cross-attention.
+ q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
+ q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
+ k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
+ k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
+
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
+
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.proj.bias"
+ )
+
+ # Final block.
+ converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
+ converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
+
+ # DiT XL/2
+ transformer = Transformer2DModel(
+ sample_size=args.image_size // 8,
+ num_layers=28,
+ attention_head_dim=72,
+ in_channels=4,
+ out_channels=8,
+ patch_size=2,
+ attention_bias=True,
+ num_attention_heads=16,
+ cross_attention_dim=1152,
+ activation_fn="gelu-approximate",
+ num_embeds_ada_norm=1000,
+ norm_type="ada_norm_single",
+ norm_elementwise_affine=False,
+ norm_eps=1e-6,
+ caption_channels=4096,
+ )
+ transformer.load_state_dict(converted_state_dict, strict=True)
+
+ assert transformer.pos_embed.pos_embed is not None
+ state_dict.pop("pos_embed")
+ assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
+
+ num_model_params = sum(p.numel() for p in transformer.parameters())
+ print(f"Total number of transformer parameters: {num_model_params}")
+
+ if args.only_transformer:
+ transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
+ else:
+ scheduler = DPMSolverMultistepScheduler()
+
+ vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="sd-vae-ft-ema")
+
+ tokenizer = T5Tokenizer.from_pretrained(ckpt_id, subfolder="t5-v1_1-xxl")
+ text_encoder = T5EncoderModel.from_pretrained(ckpt_id, subfolder="t5-v1_1-xxl")
+
+ pipeline = PixArtAlphaPipeline(
+ tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
+ )
+
+ pipeline.save_pretrained(args.dump_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
+ )
+ parser.add_argument(
+ "--image_size",
+ default=1024,
+ type=int,
+ choices=[512, 1024],
+ required=False,
+ help="Image size of pretrained model, either 512 or 1024.",
+ )
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
+ parser.add_argument("--only_transformer", default=True, type=bool, required=True)
+
+ args = parser.parse_args()
+ main(args)
diff --git a/setup.py b/setup.py
index 7ad5646d4fca..c2c8e75c24ae 100644
--- a/setup.py
+++ b/setup.py
@@ -244,7 +244,7 @@ def run(self):
setup(
name="diffusers",
- version="0.22.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="0.23.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index c970128fdf16..4291e911ac74 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.22.0.dev0"
+__version__ = "0.23.0.dev0"
from typing import TYPE_CHECKING
@@ -235,6 +235,7 @@
"LDMTextToImagePipeline",
"MusicLDMPipeline",
"PaintByExamplePipeline",
+ "PixArtAlphaPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -579,6 +580,7 @@
LDMTextToImagePipeline,
MusicLDMPipeline,
PaintByExamplePipeline,
+ PixArtAlphaPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py
index 87e0e164026f..2fa1c61fd809 100644
--- a/src/diffusers/loaders.py
+++ b/src/diffusers/loaders.py
@@ -2390,7 +2390,7 @@ def unfuse_text_encoder_lora(text_encoder):
def set_adapters_for_text_encoder(
self,
adapter_names: Union[List[str], str],
- text_encoder: Optional[PreTrainedModel] = None,
+ text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
text_encoder_weights: List[float] = None,
):
"""
@@ -2429,7 +2429,7 @@ def process_weights(adapter_names, weights):
)
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
- def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
+ def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
"""
Disables the LoRA layers for the text encoder.
@@ -2446,7 +2446,7 @@ def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel]
raise ValueError("Text Encoder not found.")
set_adapter_layers(text_encoder, enabled=False)
- def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
+ def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
"""
Enables the LoRA layers for the text encoder.
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index cb2f24a52786..0c4c5de6e31a 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -117,7 +117,8 @@ def __init__(
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
- norm_type: str = "layer_norm",
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
+ norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
@@ -128,6 +129,8 @@ def __init__(
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
@@ -152,7 +155,8 @@ def __init__(
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
@@ -171,7 +175,7 @@ def __init__(
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
)
self.attn2 = Attention(
query_dim=dim,
@@ -187,13 +191,19 @@ def __init__(
self.attn2 = None
# 3. Feed-forward
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ if not self.use_ada_layer_norm_single:
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+ # 5. Scale-shift for PixArt-Alpha.
+ if self.use_ada_layer_norm_single:
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
@@ -215,14 +225,25 @@ def forward(
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
- else:
+ elif self.use_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states)
+ elif self.use_ada_layer_norm_single:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.squeeze(1)
+ else:
+ raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
@@ -242,19 +263,31 @@ def forward(
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.use_ada_layer_norm_single:
+ attn_output = gate_msa * attn_output
+
hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
- # 2.5 ends
# 3. Cross-Attention
if self.attn2 is not None:
- norm_hidden_states = (
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
- )
- if self.pos_embed is not None:
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm2(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
+ norm_hidden_states = self.norm2(hidden_states)
+ elif self.use_ada_layer_norm_single:
+ # For PixArt norm2 isn't applied here:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ norm_hidden_states = hidden_states
+ else:
+ raise ValueError("Incorrect norm")
+
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
@@ -266,11 +299,16 @@ def forward(
hidden_states = attn_output + hidden_states
# 4. Feed-forward
- norm_hidden_states = self.norm3(hidden_states)
+ if not self.use_ada_layer_norm_single:
+ norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ if self.use_ada_layer_norm_single:
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
@@ -291,8 +329,12 @@ def forward(
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.use_ada_layer_norm_single:
+ ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
return hidden_states
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index efed305a0e96..1234dbd2d5ce 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -378,7 +378,7 @@ def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False)
_remove_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to remove LoRA layers from the model.
"""
- if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
+ if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
deprecate(
"set_processor to offload LoRA",
"0.26.0",
@@ -879,6 +879,9 @@ def __call__(
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
@@ -891,17 +894,17 @@ def __call__(
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- query = attn.to_q(hidden_states, scale=scale)
+ query = attn.to_q(hidden_states, *args)
query = attn.head_to_batch_dim(query)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
- key = attn.to_k(hidden_states, scale=scale)
- value = attn.to_v(hidden_states, scale=scale)
+ key = attn.to_k(hidden_states, *args)
+ value = attn.to_v(hidden_states, *args)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
@@ -915,7 +918,7 @@ def __call__(
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
- hidden_states = attn.to_out[0](hidden_states, scale=scale)
+ hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -946,6 +949,9 @@ def __call__(
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
@@ -958,7 +964,7 @@ def __call__(
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- query = attn.to_q(hidden_states, scale=scale)
+ query = attn.to_q(hidden_states, *args)
query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
@@ -967,8 +973,8 @@ def __call__(
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention:
- key = attn.to_k(hidden_states, scale=scale)
- value = attn.to_v(hidden_states, scale=scale)
+ key = attn.to_k(hidden_states, *args)
+ value = attn.to_v(hidden_states, *args)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
@@ -985,7 +991,7 @@ def __call__(
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj
- hidden_states = attn.to_out[0](hidden_states, scale=scale)
+ hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1177,6 +1183,8 @@ def __call__(
) -> torch.FloatTensor:
residual = hidden_states
+ args = () if USE_PEFT_BACKEND else (scale,)
+
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1207,12 +1215,8 @@ def __call__(
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = (
- attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states)
- )
- value = (
- attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
- )
+ key = attn.to_k(encoder_hidden_states, *args)
+ value = attn.to_v(encoder_hidden_states, *args)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@@ -1232,9 +1236,7 @@ def __call__(
hidden_states = hidden_states.to(query.dtype)
# linear proj
- hidden_states = (
- attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
- )
+ hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1361,6 +1363,7 @@ def __call__(
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
+
return hidden_states
@@ -1433,8 +1436,11 @@ def __call__(
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv:
- key = self.to_k_custom_diffusion(encoder_hidden_states)
- value = self.to_v_custom_diffusion(encoder_hidden_states)
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
+ key = key.to(attn.to_q.weight.dtype)
+ value = value.to(attn.to_q.weight.dtype)
+
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index f1128e518e2a..a377ae267411 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -66,17 +66,22 @@ def get_timestep_embedding(
return emb
-def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
+def get_2d_sincos_pos_embed(
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
+):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
- grid_h = np.arange(grid_size, dtype=np.float32)
- grid_w = np.arange(grid_size, dtype=np.float32)
+ if isinstance(grid_size, int):
+ grid_size = (grid_size, grid_size)
+
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
- grid = grid.reshape([2, 1, grid_size, grid_size])
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
@@ -129,6 +134,7 @@ def __init__(
layer_norm=False,
flatten=True,
bias=True,
+ interpolation_scale=1,
):
super().__init__()
@@ -144,16 +150,41 @@ def __init__(
else:
self.norm = None
- pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
+ self.patch_size = patch_size
+ # See:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
+ self.height, self.width = height // patch_size, width // patch_size
+ self.base_size = height // patch_size
+ self.interpolation_scale = interpolation_scale
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
+ )
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
+
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
- return latent + self.pos_embed
+
+ # Interpolate positional embeddings if needed.
+ # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
+ if self.height != height or self.width != width:
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=self.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=self.base_size,
+ interpolation_scale=self.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed)
+ pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = self.pos_embed
+
+ return (latent + pos_embed).to(latent.dtype)
class TimestepEmbedding(nn.Module):
@@ -683,3 +714,79 @@ def forward(
objs = torch.cat([objs_text, objs_image], dim=1)
return objs
+
+
+class CombinedTimestepSizeEmbeddings(nn.Module):
+ """
+ For PixArt-Alpha.
+
+ Reference:
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
+ """
+
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.outdim = size_emb_dim
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.use_additional_conditions = use_additional_conditions
+ if use_additional_conditions:
+ self.use_additional_conditions = True
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+ self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+
+ def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
+ if size.ndim == 1:
+ size = size[:, None]
+
+ if size.shape[0] != batch_size:
+ size = size.repeat(batch_size // size.shape[0], 1)
+ if size.shape[0] != batch_size:
+ raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
+
+ current_batch_size, dims = size.shape[0], size.shape[1]
+ size = size.reshape(-1)
+ size_freq = self.additional_condition_proj(size).to(size.dtype)
+
+ size_emb = embedder(size_freq)
+ size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
+ return size_emb
+
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ if self.use_additional_conditions:
+ resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
+ aspect_ratio = self.apply_condition(
+ aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
+ )
+ conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
+ else:
+ conditioning = timesteps_emb
+
+ return conditioning
+
+
+class CaptionProjection(nn.Module):
+ """
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
+
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
+ """
+
+ def __init__(self, in_features, hidden_size, num_tokens=120):
+ super().__init__()
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
+ self.act_1 = nn.GELU(approximate="tanh")
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
+ self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))
+
+ def forward(self, caption, force_drop_ids=None):
+ hidden_states = self.linear_1(caption)
+ hidden_states = self.act_1(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py
index dd451b5f3bfc..cedeff18f351 100644
--- a/src/diffusers/models/normalization.py
+++ b/src/diffusers/models/normalization.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Tuple
+from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .activations import get_activation
-from .embeddings import CombinedTimestepLabelEmbeddings
+from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings
class AdaLayerNorm(nn.Module):
@@ -77,6 +77,39 @@ def forward(
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+class AdaLayerNormSingle(nn.Module):
+ r"""
+ Norm layer adaptive layer norm single (adaLN-single).
+
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
+ """
+
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.emb = CombinedTimestepSizeEmbeddings(
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
+ )
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ batch_size: int = None,
+ hidden_dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # No modulation happening here.
+ embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
+
+
class AdaGroupNorm(nn.Module):
r"""
GroupNorm layer modified to incorporate timestep embeddings.
diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py
index 8fe66aacf5db..868e2e5fae2c 100644
--- a/src/diffusers/models/resnet.py
+++ b/src/diffusers/models/resnet.py
@@ -778,16 +778,22 @@ class Conv1dBlock(nn.Module):
out_channels (`int`): Number of output channels.
kernel_size (`int` or `tuple`): Size of the convolving kernel.
n_groups (`int`, default `8`): Number of groups to separate the channels into.
+ activation (`str`, defaults `mish`): Name of the activation function.
"""
def __init__(
- self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8
+ self,
+ inp_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ n_groups: int = 8,
+ activation: str = "mish",
):
super().__init__()
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.group_norm = nn.GroupNorm(n_groups, out_channels)
- self.mish = nn.Mish()
+ self.mish = get_activation(activation)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
intermediate_repr = self.conv1d(inputs)
@@ -808,16 +814,22 @@ class ResidualTemporalBlock1D(nn.Module):
out_channels (`int`): Number of output channels.
embed_dim (`int`): Embedding dimension.
kernel_size (`int` or `tuple`): Size of the convolving kernel.
+ activation (`str`, defaults `mish`): It is possible to choose the right activation function.
"""
def __init__(
- self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5
+ self,
+ inp_channels: int,
+ out_channels: int,
+ embed_dim: int,
+ kernel_size: Union[int, Tuple[int, int]] = 5,
+ activation: str = "mish",
):
super().__init__()
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
- self.time_emb_act = nn.Mish()
+ self.time_emb_act = get_activation(activation)
self.time_emb = nn.Linear(embed_dim, out_channels)
self.residual_conv = (
diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py
index 0f00932f3014..24abf54d6da7 100644
--- a/src/diffusers/models/transformer_2d.py
+++ b/src/diffusers/models/transformer_2d.py
@@ -22,9 +22,10 @@
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate
from .attention import BasicTransformerBlock
-from .embeddings import PatchEmbed
+from .embeddings import CaptionProjection, PatchEmbed
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .modeling_utils import ModelMixin
+from .normalization import AdaLayerNormSingle
@dataclass
@@ -92,7 +93,9 @@ def __init__(
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
attention_type: str = "default",
+ caption_channels: int = None,
):
super().__init__()
self.use_linear_projection = use_linear_projection
@@ -164,12 +167,15 @@ def __init__(
self.width = sample_size
self.patch_size = patch_size
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
+ interpolation_scale = max(interpolation_scale, 1)
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
+ interpolation_scale=interpolation_scale,
)
# 3. Define transformers blocks
@@ -189,6 +195,7 @@ def __init__(
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
attention_type=attention_type,
)
for d in range(num_layers)
@@ -206,10 +213,27 @@ def __init__(
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
- elif self.is_input_patches:
+ elif self.is_input_patches and norm_type != "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+ elif self.is_input_patches and norm_type == "ada_norm_single":
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+
+ # 5. PixArt-Alpha blocks.
+ self.adaln_single = None
+ self.use_additional_conditions = False
+ if norm_type == "ada_norm_single":
+ self.use_additional_conditions = self.config.sample_size == 128
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
+ # additional conditions until we find better name
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
+
+ self.caption_projection = None
+ if caption_channels is not None:
+ self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
self.gradient_checkpointing = False
@@ -218,6 +242,7 @@ def forward(
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
@@ -314,9 +339,25 @@ def forward(
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = self.pos_embed(hidden_states)
+ if self.adaln_single is not None:
+ if self.use_additional_conditions and added_cond_kwargs is None:
+ raise ValueError(
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
+ )
+ batch_size = hidden_states.shape[0]
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
# 2. Blocks
+ if self.caption_projection is not None:
+ batch_size = hidden_states.shape[0]
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
@@ -367,17 +408,26 @@ def forward(
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
- elif self.is_input_patches:
- # TODO: cleanup!
- conditioning = self.transformer_blocks[0].norm1.emb(
- timestep, class_labels, hidden_dtype=hidden_states.dtype
- )
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
- hidden_states = self.proj_out_2(hidden_states)
+
+ if self.is_input_patches:
+ if self.config.norm_type != "ada_norm_single":
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ hidden_states = self.proj_out_2(hidden_states)
+ elif self.config.norm_type == "ada_norm_single":
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale) + shift
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.squeeze(1)
# unpatchify
- height = width = int(hidden_states.shape[1] ** 0.5)
+ if self.adaln_single is None:
+ height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py
index 08ad122c3891..0c93b9142bea 100644
--- a/src/diffusers/models/vq_model.py
+++ b/src/diffusers/models/vq_model.py
@@ -162,8 +162,8 @@ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
is returned.
"""
- x = sample
- h = self.encode(x).latents
+
+ h = self.encode(sample).latents
dec = self.decode(h).sample
if not return_dict:
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 851f516da7cd..879bd6d98aa6 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -117,6 +117,7 @@
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
+ _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_diffusion"].extend(
@@ -341,6 +342,7 @@
from .latent_diffusion import LDMTextToImagePipeline
from .musicldm import MusicLDMPipeline
from .paint_by_example import PaintByExamplePipeline
+ from .pixart_alpha import PixArtAlphaPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_diffusion import (
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
index 49947f9dbf32..b63acb9a5f30 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
@@ -498,7 +498,7 @@ def prepare_latents(
@torch.no_grad()
def __call__(
self,
- prompt: Union[str, List[str]],
+ prompt: Union[str, List[str]] = None,
num_frames: Optional[int] = 16,
height: Optional[int] = None,
width: Optional[int] = None,
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index b6e6f48126bd..3144956ee6d4 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -43,6 +43,7 @@
KandinskyV22Pipeline,
)
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
+from .pixart_alpha import PixArtAlphaPipeline
from .stable_diffusion import (
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
@@ -67,6 +68,7 @@
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
("wuerstchen", WuerstchenCombinedPipeline),
("lcm", LatentConsistencyModelPipeline),
+ ("pixart", PixArtAlphaPipeline),
]
)
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
index e595b3423995..8380dd210d9c 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
@@ -1109,8 +1109,6 @@ def __call__(
nsfw_detected = None
watermark_detected = None
- if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
- self.unet_offload_hook.offload()
else:
# 10. Post-processing
image = (image / 2 + 0.5).clamp(0, 1)
@@ -1119,9 +1117,7 @@ def __call__(
# 11. Run safety checker
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
- # Offload last model to CPU
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
- self.final_offload_hook.offload()
+ self.maybe_free_model_hooks()
if not return_dict:
return (image, nsfw_detected, watermark_detected)
diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py
index 022aa1202603..f22d429d7c66 100644
--- a/src/diffusers/pipelines/dit/pipeline_dit.py
+++ b/src/diffusers/pipelines/dit/pipeline_dit.py
@@ -166,7 +166,6 @@ def __call__(
# set step values
self.scheduler.set_timesteps(num_inference_steps)
-
for t in self.progress_bar(self.scheduler.timesteps):
if guidance_scale > 1:
half = latent_model_input[: len(latent_model_input) // 2]
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
index 5c78b0dce87e..5e7a69e756ce 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
@@ -388,6 +388,8 @@ def __call__(
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+ self.maybe_free_model_hooks()
+
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
index 25508e1e080f..eff8af4c723e 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
@@ -321,6 +321,9 @@ def __call__(
callback_steps=callback_steps,
return_dict=return_dict,
)
+
+ self.maybe_free_model_hooks()
+
return outputs
@@ -558,6 +561,9 @@ def __call__(
callback_steps=callback_steps,
return_dict=return_dict,
)
+
+ self.maybe_free_model_hooks()
+
return outputs
@@ -593,7 +599,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
"""
_load_connected_pipes = True
- model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq"
+ model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
def __init__(
self,
@@ -802,4 +808,7 @@ def __call__(
callback_steps=callback_steps,
return_dict=return_dict,
)
+
+ self.maybe_free_model_hooks()
+
return outputs
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
index a22823aadef4..c5e7af270906 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
@@ -481,6 +481,8 @@ def __call__(
# 7. post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+ self.maybe_free_model_hooks()
+
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
index 144e3ce585af..e9b5eb5cdd70 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
@@ -616,6 +616,8 @@ def __call__(
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+ self.maybe_free_model_hooks()
+
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
index c9a6019a8eac..a9c12b258974 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
@@ -527,7 +527,7 @@ def __call__(
if negative_prompt is None:
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
- self.maybe_free_model_hooks
+ self.maybe_free_model_hooks()
else:
image_embeddings, zero_embeds = image_embeddings.chunk(2)
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
index 097673d904f5..2c7caa6214e5 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
@@ -326,6 +326,8 @@ def __call__(
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
+ self.maybe_free_model_hooks()
+
return outputs
@@ -572,6 +574,8 @@ def __call__(
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
+
+ self.maybe_free_model_hooks()
return outputs
@@ -842,4 +846,6 @@ def __call__(
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
**kwargs,
)
+ self.maybe_free_model_hooks()
+
return outputs
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
index 345b3ae65721..8d0e788b9dd9 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
@@ -531,14 +531,10 @@ def __call__(
# if negative prompt has been defined, we retrieve split the image embedding into two
if negative_prompt is None:
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
-
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
- self.final_offload_hook.offload()
else:
image_embeddings, zero_embeds = image_embeddings.chunk(2)
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
- self.prior_hook.offload()
+ self.maybe_free_model_hooks()
if output_type not in ["pt", "np"]:
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
index b4a6a64137ec..bef70821c605 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
@@ -545,12 +545,10 @@ def __call__(
# if negative prompt has been defined, we retrieve split the image embedding into two
if negative_prompt is None:
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
- self.final_offload_hook.offload()
else:
image_embeddings, zero_embeds = image_embeddings.chunk(2)
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
- self.prior_hook.offload()
+
+ self.maybe_free_model_hooks()
if output_type not in ["pt", "np"]:
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/latent_consistency_models/__init__.py b/src/diffusers/pipelines/latent_consistency_models/__init__.py
index 14002058cdfd..8f79d3c4773f 100644
--- a/src/diffusers/pipelines/latent_consistency_models/__init__.py
+++ b/src/diffusers/pipelines/latent_consistency_models/__init__.py
@@ -1,19 +1,40 @@
from typing import TYPE_CHECKING
from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
_LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
)
-_import_structure = {
- "pipeline_latent_consistency_img2img": ["LatentConsistencyModelImg2ImgPipeline"],
- "pipeline_latent_consistency_text2img": ["LatentConsistencyModelPipeline"],
-}
+_dummy_objects = {}
+_import_structure = {}
-if TYPE_CHECKING:
- from .pipeline_latent_consistency_img2img import LatentConsistencyModelImg2ImgPipeline
- from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_latent_consistency_img2img"] = ["LatentConsistencyModelImg2ImgPipeline"]
+ _import_structure["pipeline_latent_consistency_text2img"] = ["LatentConsistencyModelPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_latent_consistency_img2img import LatentConsistencyModelImg2ImgPipeline
+ from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline
else:
import sys
@@ -24,3 +45,6 @@
_import_structure,
module_spec=__spec__,
)
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
index ccc84e22c252..679415db7f3a 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
@@ -60,7 +60,7 @@ def retrieve_latents(encoder_output, generator):
>>> import torch
>>> import PIL
- >>> pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7")
+ >>> pipe = AutoPipelineForImage2Image.from_pretrained("SimianLuo/LCM_Dreamshaper_v7")
>>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality.
>>> pipe.to(torch_device="cuda", torch_dtype=torch.float32)
@@ -738,7 +738,7 @@ def __call__(
if original_inference_steps is not None
else self.scheduler.config.original_inference_steps
)
- latent_timestep = torch.tensor(int(strength * original_inference_steps))
+ latent_timestep = timesteps[:1]
latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index 8baafbaef115..6437732d0315 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -158,9 +158,9 @@ def is_safetensors_compatible(filenames, variant=None, passed_components=None) -
continue
if extension == ".bin":
- pt_filenames.append(filename)
+ pt_filenames.append(os.path.normpath(filename))
elif extension == ".safetensors":
- sf_filenames.add(filename)
+ sf_filenames.add(os.path.normpath(filename))
for filename in pt_filenames:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
@@ -172,9 +172,8 @@ def is_safetensors_compatible(filenames, variant=None, passed_components=None) -
else:
filename = filename
- expected_sf_filename = os.path.join(path, filename)
+ expected_sf_filename = os.path.normpath(os.path.join(path, filename))
expected_sf_filename = f"{expected_sf_filename}.safetensors"
-
if expected_sf_filename not in sf_filenames:
logger.warning(f"{expected_sf_filename} not found")
return False
@@ -353,13 +352,18 @@ def _get_pipeline_class(
else:
file_name = CUSTOM_PIPELINE_FILE_NAME
+ if repo_id is not None and hub_revision is not None:
+ # if we load the pipeline code from the Hub
+ # make sure to overwrite the `revison`
+ revision = hub_revision
+
return get_class_from_dynamic_module(
custom_pipeline,
module_file=file_name,
class_name=class_name,
repo_id=repo_id,
cache_dir=cache_dir,
- revision=revision if hub_revision is None else hub_revision,
+ revision=revision,
)
if class_obj != DiffusionPipeline:
@@ -1769,7 +1773,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
)
):
raise EnvironmentError(
- f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})"
+ f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
diff --git a/src/diffusers/pipelines/pixart_alpha/__init__.py b/src/diffusers/pipelines/pixart_alpha/__init__.py
new file mode 100644
index 000000000000..0bfa28fcde50
--- /dev/null
+++ b/src/diffusers/pipelines/pixart_alpha/__init__.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_pixart_alpha import PixArtAlphaPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
new file mode 100644
index 000000000000..386971b2ac3d
--- /dev/null
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
@@ -0,0 +1,770 @@
+# Copyright 2023 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+from torchvision import transforms as T
+
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL, Transformer2DModel
+from ...schedulers import DPMSolverMultistepScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ is_bs4_available,
+ is_ftfy_available,
+ logging,
+ replace_example_docstring,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import PixArtAlphaPipeline
+
+ >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
+ >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
+ >>> # Enable memory optimizations.
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+ASPECT_RATIO_1024_TEST = {
+ '0.25': [512., 2048.], '0.28': [512., 1856.],
+ '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
+ '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
+ '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
+ '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
+ '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
+ '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
+ '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
+ '2.5': [1600., 640.], '3.0': [1728., 576.],
+ '4.0': [2048., 512.],
+}
+
+
+class PixArtAlphaPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using PixArt-Alpha.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. PixArt-Alpha uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`Transformer2DModel`]):
+ A text conditioned `Transformer2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+ bad_punct_regex = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder"]
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKL,
+ transformer: Transformer2DModel,
+ scheduler: DPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
+ def mask_text_embeddings(self, emb, mask):
+ if emb.shape[0] == 1:
+ keep_index = mask.sum().item()
+ return emb[:, :, :keep_index, :], keep_index
+ else:
+ masked_feature = emb * mask[:, None, :, None]
+ return masked_feature, emb.shape[2]
+
+ @staticmethod
+ def classify_height_width_bin(height: int, width: int, ratios: dict):
+ ar = float(height / width)
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
+ default_hw = ratios[closest_ratio]
+ return int(default_hw[0]), int(default_hw[1])
+
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clean_caption: bool = False,
+ mask_feature: bool = True,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
+ string.
+ clean_caption (bool, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ mask_feature: (bool, defaults to `True`):
+ If `True`, the function will mask the text embeddings.
+ """
+ embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # See Section 3.1. of the paper.
+ max_length = 120
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ attention_mask = text_inputs.attention_mask.to(device)
+ prompt_embeds_attention_mask = attention_mask
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
+ prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens = [negative_prompt] * batch_size
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ attention_mask = uncond_input.attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ else:
+ negative_prompt_embeds = None
+
+ # Perform additional masking.
+ if mask_feature and not embeds_initially_provided:
+ prompt_embeds = prompt_embeds.unsqueeze(1)
+ masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
+ masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
+ masked_negative_prompt_embeds = (
+ negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
+ )
+ return masked_prompt_embeds, masked_negative_prompt_embeds
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_steps,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("
", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @staticmethod
+ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int):
+ orig_hw = torch.tensor([samples.shape[2], samples.shape[3]])
+ custom_hw = torch.tensor([new_height, new_width])
+
+ if (orig_hw != custom_hw).all():
+ ratio = max(custom_hw[0] / orig_hw[0], custom_hw[1] / orig_hw[1])
+ resized_width = int(orig_hw[1] * ratio)
+ resized_height = int(orig_hw[0] * ratio)
+
+ transform = T.Compose([
+ T.Resize((resized_height, resized_width)),
+ T.CenterCrop(custom_hw.tolist())
+ ])
+ return transform(samples)
+ else:
+ return samples
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 20,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.5,
+ num_images_per_prompt: Optional[int] = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ clean_caption: bool = True,
+ mask_feature: bool = True,
+ use_bin_classifier: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ # 1. Check inputs. Raise error if not correct
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
+ if use_bin_classifier:
+ orig_height, orig_width = height, width
+ height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_TEST)
+ self.check_inputs(
+ prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ clean_caption=clean_caption,
+ mask_feature=mask_feature,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Prepare micro-conditions.
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
+ if self.transformer.config.sample_size == 128:
+ resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
+ resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
+ aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ current_timestep = t
+ if not torch.is_tensor(current_timestep):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = latent_model_input.device.type == "mps"
+ if isinstance(current_timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
+ elif len(current_timestep.shape) == 0:
+ current_timestep = current_timestep[None].to(latent_model_input.device)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=current_timestep,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+ else:
+ noise_pred = noise_pred
+
+ # compute previous image: x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ if use_bin_classifier:
+ image = self.resize_and_crop_tensor(image, orig_width, orig_height)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
index 8a5eb066f4fa..9bdb6d824f99 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
@@ -918,6 +918,7 @@ def __call__(
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+ self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
index f897b51941a6..2e040306abfd 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
@@ -1027,6 +1027,7 @@ def __call__(
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+ self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
index 7f6845128f6c..36efb01f23ef 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
@@ -846,6 +846,7 @@ def __call__(
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
+ self.maybe_free_model_hooks()
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
index c6797a0693cc..e8f48a163066 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
@@ -439,6 +439,8 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+ self.maybe_free_model_hooks()
+
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
index 1e8c98c44750..4cde54ac587a 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
@@ -511,6 +511,8 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type)
+ self.maybe_free_model_hooks()
+
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
index f53e34e9259a..ce3e694e7e32 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
@@ -802,6 +802,8 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+ self.maybe_free_model_hooks()
+
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
index 80f1d49ae297..56eb38c653ba 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
@@ -741,6 +741,8 @@ def get_map_size(module, input, output):
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+ self.maybe_free_model_hooks()
+
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
index c81dd85f0e46..eb4542888c1f 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
@@ -206,17 +206,15 @@ def _encode_prior_prompt(
prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device))
prompt_embeds = prior_text_encoder_output.text_embeds
- prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state
+ text_enc_hid_states = prior_text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
- prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1]
+ prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave(
- num_images_per_prompt, dim=0
- )
+ text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
@@ -235,9 +233,7 @@ def _encode_prior_prompt(
)
negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds
- uncond_prior_text_encoder_hidden_states = (
- negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
- )
+ uncond_text_enc_hid_states = negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -245,11 +241,9 @@ def _encode_prior_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
- seq_len = uncond_prior_text_encoder_hidden_states.shape[1]
- uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat(
- 1, num_images_per_prompt, 1
- )
- uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view(
+ seq_len = uncond_text_enc_hid_states.shape[1]
+ uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
@@ -260,13 +254,11 @@ def _encode_prior_prompt(
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- prior_text_encoder_hidden_states = torch.cat(
- [uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states]
- )
+ text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
- return prompt_embeds, prior_text_encoder_hidden_states, text_mask
+ return prompt_embeds, text_enc_hid_states, text_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py
index c4a25c865d88..7bebed73c106 100644
--- a/src/diffusers/pipelines/unclip/pipeline_unclip.py
+++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py
@@ -156,15 +156,15 @@ def _encode_prompt(
text_encoder_output = self.text_encoder(text_input_ids.to(device))
prompt_embeds = text_encoder_output.text_embeds
- text_encoder_hidden_states = text_encoder_output.last_hidden_state
+ text_enc_hid_states = text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
- prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
+ prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
- text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
@@ -181,7 +181,7 @@ def _encode_prompt(
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
- uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+ uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -189,9 +189,9 @@ def _encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
- seq_len = uncond_text_encoder_hidden_states.shape[1]
- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ seq_len = uncond_text_enc_hid_states.shape[1]
+ uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
@@ -202,11 +202,11 @@ def _encode_prompt(
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+ text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
- return prompt_embeds, text_encoder_hidden_states, text_mask
+ return prompt_embeds, text_enc_hid_states, text_mask
@torch.no_grad()
def __call__(
@@ -293,7 +293,7 @@ def __call__(
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
- prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt_embeds, text_enc_hid_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
)
@@ -321,7 +321,7 @@ def __call__(
latent_model_input,
timestep=t,
proj_embedding=prompt_embeds,
- encoder_hidden_states=text_encoder_hidden_states,
+ encoder_hidden_states=text_enc_hid_states,
attention_mask=text_mask,
).predicted_image_embedding
@@ -352,10 +352,10 @@ def __call__(
# decoder
- text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
+ text_enc_hid_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings,
prompt_embeds=prompt_embeds,
- text_encoder_hidden_states=text_encoder_hidden_states,
+ text_encoder_hidden_states=text_enc_hid_states,
do_classifier_free_guidance=do_classifier_free_guidance,
)
@@ -377,7 +377,7 @@ def __call__(
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
- text_encoder_hidden_states.dtype,
+ text_enc_hid_states.dtype,
device,
generator,
decoder_latents,
@@ -391,7 +391,7 @@ def __call__(
noise_pred = self.decoder(
sample=latent_model_input,
timestep=t,
- encoder_hidden_states=text_encoder_hidden_states,
+ encoder_hidden_states=text_enc_hid_states,
class_labels=additive_clip_time_embeddings,
attention_mask=decoder_text_mask,
).sample
diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
index 32147ffa455b..60ea3d814b3a 100644
--- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
@@ -1494,7 +1494,6 @@ def forward(self, input_tensor, temb):
return output_tensor
-# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
class DownBlockFlat(nn.Module):
def __init__(
self,
@@ -1583,7 +1582,6 @@ def custom_forward(*inputs):
return hidden_states, output_states
-# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
class CrossAttnDownBlockFlat(nn.Module):
def __init__(
self,
diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py
index 8e2627b6f477..adcc092a816f 100644
--- a/src/diffusers/schedulers/scheduling_lcm.py
+++ b/src/diffusers/schedulers/scheduling_lcm.py
@@ -182,6 +182,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ timestep_scaling (`float`, defaults to 10.0):
+ The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
+ `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
+ error at the default of `10.0` is already pretty small).
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -208,6 +212,7 @@ def __init__(
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
+ timestep_scaling: float = 10.0,
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
@@ -380,12 +385,12 @@ def set_timesteps(
self._step_index = None
- def get_scalings_for_boundary_condition_discrete(self, t):
+ def get_scalings_for_boundary_condition_discrete(self, timestep):
self.sigma_data = 0.5 # Default: 0.5
+ scaled_timestep = timestep * self.config.timestep_scaling
- # By dividing 0.1: This is almost a delta function at t=0.
- c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
- c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
+ c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out
def step(
@@ -466,9 +471,12 @@ def step(
denoised = c_out * predicted_original_sample + c_skip * sample
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
- # Noise is not used for one-step sampling.
- if len(self.timesteps) > 1:
- noise = randn_tensor(model_output.shape, generator=generator, device=model_output.device)
+ # Noise is not used on the final timestep of the timestep schedule.
+ # This also means that noise is not used for one-step sampling.
+ if self.step_index != self.num_inference_steps - 1:
+ noise = randn_tensor(
+ model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
+ )
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
else:
prev_sample = denoised
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 132d76dc57cd..d6200bcaf122 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -572,6 +572,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class PixArtAlphaPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class SemanticStableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index 80c97978723c..e4ecb59121f7 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -293,7 +293,16 @@ def test_set_xformers_attn_processor_for_determinism(self):
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
+ model.set_attn_processor(XFormersAttnProcessor())
+ assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
+ with torch.no_grad():
+ output_3 = model(**inputs_dict)[0]
+
+ torch.use_deterministic_algorithms(True)
+
assert torch.allclose(output, output_2, atol=self.base_precision)
+ assert torch.allclose(output, output_3, atol=self.base_precision)
+ assert torch.allclose(output_2, output_3, atol=self.base_precision)
@require_torch_gpu
def test_set_attn_processor_for_determinism(self):
@@ -315,11 +324,6 @@ def test_set_attn_processor_for_determinism(self):
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
- model.enable_xformers_memory_efficient_attention()
- assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
- with torch.no_grad():
- model(**inputs_dict)[0]
-
model.set_attn_processor(AttnProcessor2_0())
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
with torch.no_grad():
@@ -330,18 +334,12 @@ def test_set_attn_processor_for_determinism(self):
with torch.no_grad():
output_5 = model(**inputs_dict)[0]
- model.set_attn_processor(XFormersAttnProcessor())
- assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
- with torch.no_grad():
- output_6 = model(**inputs_dict)[0]
-
torch.use_deterministic_algorithms(True)
# make sure that outputs match
assert torch.allclose(output_2, output_1, atol=self.base_precision)
assert torch.allclose(output_2, output_4, atol=self.base_precision)
assert torch.allclose(output_2, output_5, atol=self.base_precision)
- assert torch.allclose(output_2, output_6, atol=self.base_precision)
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py
index baba8ba4d655..3c9390f2d1b6 100644
--- a/tests/pipelines/animatediff/test_animatediff.py
+++ b/tests/pipelines/animatediff/test_animatediff.py
@@ -220,6 +220,17 @@ def test_to_dtype(self):
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
+ def test_prompt_embeds(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs.pop("prompt")
+ inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device)
+ pipe(**inputs)
+
@slow
@require_torch_gpu
diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
index 82a2944aeda4..53702925534d 100644
--- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
+++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
@@ -133,7 +133,7 @@ def test_lcm_onestep(self):
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5865, 0.2854, 0.2828, 0.7473, 0.6006, 0.4580, 0.4397, 0.6415, 0.6069])
+ expected_slice = np.array([0.4388, 0.3717, 0.2202, 0.7213, 0.6370, 0.3664, 0.5815, 0.6080, 0.4977])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_lcm_multistep(self):
@@ -150,7 +150,7 @@ def test_lcm_multistep(self):
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.4903, 0.3304, 0.3503, 0.5241, 0.5153, 0.4585, 0.3222, 0.4764, 0.4891])
+ expected_slice = np.array([0.4150, 0.3719, 0.2479, 0.6333, 0.6024, 0.3778, 0.5036, 0.5420, 0.4678])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_inference_batch_single_identical(self):
@@ -237,7 +237,7 @@ def test_lcm_onestep(self):
assert image.shape == (1, 512, 512, 3)
image_slice = image[0, -3:, -3:, -1].flatten()
- expected_slice = np.array([0.1025, 0.0911, 0.0984, 0.0981, 0.0901, 0.0918, 0.1055, 0.0940, 0.0730])
+ expected_slice = np.array([0.1950, 0.1961, 0.2308, 0.1786, 0.1837, 0.2320, 0.1898, 0.1885, 0.2309])
assert np.abs(image_slice - expected_slice).max() < 1e-3
def test_lcm_multistep(self):
@@ -253,5 +253,5 @@ def test_lcm_multistep(self):
assert image.shape == (1, 512, 512, 3)
image_slice = image[0, -3:, -3:, -1].flatten()
- expected_slice = np.array([0.01855, 0.01855, 0.01489, 0.01392, 0.01782, 0.01465, 0.01831, 0.02539, 0.0])
+ expected_slice = np.array([0.3756, 0.3816, 0.3767, 0.3718, 0.3739, 0.3735, 0.3863, 0.3803, 0.3563])
assert np.abs(image_slice - expected_slice).max() < 1e-3
diff --git a/tests/pipelines/pixart/__init__.py b/tests/pipelines/pixart/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py
new file mode 100644
index 000000000000..a04f4e1a8804
--- /dev/null
+++ b/tests/pipelines/pixart/test_pixart.py
@@ -0,0 +1,341 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ PixArtAlphaPipeline,
+ Transformer2DModel,
+)
+from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = PixArtAlphaPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+
+ required_optional_params = PipelineTesterMixin.required_optional_params
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = Transformer2DModel(
+ sample_size=8,
+ num_layers=2,
+ patch_size=2,
+ attention_head_dim=8,
+ num_attention_heads=3,
+ caption_channels=32,
+ in_channels=4,
+ cross_attention_dim=24,
+ out_channels=8,
+ attention_bias=True,
+ activation_fn="gelu-approximate",
+ num_embeds_ada_norm=1000,
+ norm_type="ada_norm_single",
+ norm_elementwise_affine=False,
+ norm_eps=1e-6,
+ )
+ vae = AutoencoderKL()
+ scheduler = DDIMScheduler()
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "transformer": transformer.eval(),
+ "vae": vae.eval(),
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "output_type": "numpy",
+ }
+ return inputs
+
+ def test_sequential_cpu_offload_forward_pass(self):
+ # TODO(PVP, Sayak) need to fix later
+ return
+
+ def test_save_load_optional_components(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ prompt = inputs["prompt"]
+ generator = inputs["generator"]
+ num_inference_steps = inputs["num_inference_steps"]
+ output_type = inputs["output_type"]
+
+ prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt, mask_feature=False)
+
+ # inputs with prompt converted to embeddings
+ inputs = {
+ "prompt_embeds": prompt_embeds,
+ "negative_prompt": None,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ "generator": generator,
+ "num_inference_steps": num_inference_steps,
+ "output_type": output_type,
+ }
+
+ # set all optional components to None
+ for optional_component in pipe._optional_components:
+ setattr(pipe, optional_component, None)
+
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for optional_component in pipe._optional_components:
+ self.assertTrue(
+ getattr(pipe_loaded, optional_component) is None,
+ f"`{optional_component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ generator = inputs["generator"]
+ num_inference_steps = inputs["num_inference_steps"]
+ output_type = inputs["output_type"]
+
+ # inputs with prompt converted to embeddings
+ inputs = {
+ "prompt_embeds": prompt_embeds,
+ "negative_prompt": None,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ "generator": generator,
+ "num_inference_steps": num_inference_steps,
+ "output_type": output_type,
+ }
+
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(max_diff, 1e-4)
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ self.assertEqual(image.shape, (1, 8, 8, 3))
+ expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675])
+ max_diff = np.abs(image_slice.flatten() - expected_slice).max()
+ self.assertLessEqual(max_diff, 1e-3)
+
+ def test_inference_non_square_images(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs, height=32, width=48).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ self.assertEqual(image.shape, (1, 32, 48, 3))
+ expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416])
+ max_diff = np.abs(image_slice.flatten() - expected_slice).max()
+ self.assertLessEqual(max_diff, 1e-3)
+
+ def test_inference_with_embeddings_and_multiple_images(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ prompt = inputs["prompt"]
+ generator = inputs["generator"]
+ num_inference_steps = inputs["num_inference_steps"]
+ output_type = inputs["output_type"]
+
+ prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt)
+
+ # inputs with prompt converted to embeddings
+ inputs = {
+ "prompt_embeds": prompt_embeds,
+ "negative_prompt": None,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ "generator": generator,
+ "num_inference_steps": num_inference_steps,
+ "output_type": output_type,
+ "num_images_per_prompt": 2,
+ }
+
+ # set all optional components to None
+ for optional_component in pipe._optional_components:
+ setattr(pipe, optional_component, None)
+
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for optional_component in pipe._optional_components:
+ self.assertTrue(
+ getattr(pipe_loaded, optional_component) is None,
+ f"`{optional_component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ generator = inputs["generator"]
+ num_inference_steps = inputs["num_inference_steps"]
+ output_type = inputs["output_type"]
+
+ # inputs with prompt converted to embeddings
+ inputs = {
+ "prompt_embeds": prompt_embeds,
+ "negative_prompt": None,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ "generator": generator,
+ "num_inference_steps": num_inference_steps,
+ "output_type": output_type,
+ "num_images_per_prompt": 2,
+ }
+
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(max_diff, 1e-4)
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(expected_max_diff=1e-3)
+
+
+@slow
+@require_torch_gpu
+class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_pixart_1024_fast(self):
+ generator = torch.manual_seed(0)
+
+ pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
+ pipe.enable_model_cpu_offload()
+
+ prompt = "A small cactus with a happy face in the Sahara desert."
+
+ image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1323])
+
+ max_diff = np.abs(image_slice.flatten() - expected_slice).max()
+ self.assertLessEqual(max_diff, 1e-3)
+
+ def test_pixart_512_fast(self):
+ generator = torch.manual_seed(0)
+
+ pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-512x512", torch_dtype=torch.float16)
+ pipe.enable_model_cpu_offload()
+
+ prompt = "A small cactus with a happy face in the Sahara desert."
+
+ image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0266])
+
+ max_diff = np.abs(image_slice.flatten() - expected_slice).max()
+ self.assertLessEqual(max_diff, 1e-3)
+
+ def test_pixart_1024(self):
+ generator = torch.manual_seed(0)
+
+ pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
+ pipe.enable_model_cpu_offload()
+ prompt = "A small cactus with a happy face in the Sahara desert."
+
+ image = pipe(prompt, generator=generator, output_type="np").images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ expected_slice = np.array([0.1501, 0.1755, 0.1877, 0.1445, 0.1665, 0.1763, 0.1389, 0.176, 0.2031])
+
+ max_diff = np.abs(image_slice.flatten() - expected_slice).max()
+ self.assertLessEqual(max_diff, 1e-3)
+
+ def test_pixart_512(self):
+ generator = torch.manual_seed(0)
+
+ pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-512x512", torch_dtype=torch.float16)
+ pipe.enable_model_cpu_offload()
+
+ prompt = "A small cactus with a happy face in the Sahara desert."
+
+ image = pipe(prompt, generator=generator, output_type="np").images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ expected_slice = np.array([0.2515, 0.2593, 0.2593, 0.2544, 0.2759, 0.2788, 0.2812, 0.3169, 0.332])
+
+ max_diff = np.abs(image_slice.flatten() - expected_slice).max()
+ self.assertLessEqual(max_diff, 1e-3)
diff --git a/tests/pipelines/shap_e/test_shap_e.py b/tests/pipelines/shap_e/test_shap_e.py
index 7b95fdd9e669..c7792f097ed5 100644
--- a/tests/pipelines/shap_e/test_shap_e.py
+++ b/tests/pipelines/shap_e/test_shap_e.py
@@ -160,7 +160,7 @@ def get_dummy_inputs(self, device, seed=0):
"generator": generator,
"num_inference_steps": 1,
"frame_size": 32,
- "output_type": "np",
+ "output_type": "latent",
}
return inputs
@@ -176,24 +176,12 @@ def test_shap_e(self):
output = pipe(**self.get_dummy_inputs(device))
image = output.images[0]
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (20, 32, 32, 3)
-
- expected_slice = np.array(
- [
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- ]
- )
+ image = image.cpu().numpy()
+ image_slice = image[-3:, -3:]
+
+ assert image.shape == (32, 16)
+ expected_slice = np.array([-1.0000, -0.6241, 1.0000, -0.8978, -0.6866, 0.7876, -0.7473, -0.2874, 0.6103])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_batch_consistent(self):
diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py
index 055dbe7a97d4..ee8d9d07cd77 100644
--- a/tests/pipelines/shap_e/test_shap_e_img2img.py
+++ b/tests/pipelines/shap_e/test_shap_e_img2img.py
@@ -181,7 +181,7 @@ def get_dummy_inputs(self, device, seed=0):
"generator": generator,
"num_inference_steps": 1,
"frame_size": 32,
- "output_type": "np",
+ "output_type": "latent",
}
return inputs
@@ -197,22 +197,12 @@ def test_shap_e(self):
output = pipe(**self.get_dummy_inputs(device))
image = output.images[0]
- image_slice = image[0, -3:, -3:, -1]
+ image_slice = image[-3:, -3:].cpu().numpy()
- assert image.shape == (20, 32, 32, 3)
+ assert image.shape == (32, 16)
expected_slice = np.array(
- [
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- ]
+ [-1.0, 0.40668195, 0.57322013, -0.9469888, 0.4283227, 0.30348337, -0.81094897, 0.74555075, 0.15342723]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 1795c83b58a1..b9fe4d190f23 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -493,7 +493,7 @@ def _test_inference_batch_single_identical(
assert output_batch[0].shape[0] == batch_size
- max_diff = np.abs(output_batch[0][0] - output[0][0]).max()
+ max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
@@ -702,7 +702,7 @@ def _test_attention_slicing_forward_pass(
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
if test_mean_pixel_difference:
- assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
+ assert_mean_pixel_difference(to_np(output_with_slicing[0]), to_np(output_without_slicing[0]))
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
diff --git a/tests/schedulers/test_scheduler_lcm.py b/tests/schedulers/test_scheduler_lcm.py
index 48b68fa47ddc..f7d511ff0573 100644
--- a/tests/schedulers/test_scheduler_lcm.py
+++ b/tests/schedulers/test_scheduler_lcm.py
@@ -230,7 +230,7 @@ def test_full_loop_onestep(self):
result_mean = torch.mean(torch.abs(sample))
# TODO: get expected sum and mean
- assert abs(result_sum.item() - 18.7097) < 1e-2
+ assert abs(result_sum.item() - 18.7097) < 1e-3
assert abs(result_mean.item() - 0.0244) < 1e-3
def test_full_loop_multistep(self):
@@ -240,5 +240,5 @@ def test_full_loop_multistep(self):
result_mean = torch.mean(torch.abs(sample))
# TODO: get expected sum and mean
- assert abs(result_sum.item() - 280.5618) < 1e-2
- assert abs(result_mean.item() - 0.3653) < 1e-3
+ assert abs(result_sum.item() - 197.7616) < 1e-3
+ assert abs(result_mean.item() - 0.2575) < 1e-3