@@ -477,18 +500,21 @@ Once you've generated the embeddings, pass them to the `prompt_embeds` (and `neg
```py
import torch
from diffusers import AutoPipelineForInpainting
+from diffusers.utils import make_image_grid
pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
-image = pipeline(prompt_emebds=prompt_embeds, # generated from Compel
- negative_prompt_embeds, # generated from Compel
+image = pipeline(prompt_embeds=prompt_embeds, # generated from Compel
+ negative_prompt_embeds=negative_prompt_embeds, # generated from Compel
image=init_image,
mask_image=mask_image
).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
### ControlNet
@@ -501,7 +527,7 @@ For example, let's condition an image with a ControlNet pretrained on inpaint im
import torch
import numpy as np
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
-from diffusers.utils import load_image
+from diffusers.utils import load_image, make_image_grid
# load ControlNet
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16, variant="fp16")
@@ -511,11 +537,12 @@ pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16"
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
# prepare control image
def make_inpaint_condition(init_image, mask_image):
@@ -536,7 +563,7 @@ Now generate an image from the base, mask and control images. You'll notice feat
```py
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, control_image=control_image).images[0]
-image
+make_image_grid([init_image, mask_image, PIL.Image.fromarray(np.uint8(control_image[0][0])).convert('RGB'), image], rows=2, cols=2)
```
You can take this a step further and chain it with an image-to-image pipeline to apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion):
@@ -548,13 +575,14 @@ pipeline = AutoPipelineForImage2Image.from_pretrained(
"nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
prompt = "elden ring style castle" # include the token "elden ring style" in the prompt
negative_prompt = "bad architecture, deformed, disfigured, poor details"
-image = pipeline(prompt, negative_prompt=negative_prompt, image=image).images[0]
-image
+image_elden_ring = pipeline(prompt, negative_prompt=negative_prompt, image=image).images[0]
+make_image_grid([init_image, mask_image, image, image_elden_ring], rows=2, cols=2)
```
@@ -576,17 +604,17 @@ image
It can be difficult and slow to run diffusion models if you're resource constrained, but it doesn't have to be with a few optimization tricks. One of the biggest (and easiest) optimizations you can enable is switching to memory-efficient attention. If you're using PyTorch 2.0, [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) is automatically enabled and you don't need to do anything else. For non-PyTorch 2.0 users, you can install and use [xFormers](../optimization/xformers)'s implementation of memory-efficient attention. Both options reduce memory usage and accelerate inference.
-You can also offload the model to the GPU to save even more memory:
+You can also offload the model to the CPU to save even more memory:
```diff
+ pipeline.enable_xformers_memory_efficient_attention()
+ pipeline.enable_model_cpu_offload()
```
-To speed-up your inference code even more, use [`torch_compile`](../optimization/torch2.0#torch.compile). You should wrap `torch.compile` around the most intensive component in the pipeline which is typically the UNet:
+To speed-up your inference code even more, use [`torch_compile`](../optimization/torch2.0#torchcompile). You should wrap `torch.compile` around the most intensive component in the pipeline which is typically the UNet:
```py
-pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
-Learn more in the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
\ No newline at end of file
+Learn more in the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
diff --git a/docs/source/en/using-diffusers/unconditional_image_generation.md b/docs/source/en/using-diffusers/unconditional_image_generation.md
index 3893f7cce276..c055bc75c5a4 100644
--- a/docs/source/en/using-diffusers/unconditional_image_generation.md
+++ b/docs/source/en/using-diffusers/unconditional_image_generation.md
@@ -23,16 +23,16 @@ You can use any of the ๐งจ Diffusers [checkpoints](https://huggingface.co/model
-๐ก Want to train your own unconditional image generation model? Take a look at the training [guide](training/unconditional_training) to learn how to generate your own images.
+๐ก Want to train your own unconditional image generation model? Take a look at the training [guide](../training/unconditional_training) to learn how to generate your own images.
In this guide, you'll use [`DiffusionPipeline`] for unconditional image generation with [DDPM](https://arxiv.org/abs/2006.11239):
```python
->>> from diffusers import DiffusionPipeline
+from diffusers import DiffusionPipeline
->>> generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128", use_safetensors=True)
+generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128", use_safetensors=True)
```
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
@@ -40,13 +40,14 @@ Because the model consists of roughly 1.4 billion parameters, we strongly recomm
You can move the generator object to a GPU, just like you would in PyTorch:
```python
->>> generator.to("cuda")
+generator.to("cuda")
```
Now you can use the `generator` to generate an image:
```python
->>> image = generator().images[0]
+image = generator().images[0]
+image
```
The output is by default wrapped into a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object.
@@ -54,7 +55,7 @@ The output is by default wrapped into a [`PIL.Image`](https://pillow.readthedocs
You can save the image by calling:
```python
->>> image.save("generated_image.png")
+image.save("generated_image.png")
```
Try out the Spaces below, and feel free to play around with the inference steps parameter to see how it affects the image quality!
@@ -65,5 +66,3 @@ Try out the Spaces below, and feel free to play around with the inference steps
width="850"
height="500"
>
-
-
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index e1ae2c1c064b..1234dbd2d5ce 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -378,7 +378,7 @@ def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False)
_remove_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to remove LoRA layers from the model.
"""
- if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
+ if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
deprecate(
"set_processor to offload LoRA",
"0.26.0",
@@ -879,6 +879,9 @@ def __call__(
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
@@ -891,17 +894,17 @@ def __call__(
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- query = attn.to_q(hidden_states, scale=scale)
+ query = attn.to_q(hidden_states, *args)
query = attn.head_to_batch_dim(query)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
- key = attn.to_k(hidden_states, scale=scale)
- value = attn.to_v(hidden_states, scale=scale)
+ key = attn.to_k(hidden_states, *args)
+ value = attn.to_v(hidden_states, *args)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
@@ -915,7 +918,7 @@ def __call__(
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
- hidden_states = attn.to_out[0](hidden_states, scale=scale)
+ hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -946,6 +949,9 @@ def __call__(
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
@@ -958,7 +964,7 @@ def __call__(
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- query = attn.to_q(hidden_states, scale=scale)
+ query = attn.to_q(hidden_states, *args)
query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
@@ -967,8 +973,8 @@ def __call__(
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention:
- key = attn.to_k(hidden_states, scale=scale)
- value = attn.to_v(hidden_states, scale=scale)
+ key = attn.to_k(hidden_states, *args)
+ value = attn.to_v(hidden_states, *args)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
@@ -985,7 +991,7 @@ def __call__(
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj
- hidden_states = attn.to_out[0](hidden_states, scale=scale)
+ hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1177,6 +1183,8 @@ def __call__(
) -> torch.FloatTensor:
residual = hidden_states
+ args = () if USE_PEFT_BACKEND else (scale,)
+
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1207,12 +1215,8 @@ def __call__(
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = (
- attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states)
- )
- value = (
- attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
- )
+ key = attn.to_k(encoder_hidden_states, *args)
+ value = attn.to_v(encoder_hidden_states, *args)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@@ -1232,9 +1236,7 @@ def __call__(
hidden_states = hidden_states.to(query.dtype)
# linear proj
- hidden_states = (
- attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
- )
+ hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py
index 8fe66aacf5db..868e2e5fae2c 100644
--- a/src/diffusers/models/resnet.py
+++ b/src/diffusers/models/resnet.py
@@ -778,16 +778,22 @@ class Conv1dBlock(nn.Module):
out_channels (`int`): Number of output channels.
kernel_size (`int` or `tuple`): Size of the convolving kernel.
n_groups (`int`, default `8`): Number of groups to separate the channels into.
+ activation (`str`, defaults `mish`): Name of the activation function.
"""
def __init__(
- self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8
+ self,
+ inp_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ n_groups: int = 8,
+ activation: str = "mish",
):
super().__init__()
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.group_norm = nn.GroupNorm(n_groups, out_channels)
- self.mish = nn.Mish()
+ self.mish = get_activation(activation)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
intermediate_repr = self.conv1d(inputs)
@@ -808,16 +814,22 @@ class ResidualTemporalBlock1D(nn.Module):
out_channels (`int`): Number of output channels.
embed_dim (`int`): Embedding dimension.
kernel_size (`int` or `tuple`): Size of the convolving kernel.
+ activation (`str`, defaults `mish`): It is possible to choose the right activation function.
"""
def __init__(
- self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5
+ self,
+ inp_channels: int,
+ out_channels: int,
+ embed_dim: int,
+ kernel_size: Union[int, Tuple[int, int]] = 5,
+ activation: str = "mish",
):
super().__init__()
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
- self.time_emb_act = nn.Mish()
+ self.time_emb_act = get_activation(activation)
self.time_emb = nn.Linear(embed_dim, out_channels)
self.residual_conv = (
diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py
index 7c0cd12d1c67..24abf54d6da7 100644
--- a/src/diffusers/models/transformer_2d.py
+++ b/src/diffusers/models/transformer_2d.py
@@ -339,6 +339,7 @@ def forward(
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = self.pos_embed(hidden_states)
if self.adaln_single is not None:
@@ -425,7 +426,8 @@ def forward(
hidden_states = hidden_states.squeeze(1)
# unpatchify
- height = width = int(hidden_states.shape[1] ** 0.5)
+ if self.adaln_single is None:
+ height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py
index 08ad122c3891..0c93b9142bea 100644
--- a/src/diffusers/models/vq_model.py
+++ b/src/diffusers/models/vq_model.py
@@ -162,8 +162,8 @@ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
is returned.
"""
- x = sample
- h = self.encode(x).latents
+
+ h = self.encode(sample).latents
dec = self.decode(h).sample
if not return_dict:
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
index 49947f9dbf32..b63acb9a5f30 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
@@ -498,7 +498,7 @@ def prepare_latents(
@torch.no_grad()
def __call__(
self,
- prompt: Union[str, List[str]],
+ prompt: Union[str, List[str]] = None,
num_frames: Optional[int] = 16,
height: Optional[int] = None,
width: Optional[int] = None,
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
index ccc84e22c252..679415db7f3a 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
@@ -60,7 +60,7 @@ def retrieve_latents(encoder_output, generator):
>>> import torch
>>> import PIL
- >>> pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7")
+ >>> pipe = AutoPipelineForImage2Image.from_pretrained("SimianLuo/LCM_Dreamshaper_v7")
>>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality.
>>> pipe.to(torch_device="cuda", torch_dtype=torch.float32)
@@ -738,7 +738,7 @@ def __call__(
if original_inference_steps is not None
else self.scheduler.config.original_inference_steps
)
- latent_timestep = torch.tensor(int(strength * original_inference_steps))
+ latent_timestep = timesteps[:1]
latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index 9e019794692e..6437732d0315 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -158,9 +158,9 @@ def is_safetensors_compatible(filenames, variant=None, passed_components=None) -
continue
if extension == ".bin":
- pt_filenames.append(filename)
+ pt_filenames.append(os.path.normpath(filename))
elif extension == ".safetensors":
- sf_filenames.add(filename)
+ sf_filenames.add(os.path.normpath(filename))
for filename in pt_filenames:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
@@ -172,9 +172,8 @@ def is_safetensors_compatible(filenames, variant=None, passed_components=None) -
else:
filename = filename
- expected_sf_filename = os.path.join(path, filename)
+ expected_sf_filename = os.path.normpath(os.path.join(path, filename))
expected_sf_filename = f"{expected_sf_filename}.safetensors"
-
if expected_sf_filename not in sf_filenames:
logger.warning(f"{expected_sf_filename} not found")
return False
@@ -1774,7 +1773,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
)
):
raise EnvironmentError(
- f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})"
+ f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
index 6f2bed9ce2cf..386971b2ac3d 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
@@ -20,6 +20,7 @@
import torch
from transformers import T5EncoderModel, T5Tokenizer
+from torchvision import transforms as T
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, Transformer2DModel
@@ -61,6 +62,20 @@
"""
+ASPECT_RATIO_1024_TEST = {
+ '0.25': [512., 2048.], '0.28': [512., 1856.],
+ '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
+ '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
+ '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
+ '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
+ '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
+ '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
+ '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
+ '2.5': [1600., 640.], '3.0': [1728., 576.],
+ '4.0': [2048., 512.],
+}
+
+
class PixArtAlphaPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using PixArt-Alpha.
@@ -116,6 +131,13 @@ def mask_text_embeddings(self, emb, mask):
masked_feature = emb * mask[:, None, :, None]
return masked_feature, emb.shape[2]
+ @staticmethod
+ def classify_height_width_bin(height: int, width: int, ratios: dict):
+ ar = float(height / width)
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
+ default_hw = ratios[closest_ratio]
+ return int(default_hw[0]), int(default_hw[1])
+
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
def encode_prompt(
self,
@@ -156,6 +178,8 @@ def encode_prompt(
mask_feature: (bool, defaults to `True`):
If `True`, the function will mask the text embeddings.
"""
+ embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
+
if device is None:
device = self._execution_device
@@ -253,7 +277,7 @@ def encode_prompt(
negative_prompt_embeds = None
# Perform additional masking.
- if mask_feature and prompt_embeds is None and negative_prompt_embeds is None:
+ if mask_feature and not embeds_initially_provided:
prompt_embeds = prompt_embeds.unsqueeze(1)
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
@@ -493,6 +517,24 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = latents * self.scheduler.init_noise_sigma
return latents
+ @staticmethod
+ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int):
+ orig_hw = torch.tensor([samples.shape[2], samples.shape[3]])
+ custom_hw = torch.tensor([new_height, new_width])
+
+ if (orig_hw != custom_hw).all():
+ ratio = max(custom_hw[0] / orig_hw[0], custom_hw[1] / orig_hw[1])
+ resized_width = int(orig_hw[1] * ratio)
+ resized_height = int(orig_hw[0] * ratio)
+
+ transform = T.Compose([
+ T.Resize((resized_height, resized_width)),
+ T.CenterCrop(custom_hw.tolist())
+ ])
+ return transform(samples)
+ else:
+ return samples
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -516,6 +558,7 @@ def __call__(
callback_steps: int = 1,
clean_caption: bool = True,
mask_feature: bool = True,
+ use_bin_classifier: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -589,6 +632,9 @@ def __call__(
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
+ if use_bin_classifier:
+ orig_height, orig_width = height, width
+ height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_TEST)
self.check_inputs(
prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds
)
@@ -707,6 +753,8 @@ def __call__(
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ if use_bin_classifier:
+ image = self.resize_and_crop_tensor(image, orig_width, orig_height)
else:
image = latents
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
index c81dd85f0e46..eb4542888c1f 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
@@ -206,17 +206,15 @@ def _encode_prior_prompt(
prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device))
prompt_embeds = prior_text_encoder_output.text_embeds
- prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state
+ text_enc_hid_states = prior_text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
- prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1]
+ prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave(
- num_images_per_prompt, dim=0
- )
+ text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
@@ -235,9 +233,7 @@ def _encode_prior_prompt(
)
negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds
- uncond_prior_text_encoder_hidden_states = (
- negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
- )
+ uncond_text_enc_hid_states = negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -245,11 +241,9 @@ def _encode_prior_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
- seq_len = uncond_prior_text_encoder_hidden_states.shape[1]
- uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat(
- 1, num_images_per_prompt, 1
- )
- uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view(
+ seq_len = uncond_text_enc_hid_states.shape[1]
+ uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
@@ -260,13 +254,11 @@ def _encode_prior_prompt(
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- prior_text_encoder_hidden_states = torch.cat(
- [uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states]
- )
+ text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
- return prompt_embeds, prior_text_encoder_hidden_states, text_mask
+ return prompt_embeds, text_enc_hid_states, text_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py
index c4a25c865d88..7bebed73c106 100644
--- a/src/diffusers/pipelines/unclip/pipeline_unclip.py
+++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py
@@ -156,15 +156,15 @@ def _encode_prompt(
text_encoder_output = self.text_encoder(text_input_ids.to(device))
prompt_embeds = text_encoder_output.text_embeds
- text_encoder_hidden_states = text_encoder_output.last_hidden_state
+ text_enc_hid_states = text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
- prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
+ prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
- text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
@@ -181,7 +181,7 @@ def _encode_prompt(
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
- uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+ uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -189,9 +189,9 @@ def _encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
- seq_len = uncond_text_encoder_hidden_states.shape[1]
- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ seq_len = uncond_text_enc_hid_states.shape[1]
+ uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
@@ -202,11 +202,11 @@ def _encode_prompt(
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+ text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
- return prompt_embeds, text_encoder_hidden_states, text_mask
+ return prompt_embeds, text_enc_hid_states, text_mask
@torch.no_grad()
def __call__(
@@ -293,7 +293,7 @@ def __call__(
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
- prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt_embeds, text_enc_hid_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
)
@@ -321,7 +321,7 @@ def __call__(
latent_model_input,
timestep=t,
proj_embedding=prompt_embeds,
- encoder_hidden_states=text_encoder_hidden_states,
+ encoder_hidden_states=text_enc_hid_states,
attention_mask=text_mask,
).predicted_image_embedding
@@ -352,10 +352,10 @@ def __call__(
# decoder
- text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
+ text_enc_hid_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings,
prompt_embeds=prompt_embeds,
- text_encoder_hidden_states=text_encoder_hidden_states,
+ text_encoder_hidden_states=text_enc_hid_states,
do_classifier_free_guidance=do_classifier_free_guidance,
)
@@ -377,7 +377,7 @@ def __call__(
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
- text_encoder_hidden_states.dtype,
+ text_enc_hid_states.dtype,
device,
generator,
decoder_latents,
@@ -391,7 +391,7 @@ def __call__(
noise_pred = self.decoder(
sample=latent_model_input,
timestep=t,
- encoder_hidden_states=text_encoder_hidden_states,
+ encoder_hidden_states=text_enc_hid_states,
class_labels=additive_clip_time_embeddings,
attention_mask=decoder_text_mask,
).sample
diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
index 32147ffa455b..60ea3d814b3a 100644
--- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
@@ -1494,7 +1494,6 @@ def forward(self, input_tensor, temb):
return output_tensor
-# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
class DownBlockFlat(nn.Module):
def __init__(
self,
@@ -1583,7 +1582,6 @@ def custom_forward(*inputs):
return hidden_states, output_states
-# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
class CrossAttnDownBlockFlat(nn.Module):
def __init__(
self,
diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py
index 8e2627b6f477..adcc092a816f 100644
--- a/src/diffusers/schedulers/scheduling_lcm.py
+++ b/src/diffusers/schedulers/scheduling_lcm.py
@@ -182,6 +182,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ timestep_scaling (`float`, defaults to 10.0):
+ The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
+ `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
+ error at the default of `10.0` is already pretty small).
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -208,6 +212,7 @@ def __init__(
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
+ timestep_scaling: float = 10.0,
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
@@ -380,12 +385,12 @@ def set_timesteps(
self._step_index = None
- def get_scalings_for_boundary_condition_discrete(self, t):
+ def get_scalings_for_boundary_condition_discrete(self, timestep):
self.sigma_data = 0.5 # Default: 0.5
+ scaled_timestep = timestep * self.config.timestep_scaling
- # By dividing 0.1: This is almost a delta function at t=0.
- c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
- c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
+ c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out
def step(
@@ -466,9 +471,12 @@ def step(
denoised = c_out * predicted_original_sample + c_skip * sample
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
- # Noise is not used for one-step sampling.
- if len(self.timesteps) > 1:
- noise = randn_tensor(model_output.shape, generator=generator, device=model_output.device)
+ # Noise is not used on the final timestep of the timestep schedule.
+ # This also means that noise is not used for one-step sampling.
+ if self.step_index != self.num_inference_steps - 1:
+ noise = randn_tensor(
+ model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
+ )
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
else:
prev_sample = denoised
diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py
index baba8ba4d655..3c9390f2d1b6 100644
--- a/tests/pipelines/animatediff/test_animatediff.py
+++ b/tests/pipelines/animatediff/test_animatediff.py
@@ -220,6 +220,17 @@ def test_to_dtype(self):
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
+ def test_prompt_embeds(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs.pop("prompt")
+ inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device)
+ pipe(**inputs)
+
@slow
@require_torch_gpu
diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
index 82a2944aeda4..53702925534d 100644
--- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
+++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
@@ -133,7 +133,7 @@ def test_lcm_onestep(self):
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5865, 0.2854, 0.2828, 0.7473, 0.6006, 0.4580, 0.4397, 0.6415, 0.6069])
+ expected_slice = np.array([0.4388, 0.3717, 0.2202, 0.7213, 0.6370, 0.3664, 0.5815, 0.6080, 0.4977])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_lcm_multistep(self):
@@ -150,7 +150,7 @@ def test_lcm_multistep(self):
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.4903, 0.3304, 0.3503, 0.5241, 0.5153, 0.4585, 0.3222, 0.4764, 0.4891])
+ expected_slice = np.array([0.4150, 0.3719, 0.2479, 0.6333, 0.6024, 0.3778, 0.5036, 0.5420, 0.4678])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_inference_batch_single_identical(self):
@@ -237,7 +237,7 @@ def test_lcm_onestep(self):
assert image.shape == (1, 512, 512, 3)
image_slice = image[0, -3:, -3:, -1].flatten()
- expected_slice = np.array([0.1025, 0.0911, 0.0984, 0.0981, 0.0901, 0.0918, 0.1055, 0.0940, 0.0730])
+ expected_slice = np.array([0.1950, 0.1961, 0.2308, 0.1786, 0.1837, 0.2320, 0.1898, 0.1885, 0.2309])
assert np.abs(image_slice - expected_slice).max() < 1e-3
def test_lcm_multistep(self):
@@ -253,5 +253,5 @@ def test_lcm_multistep(self):
assert image.shape == (1, 512, 512, 3)
image_slice = image[0, -3:, -3:, -1].flatten()
- expected_slice = np.array([0.01855, 0.01855, 0.01489, 0.01392, 0.01782, 0.01465, 0.01831, 0.02539, 0.0])
+ expected_slice = np.array([0.3756, 0.3816, 0.3767, 0.3718, 0.3739, 0.3735, 0.3863, 0.3803, 0.3563])
assert np.abs(image_slice - expected_slice).max() < 1e-3
diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py
index 2e033896d900..a04f4e1a8804 100644
--- a/tests/pipelines/pixart/test_pixart.py
+++ b/tests/pipelines/pixart/test_pixart.py
@@ -120,7 +120,6 @@ def test_save_load_optional_components(self):
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
- "mask_feature": False,
}
# set all optional components to None
@@ -155,7 +154,6 @@ def test_save_load_optional_components(self):
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
- "mask_feature": False,
}
output_loaded = pipe_loaded(**inputs)[0]
@@ -174,13 +172,29 @@ def test_inference(self):
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
- print(torch.from_numpy(image_slice.flatten()))
self.assertEqual(image.shape, (1, 8, 8, 3))
expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
+ def test_inference_non_square_images(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs, height=32, width=48).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ self.assertEqual(image.shape, (1, 32, 48, 3))
+ expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416])
+ max_diff = np.abs(image_slice.flatten() - expected_slice).max()
+ self.assertLessEqual(max_diff, 1e-3)
+
def test_inference_with_embeddings_and_multiple_images(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
diff --git a/tests/pipelines/shap_e/test_shap_e.py b/tests/pipelines/shap_e/test_shap_e.py
index 7b95fdd9e669..c7792f097ed5 100644
--- a/tests/pipelines/shap_e/test_shap_e.py
+++ b/tests/pipelines/shap_e/test_shap_e.py
@@ -160,7 +160,7 @@ def get_dummy_inputs(self, device, seed=0):
"generator": generator,
"num_inference_steps": 1,
"frame_size": 32,
- "output_type": "np",
+ "output_type": "latent",
}
return inputs
@@ -176,24 +176,12 @@ def test_shap_e(self):
output = pipe(**self.get_dummy_inputs(device))
image = output.images[0]
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (20, 32, 32, 3)
-
- expected_slice = np.array(
- [
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- ]
- )
+ image = image.cpu().numpy()
+ image_slice = image[-3:, -3:]
+
+ assert image.shape == (32, 16)
+ expected_slice = np.array([-1.0000, -0.6241, 1.0000, -0.8978, -0.6866, 0.7876, -0.7473, -0.2874, 0.6103])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_batch_consistent(self):
diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py
index 055dbe7a97d4..ee8d9d07cd77 100644
--- a/tests/pipelines/shap_e/test_shap_e_img2img.py
+++ b/tests/pipelines/shap_e/test_shap_e_img2img.py
@@ -181,7 +181,7 @@ def get_dummy_inputs(self, device, seed=0):
"generator": generator,
"num_inference_steps": 1,
"frame_size": 32,
- "output_type": "np",
+ "output_type": "latent",
}
return inputs
@@ -197,22 +197,12 @@ def test_shap_e(self):
output = pipe(**self.get_dummy_inputs(device))
image = output.images[0]
- image_slice = image[0, -3:, -3:, -1]
+ image_slice = image[-3:, -3:].cpu().numpy()
- assert image.shape == (20, 32, 32, 3)
+ assert image.shape == (32, 16)
expected_slice = np.array(
- [
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- ]
+ [-1.0, 0.40668195, 0.57322013, -0.9469888, 0.4283227, 0.30348337, -0.81094897, 0.74555075, 0.15342723]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 1795c83b58a1..b9fe4d190f23 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -493,7 +493,7 @@ def _test_inference_batch_single_identical(
assert output_batch[0].shape[0] == batch_size
- max_diff = np.abs(output_batch[0][0] - output[0][0]).max()
+ max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
@@ -702,7 +702,7 @@ def _test_attention_slicing_forward_pass(
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
if test_mean_pixel_difference:
- assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
+ assert_mean_pixel_difference(to_np(output_with_slicing[0]), to_np(output_without_slicing[0]))
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
diff --git a/tests/schedulers/test_scheduler_lcm.py b/tests/schedulers/test_scheduler_lcm.py
index 48b68fa47ddc..f7d511ff0573 100644
--- a/tests/schedulers/test_scheduler_lcm.py
+++ b/tests/schedulers/test_scheduler_lcm.py
@@ -230,7 +230,7 @@ def test_full_loop_onestep(self):
result_mean = torch.mean(torch.abs(sample))
# TODO: get expected sum and mean
- assert abs(result_sum.item() - 18.7097) < 1e-2
+ assert abs(result_sum.item() - 18.7097) < 1e-3
assert abs(result_mean.item() - 0.0244) < 1e-3
def test_full_loop_multistep(self):
@@ -240,5 +240,5 @@ def test_full_loop_multistep(self):
result_mean = torch.mean(torch.abs(sample))
# TODO: get expected sum and mean
- assert abs(result_sum.item() - 280.5618) < 1e-2
- assert abs(result_mean.item() - 0.3653) < 1e-3
+ assert abs(result_sum.item() - 197.7616) < 1e-3
+ assert abs(result_mean.item() - 0.2575) < 1e-3