From 2e1920ef5dc9ed875dad05fd3f46ed4bc9393e0a Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Tue, 14 Jan 2025 20:20:27 +0800
Subject: [PATCH 1/7] add pipeline_stable_diffusion_xl_attentive_eraser
---
examples/community/README.md | 86 +
...ne_stable_diffusion_xl_attentive_eraser.py | 2249 +++++++++++++++++
2 files changed, 2335 insertions(+)
mode change 100755 => 100644 examples/community/README.md
create mode 100644 examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
diff --git a/examples/community/README.md b/examples/community/README.md
old mode 100755
new mode 100644
index c7c40c46ef2d..077363f11a2a
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -77,6 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
+| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3)|
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -4634,6 +4635,91 @@ make_image_grid(image, rows=1, cols=len(image))
# 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
```
+### Stable Diffusion XL Attentive Eraser Pipeline
+
+
+**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the model’s self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://arxiv.org/abs/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).
+
+#### Key Features
+
+- **Tuning-Free**: No additional training is required, making it easy to integrate and use.
+- **Flexible Mask Support**: Works with different types of masks for targeted object removal.
+- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.
+
+#### How to Use
+To use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:
+```py
+import torch
+from diffusers import DDIMScheduler, DiffusionPipeline
+from diffusers.utils import load_image
+import torch.nn.functional as F
+from torchvision.transforms.functional import to_tensor, gaussian_blur
+
+dtype = torch.float16
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
+ scheduler=scheduler,
+ variant="fp16",
+ use_safetensors=True,
+ torch_dtype=dtype,
+).to(device)
+
+
+def preprocess_image(image_path, device):
+ image = to_tensor((load_image(image_path)))
+ image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
+ if image.shape[1] != 3:
+ image = image.expand(-1, 3, -1, -1)
+ image = F.interpolate(image, (1024, 1024))
+ image = image.to(dtype).to(device)
+ return image
+
+def preprocess_mask(mask_path, device):
+ mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
+ mask = mask.unsqueeze_(0).float() # 0 or 1
+ mask = F.interpolate(mask, (1024, 1024))
+ mask = gaussian_blur(mask, kernel_size=(77, 77))
+ mask[mask < 0.1] = 0
+ mask[mask >= 0.1] = 1
+ mask = mask.to(dtype).to(device)
+ return mask
+
+prompt = "" # Set prompt to null
+seed=123
+generator = torch.Generator(device=device).manual_seed(seed)
+source_image_path = "./path-to-image.png"
+mask_path = "./path-to-mask.png"
+source_image = preprocess_image(source_image_path, device)
+mask = preprocess_mask(mask_path, device)
+
+image = pipeline(
+ prompt=prompt,
+ image=source_image,
+ mask_image=mask,
+ height=1024,
+ width=1024,
+ AAS=True, # enable AAS
+ strength=0.8, # inpainting strength
+ rm_guidance_scale=9, # removal guidance scale
+ ss_steps = 9, # similarity suppression steps
+ ss_scale = 0.3, # similarity suppression scale
+ AAS_start_step=0, # AAS start step
+ AAS_start_layer=34, # AAS start layer
+ AAS_end_layer=70, # AAS end layer
+ num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
+ generator=generator,
+ guidance_scale=1,
+).images[0]
+image.save('./removed_img.png')
+print("Object removal completed")
+```
+
+
+
# Perturbed-Attention Guidance
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
new file mode 100644
index 000000000000..cebd80c8bfa1
--- /dev/null
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -0,0 +1,2249 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from tqdm import tqdm
+import numpy as np
+import PIL.Image
+from PIL import Image
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_invisible_watermark_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+from einops import rearrange, repeat
+import torch.nn as nn
+import torch.nn.functional as F
+import os
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
+ ... torch_dtype=torch.float16,
+ ... variant="fp16",
+ ... use_safetensors=True,
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> init_image = load_image(img_url).convert("RGB")
+ >>> mask_image = load_image(mask_url).convert("RGB")
+
+ >>> prompt = "A majestic tiger sitting on a bench"
+ >>> image = pipe(
+ ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
+ ... ).images[0]
+ ```
+"""
+
+class AttentionBase:
+ def __init__(self):
+ self.cur_step = 0
+ self.num_att_layers = -1
+ self.cur_att_layer = 0
+
+ def after_step(self):
+ pass
+
+ def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+ out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+ self.cur_att_layer += 1
+ if self.cur_att_layer == self.num_att_layers:
+ self.cur_att_layer = 0
+ self.cur_step += 1
+ # after step
+ self.after_step()
+ return out
+
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+ out = torch.einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
+ return out
+
+ def reset(self):
+ self.cur_step = 0
+ self.cur_att_layer = 0
+
+class AAS_XL(AttentionBase):
+ MODEL_TYPE = {
+ "SD": 16,
+ "SDXL": 70
+ }
+ def __init__(self, start_step=4, end_step= 50, start_layer=10, end_layer=16,layer_idx=None, step_idx=None, total_steps=50, mask=None, model_type="SD",ss_steps=9,ss_scale=1.0):
+ """
+ Args:
+ start_step: the step to start AAS
+ start_layer: the layer to start AAS
+ layer_idx: list of the layers to apply AAS
+ step_idx: list the steps to apply AAS
+ total_steps: the total number of steps
+ mask: source mask with shape (h, w)
+ model_type: the model type, SD or SDXL
+ """
+ super().__init__()
+ self.total_steps = total_steps
+ self.total_layers = self.MODEL_TYPE.get(model_type, 16)
+ self.start_step = start_step
+ self.end_step = end_step
+ self.start_layer = start_layer
+ self.end_layer = end_layer
+ self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, end_layer))
+ self.step_idx = step_idx if step_idx is not None else list(range(start_step, end_step))
+ self.mask = mask # mask with shape (1, 1 ,h, w)
+ self.ss_steps = ss_steps
+ self.ss_scale = ss_scale
+ print("AAS at denoising steps: ", self.step_idx)
+ print("AAS at U-Net layers: ", self.layer_idx)
+ print("start AAS")
+ self.mask_16 = F.max_pool2d(mask,(1024//16,1024//16)).round().squeeze().squeeze()
+ self.mask_32 = F.max_pool2d(mask,(1024//32,1024//32)).round().squeeze().squeeze()
+ self.mask_64 = F.max_pool2d(mask,(1024//64,1024//64)).round().squeeze().squeeze()
+ self.mask_128 = F.max_pool2d(mask,(1024//128,1024//128)).round().squeeze().squeeze()
+
+ def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,is_mask_attn, mask, **kwargs):
+ B = q.shape[0] // num_heads
+ if is_mask_attn:
+ mask_flatten = mask.flatten(0)
+ if self.cur_step <= self.ss_steps:
+ # background
+ sim_bg = sim + mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+
+ #object
+ sim_fg = self.ss_scale*sim
+ sim_fg += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+ sim = torch.cat([sim_fg, sim_bg], dim=0)
+ else:
+ sim += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+
+ attn = sim.softmax(-1)
+ if len(attn) == 2 * len(v):
+ v = torch.cat([v] * 2)
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
+ out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
+ return out
+
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+ """
+ Attention forward function
+ """
+ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+ B = q.shape[0] // num_heads // 2
+ H = W = int(np.sqrt(q.shape[1]))
+ if H == 16:
+ mask = self.mask_16.to(sim.device)
+ elif H == 32:
+ mask = self.mask_32.to(sim.device)
+ elif H == 64:
+ mask = self.mask_64.to(sim.device)
+ else:
+ mask = self.mask_128.to(sim.device)
+
+
+ q_wo, q_w = q.chunk(2)
+ k_wo, k_w = k.chunk(2)
+ v_wo, v_w = v.chunk(2)
+ sim_wo, sim_w = sim.chunk(2)
+ attn_wo, attn_w = attn.chunk(2)
+
+ out_source = self.attn_batch(q_wo, k_wo, v_wo, sim_wo, attn_wo, is_cross, place_in_unet, num_heads,is_mask_attn=False,mask=None,**kwargs)
+ out_target = self.attn_batch(q_w, k_w, v_w, sim_w, attn_w, is_cross, place_in_unet, num_heads, is_mask_attn=True, mask = mask,**kwargs)
+
+ if self.mask is not None:
+ if out_target.shape[0] == 2:
+ out_target_fg, out_target_bg = out_target.chunk(2, 0)
+ mask = mask.reshape(-1, 1) # (hw, 1)
+ out_target = out_target_fg * mask + out_target_bg * (1 - mask)
+ else:
+ out_target = out_target
+
+ out = torch.cat([out_source, out_target], dim=0)
+ return out
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+def mask_pil_to_torch(mask, height, width):
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask = torch.from_numpy(mask)
+ return mask
+
+
+def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ # checkpoint. TOD(Yiyi) - need to clean this up later
+ deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
+ deprecate(
+ "prepare_mask_and_masked_image",
+ "0.30.0",
+ deprecation_message,
+ )
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ mask = mask_pil_to_torch(mask, height, width)
+
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ # if image.min() < -1 or image.max() > 1:
+ # raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = mask_pil_to_torch(mask, height, width)
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ if image.shape[1] == 4:
+ # images are in latent space and thus can't
+ # be masked set masked_image to None
+ # we assume that the checkpoint is not an inpainting
+ # checkpoint. TOD(Yiyi) - need to clean this up later
+ masked_image = None
+ else:
+ masked_image = image * (mask < 0.5)
+
+ # n.b. ensure backwards compatibility as old function does not return image
+ if return_image:
+ return mask, masked_image, image
+
+ return mask, masked_image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionXL_AE_Pipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ FromSingleFileMixin,
+ IPAdapterMixin,
+):
+ r"""
+ Pipeline for object removal using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+ Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
+ of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "image_encoder",
+ "feature_extractor",
+ ]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ "negative_pooled_prompt_embeds",
+ "add_neg_time_ids",
+ "mask",
+ "masked_image_latents",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack(
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
+ )
+
+ if do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+ else:
+ repeat_dims = [1]
+ image_embeds = []
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+ )
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ else:
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ height,
+ width,
+ strength,
+ callback_steps,
+ output_type,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ add_noise=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if image.shape[1] == 4:
+ image_latents = image.to(device=device, dtype=dtype)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+ elif return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None and add_noise:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ elif add_noise:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+ else:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = image_latents.to(device)
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ dtype = image.dtype
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ #mask = torch.nn.functional.interpolate(
+ # mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ #)
+ mask = torch.nn.functional.max_pool2d(mask, (8,8)).round()
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+ if masked_image is not None and masked_image.shape[1] == 4:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = None
+
+ if masked_image is not None:
+ if masked_image_latents is None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+ else:
+ t_start = 0
+
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ # Strength is irrelevant if we directly request a timestep to start at;
+ # that is, strength is determined by the denoising_start instead.
+ if denoising_start is not None:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
+ )
+ )
+
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
+ # if the scheduler is a 2nd order scheduler we might have to do +1
+ # because `num_inference_steps` might be even given that every timestep
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+ # mean that we cut the timesteps in the middle of the denoising step
+ # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+ num_inference_steps = num_inference_steps + 1
+
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
+ timesteps = timesteps[-num_inference_steps:]
+ return timesteps, num_inference_steps
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+ def _get_add_time_ids(
+ self,
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype,
+ text_encoder_projection_dim=None,
+ ):
+ if self.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+ )
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+
+ @property
+ def do_self_attention_redirection_guidance(self): #SARG
+ return self._rm_guidance_scale > 1 and self._AAS
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None and self.do_self_attention_redirection_guidance==False #CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
+ @property
+ def denoising_start(self):
+ return self._denoising_start
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ def image2latent(self, image: torch.Tensor, generator: torch.Generator):
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ if type(image) is Image:
+ image = np.array(image)
+ image = torch.from_numpy(image).float() / 127.5 - 1
+ image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
+ # input image density range [-1, 1]
+ #latents = self.vae.encode(image)['latent_dist'].mean
+ latents = self._encode_vae_image(image, generator)
+ #latents = retrieve_latents(self.vae.encode(image))
+ #latents = latents * self.vae.config.scaling_factor
+ return latents
+
+ def next_step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ x: torch.FloatTensor,
+ eta=0.,
+ verbose=False
+ ):
+ """
+ Inverse sampling for DDIM Inversion
+ """
+ if verbose:
+ print("timestep: ", timestep)
+ next_step = timestep
+ timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
+ beta_prod_t = 1 - alpha_prod_t
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+ pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
+ x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
+ return x_next, pred_x0
+
+ @torch.no_grad()
+ def invert(
+ self,
+ image: torch.Tensor,
+ prompt,
+ num_inference_steps=50,
+ eta=0.0,
+ original_size: Tuple[int, int] = None,
+ target_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ return_intermediates=False,
+ **kwds):
+ """
+ invert a real image into noise map with determinisc DDIM inversion
+ """
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ batch_size = image.shape[0]
+ if isinstance(prompt, list):
+ if batch_size == 1:
+ image = image.expand(len(prompt), -1, -1, -1)
+ elif isinstance(prompt, str):
+ if batch_size > 1:
+ prompt = [prompt] * batch_size
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ prompt_2 = prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(DEVICE), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ prompt_embeds_list.append(prompt_embeds)
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=DEVICE)
+
+ # define initial latents
+ latents = self.image2latent(image,generator=None)
+
+ start_latents = latents
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = (height, width)
+ target_size = (height, width)
+ negative_original_size = original_size
+ negative_target_size = target_size
+
+ add_text_embeds = pooled_prompt_embeds
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+
+ add_time_ids = add_time_ids.repeat(batch_size, 1).to(DEVICE)
+
+ # interative sampling
+ self.scheduler.set_timesteps(num_inference_steps)
+ latents_list = [latents]
+ pred_x0_list = []
+ #for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
+ for i, t in enumerate(reversed(self.scheduler.timesteps)):
+ model_inputs = latents
+
+ # predict the noise
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(model_inputs, t, encoder_hidden_states=prompt_embeds,added_cond_kwargs=added_cond_kwargs).sample
+
+ # compute the previous noise sample x_t-1 -> x_t
+ latents, pred_x0 = self.next_step(noise_pred, t, latents)
+ """
+ if t >= 1 and t < 41:
+ latents, pred_x0 = self.next_step_degrade(noise_pred, t, latents, mask)
+ else:
+ latents, pred_x0 = self.next_step(noise_pred, t, latents) """
+
+ latents_list.append(latents)
+ pred_x0_list.append(pred_x0)
+
+ if return_intermediates:
+ # return the intermediate laters during inversion
+ #pred_x0_list = [self.latent2image(img, return_type="np") for img in pred_x0_list]
+ #latents_list = [self.latent2image(img, return_type="np") for img in latents_list]
+ return latents, latents_list, pred_x0_list
+ return latents, start_latents
+
+ def opt(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ x: torch.FloatTensor,
+ ):
+ """
+ predict the sampe the next step in the denoise process.
+ """
+ ref_noise = model_output[:1,:,:,:].expand(model_output.shape)
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+ x_opt = alpha_prod_t**0.5 * pred_x0 + (1 - alpha_prod_t)**0.5 * ref_noise
+ return x_opt, pred_x0
+
+ def regiter_attention_editor_diffusers(self, unet, editor: AttentionBase):
+ """
+ Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
+ """
+ def ca_forward(self, place_in_unet):
+ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
+ """
+ The attention is similar to the original implementation of LDM CrossAttention class
+ except adding some modifications on the attention
+ """
+ if encoder_hidden_states is not None:
+ context = encoder_hidden_states
+ if attention_mask is not None:
+ mask = attention_mask
+
+ to_out = self.to_out
+ if isinstance(to_out, nn.modules.container.ModuleList):
+ to_out = self.to_out[0]
+ else:
+ to_out = self.to_out
+
+ h = self.heads
+ q = self.to_q(x)
+ is_cross = context is not None
+ context = context if is_cross else x
+ k = self.to_k(context)
+ v = self.to_v(context)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if mask is not None:
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ mask = mask[:, None, :].repeat(h, 1, 1)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ attn = sim.softmax(dim=-1)
+ # the only difference
+ out = editor(
+ q, k, v, sim, attn, is_cross, place_in_unet,
+ self.heads, scale=self.scale)
+
+ return to_out(out)
+
+ return forward
+
+ def register_editor(net, count, place_in_unet):
+ for name, subnet in net.named_children():
+ if net.__class__.__name__ == 'Attention': # spatial Transformer layer
+ net.forward = ca_forward(net, place_in_unet)
+ return count + 1
+ elif hasattr(net, 'children'):
+ count = register_editor(subnet, count, place_in_unet)
+ return count
+
+ cross_att_count = 0
+ for net_name, net in unet.named_children():
+ if "down" in net_name:
+ cross_att_count += register_editor(net, 0, "down")
+ elif "mid" in net_name:
+ cross_att_count += register_editor(net, 0, "mid")
+ elif "up" in net_name:
+ cross_att_count += register_editor(net, 0, "up")
+ editor.num_att_layers = cross_att_count
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: torch.FloatTensor = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ padding_mask_crop: Optional[int] = None,
+ strength: float = 0.9999,
+ AAS: bool = True, # AE parameter
+ rm_guidance_scale: float = 7.0, # AE parameter
+ ss_steps: int = 9, # AE parameter
+ ss_scale: float = 0.3, # AE parameter
+ AAS_start_step: int = 0, # AE parameter
+ AAS_start_layer: int = 34, # AE parameter
+ AAS_end_layer: int = 70, # AE parameter
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
+ `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
+ contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
+ the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
+ and contain information inreleant for inpainging, such as background.
+ strength (`float`, *optional*, defaults to 0.9999):
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+ portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
+ integer, the value of `strength` will be ignored.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
+ Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
+ if `do_classifier_free_guidance` is set to `True`.
+ If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ height,
+ width,
+ strength,
+ callback_steps,
+ output_type,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs,
+ padding_mask_crop,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._denoising_start = denoising_start
+ self._interrupt = False
+
+ ########### AE parameters
+ self._num_timesteps = num_inference_steps
+ self._rm_guidance_scale = rm_guidance_scale
+ self._AAS = AAS
+ self._ss_steps = ss_steps
+ self._ss_scale = ss_scale
+ self._AAS_start_step = AAS_start_step
+ self._AAS_start_layer = AAS_start_layer
+ self._AAS_end_layer = AAS_end_layer
+ ###########
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. set timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(dnv, float) and 0 < dnv < 1
+
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps,
+ strength,
+ device,
+ denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
+ )
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ original_image = image
+ init_image = self.image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ init_image = init_image.to(dtype=torch.float32)
+
+ mask = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ if masked_image_latents is not None:
+ masked_image = masked_image_latents
+ elif init_image.shape[1] == 4:
+ # if images are in latent space, we can't mask it
+ masked_image = None
+ else:
+ masked_image = init_image * (mask < 0.5)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ add_noise = True if self.denoising_start is None else False
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ add_noise=add_noise,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+ # 8.1 Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 10. Prepare added time ids & embeddings
+ if negative_original_size is None:
+ negative_original_size = original_size
+ if negative_target_size is None:
+ negative_target_size = target_size
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+
+ ###########
+ if self.do_self_attention_redirection_guidance:
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(2, 1)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+ ############
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # apply AAS to modify the attention module
+ if self.do_self_attention_redirection_guidance:
+ self._AAS_end_step = int(strength * self._num_timesteps)
+ layer_idx=list(range(self._AAS_start_layer, self._AAS_end_layer))
+ editor = AAS_XL(self._AAS_start_step, self._AAS_end_step, self._AAS_start_layer, self._AAS_end_layer, layer_idx= layer_idx, mask=mask_image,model_type="SDXL",ss_steps=self._ss_steps,ss_scale=self._ss_scale)
+ self.regiter_attention_editor_diffusers(self.unet, editor)
+
+
+ # 11. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ if (
+ self.denoising_end is not None
+ and self.denoising_start is not None
+ and denoising_value_valid(self.denoising_end)
+ and denoising_value_valid(self.denoising_start)
+ and self.denoising_start >= self.denoising_end
+ ):
+ raise ValueError(
+ f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {self.denoising_end} when using type float."
+ )
+ elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 11.1 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ #removal guidance
+ latent_model_input = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents #CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+ #latent_model_input_rm = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ #latent_model_input = self.scheduler.scale_model_input(latent_model_input_rm, t)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform SARG
+ if self.do_self_attention_redirection_guidance:
+ noise_pred_wo, noise_pred_w = noise_pred.chunk(2)
+ delta = noise_pred_w - noise_pred_wo
+ noise_pred = noise_pred_wo + self._rm_guidance_scale * delta
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents
+ if self.do_classifier_free_guidance:
+ init_mask, _ = mask.chunk(2)
+ else:
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+ mask = callback_outputs.pop("mask", mask)
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ latents = latents[-1:]
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ return StableDiffusionXLPipelineOutput(images=latents)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if padding_mask_crop is not None:
+ image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
From a2f805d6c9a6881c025e5a9657c68c4717f7aac7 Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Tue, 14 Jan 2025 23:22:17 +0800
Subject: [PATCH 2/7] add
pipeline_stable_diffusion_xl_attentive_eraser_make_style
---
...ne_stable_diffusion_xl_attentive_eraser.py | 1014 ++++++++++++-----
1 file changed, 718 insertions(+), 296 deletions(-)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
index cebd80c8bfa1..764a8c1defcb 100644
--- a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -14,52 +14,41 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-from tqdm import tqdm
+
import numpy as np
import PIL.Image
-from PIL import Image
import torch
-from transformers import (
- CLIPImageProcessor,
- CLIPTextModel,
- CLIPTextModelWithProjection,
- CLIPTokenizer,
- CLIPVisionModelWithProjection,
-)
+from PIL import Image
+from transformers import (CLIPImageProcessor, CLIPTextModel,
+ CLIPTextModelWithProjection, CLIPTokenizer,
+ CLIPVisionModelWithProjection)
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
-from diffusers.loaders import (
- FromSingleFileMixin,
- IPAdapterMixin,
- StableDiffusionXLLoraLoaderMixin,
- TextualInversionLoaderMixin,
-)
-from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
-from diffusers.models.attention_processor import (
- AttnProcessor2_0,
- LoRAAttnProcessor2_0,
- LoRAXFormersAttnProcessor,
- XFormersAttnProcessor,
-)
+from diffusers.loaders import (FromSingleFileMixin, IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin)
+from diffusers.models import (AutoencoderKL, ImageProjection,
+ UNet2DConditionModel)
+from diffusers.models.attention_processor import (AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor)
from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import (DiffusionPipeline,
+ StableDiffusionMixin)
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import \
+ StableDiffusionXLPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers
-from diffusers.utils import (
- USE_PEFT_BACKEND,
- deprecate,
- is_invisible_watermark_available,
- is_torch_xla_available,
- logging,
- replace_example_docstring,
- scale_lora_layers,
- unscale_lora_layers,
-)
+from diffusers.utils import (USE_PEFT_BACKEND, deprecate,
+ is_invisible_watermark_available,
+ is_torch_xla_available, logging,
+ replace_example_docstring, scale_lora_layers,
+ unscale_lora_layers)
from diffusers.utils.torch_utils import randn_tensor
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
-from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
-
if is_invisible_watermark_available():
- from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+ from diffusers.pipelines.stable_diffusion_xl.watermark import \
+ StableDiffusionXLWatermarker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
@@ -68,10 +57,10 @@
else:
XLA_AVAILABLE = False
-from einops import rearrange, repeat
+
import torch.nn as nn
import torch.nn.functional as F
-import os
+from einops import rearrange, repeat
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -104,6 +93,7 @@
```
"""
+
class AttentionBase:
def __init__(self):
self.cur_step = 0
@@ -113,8 +103,12 @@ def __init__(self):
def after_step(self):
pass
- def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
- out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+ def __call__(
+ self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs
+ ):
+ out = self.forward(
+ q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs
+ )
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers:
self.cur_att_layer = 0
@@ -123,21 +117,34 @@ def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwa
self.after_step()
return out
- def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
- out = torch.einsum('b i j, b j d -> b i d', attn, v)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
+ def forward(self, q, k, v, sim, attn, is_cross,
+ place_in_unet, num_heads, **kwargs):
+ out = torch.einsum("b i j, b j d -> b i d", attn, v)
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=num_heads)
return out
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
+
class AAS_XL(AttentionBase):
- MODEL_TYPE = {
- "SD": 16,
- "SDXL": 70
- }
- def __init__(self, start_step=4, end_step= 50, start_layer=10, end_layer=16,layer_idx=None, step_idx=None, total_steps=50, mask=None, model_type="SD",ss_steps=9,ss_scale=1.0):
+ MODEL_TYPE = {"SD": 16, "SDXL": 70}
+
+ def __init__(
+ self,
+ start_step=4,
+ end_step=50,
+ start_layer=10,
+ end_layer=16,
+ layer_idx=None,
+ step_idx=None,
+ total_steps=50,
+ mask=None,
+ model_type="SD",
+ ss_steps=9,
+ ss_scale=1.0,
+ ):
"""
Args:
start_step: the step to start AAS
@@ -155,49 +162,92 @@ def __init__(self, start_step=4, end_step= 50, start_layer=10, end_layer=16,lay
self.end_step = end_step
self.start_layer = start_layer
self.end_layer = end_layer
- self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, end_layer))
- self.step_idx = step_idx if step_idx is not None else list(range(start_step, end_step))
+ self.layer_idx = (
+ layer_idx if layer_idx is not None else list(
+ range(start_layer, end_layer))
+ )
+ self.step_idx = (
+ step_idx if step_idx is not None else list(
+ range(start_step, end_step))
+ )
self.mask = mask # mask with shape (1, 1 ,h, w)
self.ss_steps = ss_steps
self.ss_scale = ss_scale
print("AAS at denoising steps: ", self.step_idx)
print("AAS at U-Net layers: ", self.layer_idx)
print("start AAS")
- self.mask_16 = F.max_pool2d(mask,(1024//16,1024//16)).round().squeeze().squeeze()
- self.mask_32 = F.max_pool2d(mask,(1024//32,1024//32)).round().squeeze().squeeze()
- self.mask_64 = F.max_pool2d(mask,(1024//64,1024//64)).round().squeeze().squeeze()
- self.mask_128 = F.max_pool2d(mask,(1024//128,1024//128)).round().squeeze().squeeze()
-
- def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,is_mask_attn, mask, **kwargs):
+ self.mask_16 = (
+ F.max_pool2d(mask, (1024 // 16, 1024 // 16)
+ ).round().squeeze().squeeze()
+ )
+ self.mask_32 = (
+ F.max_pool2d(mask, (1024 // 32, 1024 // 32)
+ ).round().squeeze().squeeze()
+ )
+ self.mask_64 = (
+ F.max_pool2d(mask, (1024 // 64, 1024 // 64)
+ ).round().squeeze().squeeze()
+ )
+ self.mask_128 = (
+ F.max_pool2d(mask, (1024 // 128, 1024 // 128)
+ ).round().squeeze().squeeze()
+ )
+
+ def attn_batch(
+ self,
+ q,
+ k,
+ v,
+ sim,
+ attn,
+ is_cross,
+ place_in_unet,
+ num_heads,
+ is_mask_attn,
+ mask,
+ **kwargs,
+ ):
B = q.shape[0] // num_heads
if is_mask_attn:
mask_flatten = mask.flatten(0)
- if self.cur_step <= self.ss_steps:
+ if self.cur_step <= self.ss_steps:
# background
- sim_bg = sim + mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+ sim_bg = sim + mask_flatten.masked_fill(
+ mask_flatten == 1, torch.finfo(sim.dtype).min
+ )
- #object
- sim_fg = self.ss_scale*sim
- sim_fg += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+ # object
+ sim_fg = self.ss_scale * sim
+ sim_fg += mask_flatten.masked_fill(
+ mask_flatten == 1, torch.finfo(sim.dtype).min
+ )
sim = torch.cat([sim_fg, sim_bg], dim=0)
else:
- sim += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+ sim += mask_flatten.masked_fill(
+ mask_flatten == 1, torch.finfo(sim.dtype).min
+ )
attn = sim.softmax(-1)
if len(attn) == 2 * len(v):
v = torch.cat([v] * 2)
out = torch.einsum("h i j, h j d -> h i d", attn, v)
- out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
+ out = rearrange(
+ out,
+ "(h1 h) (b n) d -> (h1 b) n (h d)",
+ b=B,
+ h=num_heads)
return out
- def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+ def forward(self, q, k, v, sim, attn, is_cross,
+ place_in_unet, num_heads, **kwargs):
"""
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
- B = q.shape[0] // num_heads // 2
- H = W = int(np.sqrt(q.shape[1]))
+ # B = q.shape[0] // num_heads // 2
+ # H = W = int(np.sqrt(q.shape[1]))
+ H = int(np.sqrt(q.shape[1]))
if H == 16:
mask = self.mask_16.to(sim.device)
elif H == 32:
@@ -207,15 +257,38 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
else:
mask = self.mask_128.to(sim.device)
-
q_wo, q_w = q.chunk(2)
k_wo, k_w = k.chunk(2)
v_wo, v_w = v.chunk(2)
sim_wo, sim_w = sim.chunk(2)
attn_wo, attn_w = attn.chunk(2)
- out_source = self.attn_batch(q_wo, k_wo, v_wo, sim_wo, attn_wo, is_cross, place_in_unet, num_heads,is_mask_attn=False,mask=None,**kwargs)
- out_target = self.attn_batch(q_w, k_w, v_w, sim_w, attn_w, is_cross, place_in_unet, num_heads, is_mask_attn=True, mask = mask,**kwargs)
+ out_source = self.attn_batch(
+ q_wo,
+ k_wo,
+ v_wo,
+ sim_wo,
+ attn_wo,
+ is_cross,
+ place_in_unet,
+ num_heads,
+ is_mask_attn=False,
+ mask=None,
+ **kwargs,
+ )
+ out_target = self.attn_batch(
+ q_w,
+ k_w,
+ v_w,
+ sim_w,
+ attn_w,
+ is_cross,
+ place_in_unet,
+ num_heads,
+ is_mask_attn=True,
+ mask=mask,
+ **kwargs,
+ )
if self.mask is not None:
if out_target.shape[0] == 2:
@@ -224,11 +297,13 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
out_target = out_target_fg * mask + out_target_bg * (1 - mask)
else:
out_target = out_target
-
+
out = torch.cat([out_source, out_target], dim=0)
return out
-
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+
+
+# Copied from
+# diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -259,7 +334,9 @@ def mask_pil_to_torch(mask, height, width):
return mask
-def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+def prepare_mask_and_masked_image(
+ image, mask, height, width, return_image: bool = False
+):
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
@@ -314,7 +391,8 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
# Batch single mask or add channel dim
if mask.ndim == 3:
- # Single batched mask, no channel dim or single mask not batched but channel dim
+ # Single batched mask, no channel dim or single mask not batched
+ # but channel dim
if mask.shape[0] == 1:
mask = mask.unsqueeze(0)
@@ -322,9 +400,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
else:
mask = mask.unsqueeze(1)
- assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ assert (
+ image.ndim == 4 and mask.ndim == 4
+ ), "Image and Mask must have 4 dimensions"
# assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
- assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+ assert (
+ image.shape[0] == mask.shape[0]
+ ), "Image and Mask must have the same batch size"
# Check image is in [-1, 1]
# if image.min() < -1 or image.max() > 1:
@@ -341,14 +423,18 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
# Image as float32
image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
- raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ raise TypeError(
+ f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
+ )
else:
# preprocess image
if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
# resize all images w.r.t passed height an width
- image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+ image = [
+ i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image
+ ]
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
@@ -377,9 +463,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
return mask, masked_image
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+# Copied from
+# diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+ encoder_output: torch.Tensor,
+ generator: Optional[torch.Generator] = None,
+ sample_mode: str = "sample",
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
@@ -388,10 +477,12 @@ def retrieve_latents(
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
- raise AttributeError("Could not access latents of provided encoder_output")
+ raise AttributeError(
+ "Could not access latents of provided encoder_output")
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+# Copied from
+# diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -421,7 +512,9 @@ def retrieve_timesteps(
second element is the number of inference steps.
"""
if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ accepts_timesteps = "timesteps" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
+ )
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -542,55 +635,86 @@ def __init__(
feature_extractor=feature_extractor,
scheduler=scheduler,
)
- self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt
+ )
+ self.register_to_config(
+ requires_aesthetics_score=requires_aesthetics_score)
+ self.vae_scale_factor = 2 ** (
+ len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
- vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ vae_scale_factor=self.vae_scale_factor,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
)
- add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+ add_watermarker = (
+ add_watermarker
+ if add_watermarker is not None
+ else is_invisible_watermark_available()
+ )
if add_watermarker:
self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ # Copied from
+ # diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(
+ self, image, device, num_images_per_prompt, output_hidden_states=None
+ ):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
+ image = self.feature_extractor(
+ image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
- image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
- image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ image_enc_hidden_states = self.image_encoder(
+ image, output_hidden_states=True
+ ).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
- uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
- num_images_per_prompt, dim=0
+ uncond_image_enc_hidden_states = (
+ uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ image_embeds = image_embeds.repeat_interleave(
+ num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ # Copied from
+ # diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
- self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ self,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
- if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ if len(ip_adapter_image) != len(
+ self.unet.encoder_hid_proj.image_projection_layers
+ ):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
@@ -634,7 +758,8 @@ def prepare_ip_adapter_image_embeds(
return image_embeds
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ # Copied from
+ # diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
self,
prompt: str,
@@ -697,19 +822,23 @@ def encode_prompt(
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
- if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ if lora_scale is not None and isinstance(
+ self, StableDiffusionXLLoraLoaderMixin
+ ):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None:
if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ adjust_lora_scale_text_encoder(
+ self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ adjust_lora_scale_text_encoder(
+ self.text_encoder_2, lora_scale)
else:
scale_lora_layers(self.text_encoder_2, lora_scale)
@@ -721,9 +850,15 @@ def encode_prompt(
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
- tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ tokenizers = (
+ [self.tokenizer, self.tokenizer_2]
+ if self.tokenizer is not None
+ else [self.tokenizer_2]
+ )
text_encoders = (
- [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ [self.text_encoder, self.text_encoder_2]
+ if self.text_encoder is not None
+ else [self.text_encoder_2]
)
if prompt_embeds is None:
@@ -733,7 +868,9 @@ def encode_prompt(
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
- for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ for prompt, tokenizer, text_encoder in zip(
+ prompts, tokenizers, text_encoders
+ ):
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
@@ -746,26 +883,34 @@ def encode_prompt(
)
text_input_ids = text_inputs.input_ids
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
-
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
- text_input_ids, untruncated_ids
- ):
- removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ untruncated_ids = tokenizer(
+ prompt, padding="longest", return_tensors="pt"
+ ).input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
+ -1
+ ] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = tokenizer.batch_decode(
+ untruncated_ids[:, tokenizer.model_max_length - 1: -1]
+ )
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
- prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device), output_hidden_states=True
+ )
- # We are only ALWAYS interested in the pooled output of the final text encoder
+ # We are only ALWAYS interested in the pooled output of the
+ # final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
- prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+ prompt_embeds = prompt_embeds.hidden_states[-(
+ clip_skip + 2)]
prompt_embeds_list.append(prompt_embeds)
@@ -819,70 +964,99 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
- # We are only ALWAYS interested in the pooled output of the final text encoder
+ # We are only ALWAYS interested in the pooled output of the
+ # final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
- negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+ negative_prompt_embeds = torch.concat(
+ negative_prompt_embeds_list, dim=-1)
if self.text_encoder_2 is not None:
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ prompt_embeds = prompt_embeds.to(
+ dtype=self.text_encoder_2.dtype, device=device
+ )
else:
- prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+ prompt_embeds = prompt_embeds.to(
+ dtype=self.unet.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
- # duplicate text embeddings for each generation per prompt, using mps friendly method
+ # duplicate text embeddings for each generation per prompt, using mps
+ # friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_embeds = prompt_embeds.view(
+ bs_embed * num_images_per_prompt, seq_len, -1
+ )
if do_classifier_free_guidance:
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ # duplicate unconditional embeddings for each generation per
+ # prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
if self.text_encoder_2 is not None:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ negative_prompt_embeds = negative_prompt_embeds.to(
+ dtype=self.text_encoder_2.dtype, device=device
+ )
else:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+ negative_prompt_embeds = negative_prompt_embeds.to(
+ dtype=self.unet.dtype, device=device
+ )
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
+ 1, num_images_per_prompt, 1
+ )
+ negative_prompt_embeds = negative_prompt_embeds.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
- bs_embed * num_images_per_prompt, -1
- )
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(
+ 1, num_images_per_prompt
+ ).view(bs_embed * num_images_per_prompt, -1)
if do_classifier_free_guidance:
- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
- bs_embed * num_images_per_prompt, -1
- )
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
+ 1, num_images_per_prompt
+ ).view(bs_embed * num_images_per_prompt, -1)
if self.text_encoder is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ if isinstance(
+ self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ if isinstance(
+ self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)
- return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+ return (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ # Copied from
+ # diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ accepts_eta = "eta" in set(
+ inspect.signature(self.scheduler.step).parameters.keys()
+ )
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ accepts_generator = "generator" in set(
+ inspect.signature(self.scheduler.step).parameters.keys()
+ )
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
@@ -908,19 +1082,26 @@ def check_inputs(
padding_mask_crop=None,
):
if strength < 0 or strength > 1:
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+ raise ValueError(
+ f"The value of strength should in [0.0, 1.0] but is {strength}"
+ )
if height % 8 != 0 or width % 8 != 0:
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ raise ValueError(
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
+ )
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if callback_steps is not None and (
+ not isinstance(callback_steps, int) or callback_steps <= 0
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ k in self._callback_tensor_inputs
+ for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
@@ -940,10 +1121,18 @@ def check_inputs(
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
- elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
- raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+ elif prompt is not None and (
+ not isinstance(prompt, str) and not isinstance(prompt, list)
+ ):
+ raise ValueError(
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
+ )
+ elif prompt_2 is not None and (
+ not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
+ ):
+ raise ValueError(
+ f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
+ )
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
@@ -966,7 +1155,8 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -974,7 +1164,10 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(
+ f"The output type should be PIL when inpainting mask crop, but is"
+ f" {output_type}."
+ )
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1008,7 +1201,12 @@ def prepare_latents(
return_noise=False,
return_image_latents=False,
):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -1023,23 +1221,46 @@ def prepare_latents(
if image.shape[1] == 4:
image_latents = image.to(device=device, dtype=dtype)
- image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+ image_latents = image_latents.repeat(
+ batch_size // image_latents.shape[0], 1, 1, 1
+ )
elif return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
- image_latents = self._encode_vae_image(image=image, generator=generator)
- image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+ image_latents = self._encode_vae_image(
+ image=image, generator=generator)
+ image_latents = image_latents.repeat(
+ batch_size // image_latents.shape[0], 1, 1, 1
+ )
if latents is None and add_noise:
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
- # if strength is 1. then initialise the latents to noise, else initial to image + noise
- latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
- # if pure noise then scale the initial latents by the Scheduler's init sigma
- latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ noise = randn_tensor(
+ shape,
+ generator=generator,
+ device=device,
+ dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else
+ # initial to image + noise
+ latents = (
+ noise
+ if is_strength_max
+ else self.scheduler.add_noise(image_latents, noise, timestep)
+ )
+ # if pure noise then scale the initial latents by the Scheduler's
+ # init sigma
+ latents = (
+ latents * self.scheduler.init_noise_sigma
+ if is_strength_max
+ else latents
+ )
elif add_noise:
noise = latents.to(device)
latents = noise * self.scheduler.init_noise_sigma
else:
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ noise = randn_tensor(
+ shape,
+ generator=generator,
+ device=device,
+ dtype=dtype)
latents = image_latents.to(device)
outputs = (latents,)
@@ -1052,7 +1273,8 @@ def prepare_latents(
return outputs
- def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ def _encode_vae_image(self, image: torch.Tensor,
+ generator: torch.Generator):
dtype = image.dtype
if self.vae.config.force_upcast:
image = image.float()
@@ -1060,12 +1282,16 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ retrieve_latents(
+ self.vae.encode(image[i: i + 1]), generator=generator[i]
+ )
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
- image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+ image_latents = retrieve_latents(
+ self.vae.encode(image), generator=generator
+ )
if self.vae.config.force_upcast:
self.vae.to(dtype)
@@ -1076,18 +1302,28 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
return image_latents
def prepare_mask_latents(
- self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
- #mask = torch.nn.functional.interpolate(
+ # mask = torch.nn.functional.interpolate(
# mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
- #)
- mask = torch.nn.functional.max_pool2d(mask, (8,8)).round()
+ # )
+ mask = torch.nn.functional.max_pool2d(mask, (8, 8)).round()
mask = mask.to(device=device, dtype=dtype)
- # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ # duplicate mask and masked_image_latents for each generation per
+ # prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
@@ -1107,7 +1343,9 @@ def prepare_mask_latents(
if masked_image is not None:
if masked_image_latents is None:
masked_image = masked_image.to(device=device, dtype=dtype)
- masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+ masked_image_latents = self._encode_vae_image(
+ masked_image, generator=generator
+ )
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
@@ -1121,16 +1359,23 @@ def prepare_mask_latents(
)
masked_image_latents = (
- torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ torch.cat([masked_image_latents] * 2)
+ if do_classifier_free_guidance
+ else masked_image_latents
)
- # aligning device to prevent device errors when concating it with the latent model input
- masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ # aligning device to prevent device errors when concating it with
+ # the latent model input
+ masked_image_latents = masked_image_latents.to(
+ device=device, dtype=dtype)
return mask, masked_image_latents
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
- def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+ # Copied from
+ # diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+ def get_timesteps(
+ self, num_inference_steps, strength, device, denoising_start=None
+ ):
# get the original timestep using init_timestep
if denoising_start is None:
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
@@ -1138,7 +1383,7 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
else:
t_start = 0
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
# Strength is irrelevant if we directly request a timestep to start at;
# that is, strength is determined by the denoising_start instead.
@@ -1157,16 +1402,19 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
- # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+ # we ensure that the denoising process always ends after the
+ # 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1
- # because t_n+1 >= t_n, we slice the timesteps starting from the end
+ # because t_n+1 >= t_n, we slice the timesteps starting from the
+ # end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps
return timesteps, num_inference_steps - t_start
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+ # Copied from
+ # diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
def _get_add_time_ids(
self,
original_size,
@@ -1218,7 +1466,8 @@ def _get_add_time_ids(
return add_time_ids, add_neg_time_ids
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ # Copied from
+ # diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
dtype = self.vae.dtype
self.vae.to(dtype=torch.float32)
@@ -1238,8 +1487,10 @@ def upcast_vae(self):
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ # Copied from
+ # diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
@@ -1279,18 +1530,21 @@ def guidance_rescale(self):
def clip_skip(self):
return self._clip_skip
-
@property
- def do_self_attention_redirection_guidance(self): #SARG
+ def do_self_attention_redirection_guidance(self): # SARG
return self._rm_guidance_scale > 1 and self._AAS
-
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None and self.do_self_attention_redirection_guidance==False #CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
-
+ return (
+ self._guidance_scale > 1
+ and self.unet.config.time_cond_proj_dim is None
+ and not self.do_self_attention_redirection_guidance
+ ) # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@@ -1313,25 +1567,28 @@ def interrupt(self):
@torch.no_grad()
def image2latent(self, image: torch.Tensor, generator: torch.Generator):
- DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ DEVICE = (
+ torch.device("cuda") if torch.cuda.is_available(
+ ) else torch.device("cpu")
+ )
if type(image) is Image:
image = np.array(image)
image = torch.from_numpy(image).float() / 127.5 - 1
image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
# input image density range [-1, 1]
- #latents = self.vae.encode(image)['latent_dist'].mean
+ # latents = self.vae.encode(image)['latent_dist'].mean
latents = self._encode_vae_image(image, generator)
- #latents = retrieve_latents(self.vae.encode(image))
- #latents = latents * self.vae.config.scaling_factor
+ # latents = retrieve_latents(self.vae.encode(image))
+ # latents = latents * self.vae.config.scaling_factor
return latents
-
+
def next_step(
self,
model_output: torch.FloatTensor,
timestep: int,
x: torch.FloatTensor,
- eta=0.,
- verbose=False
+ eta=0.0,
+ verbose=False,
):
"""
Inverse sampling for DDIM Inversion
@@ -1339,15 +1596,24 @@ def next_step(
if verbose:
print("timestep: ", timestep)
next_step = timestep
- timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
- alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
+ timestep = min(
+ timestep
+ - self.scheduler.config.num_train_timesteps
+ // self.scheduler.num_inference_steps,
+ 999,
+ )
+ alpha_prod_t = (
+ self.scheduler.alphas_cumprod[timestep]
+ if timestep >= 0
+ else self.scheduler.final_alpha_cumprod
+ )
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
beta_prod_t = 1 - alpha_prod_t
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
- pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
+ pred_dir = (1 - alpha_prod_t_next) ** 0.5 * model_output
x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
return x_next, pred_x0
-
+
@torch.no_grad()
def invert(
self,
@@ -1362,11 +1628,15 @@ def invert(
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
return_intermediates=False,
- **kwds):
+ **kwds,
+ ):
"""
invert a real image into noise map with determinisc DDIM inversion
"""
- DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ DEVICE = (
+ torch.device("cuda") if torch.cuda.is_available(
+ ) else torch.device("cpu")
+ )
batch_size = image.shape[0]
if isinstance(prompt, list):
if batch_size == 1:
@@ -1376,9 +1646,15 @@ def invert(
prompt = [prompt] * batch_size
# Define tokenizers and text encoders
- tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ tokenizers = (
+ [self.tokenizer, self.tokenizer_2]
+ if self.tokenizer is not None
+ else [self.tokenizer_2]
+ )
text_encoders = (
- [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ [self.text_encoder, self.text_encoder_2]
+ if self.text_encoder is not None
+ else [self.text_encoder_2]
)
prompt_2 = prompt
@@ -1387,7 +1663,8 @@ def invert(
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
- for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ for prompt, tokenizer, text_encoder in zip(
+ prompts, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
@@ -1400,20 +1677,27 @@ def invert(
)
text_input_ids = text_inputs.input_ids
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
-
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
- text_input_ids, untruncated_ids
- ):
- removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ untruncated_ids = tokenizer(
+ prompt, padding="longest", return_tensors="pt"
+ ).input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
+ -1
+ ] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = tokenizer.batch_decode(
+ untruncated_ids[:, tokenizer.model_max_length - 1: -1]
+ )
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
- prompt_embeds = text_encoder(text_input_ids.to(DEVICE), output_hidden_states=True)
+ prompt_embeds = text_encoder(
+ text_input_ids.to(DEVICE), output_hidden_states=True
+ )
- # We are only ALWAYS interested in the pooled output of the final text encoder
+ # We are only ALWAYS interested in the pooled output of the final
+ # text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
@@ -1421,7 +1705,7 @@ def invert(
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=DEVICE)
# define initial latents
- latents = self.image2latent(image,generator=None)
+ latents = self.image2latent(image, generator=None)
start_latents = latents
height, width = latents.shape[-2:]
@@ -1454,32 +1738,41 @@ def invert(
self.scheduler.set_timesteps(num_inference_steps)
latents_list = [latents]
pred_x0_list = []
- #for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
+ # for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps),
+ # desc="DDIM Inversion")):
for i, t in enumerate(reversed(self.scheduler.timesteps)):
model_inputs = latents
# predict the noise
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
- noise_pred = self.unet(model_inputs, t, encoder_hidden_states=prompt_embeds,added_cond_kwargs=added_cond_kwargs).sample
+ added_cond_kwargs = {
+ "text_embeds": add_text_embeds,
+ "time_ids": add_time_ids,
+ }
+ noise_pred = self.unet(
+ model_inputs,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ added_cond_kwargs=added_cond_kwargs,
+ ).sample
# compute the previous noise sample x_t-1 -> x_t
latents, pred_x0 = self.next_step(noise_pred, t, latents)
- """
+ """
if t >= 1 and t < 41:
latents, pred_x0 = self.next_step_degrade(noise_pred, t, latents, mask)
else:
latents, pred_x0 = self.next_step(noise_pred, t, latents) """
-
+
latents_list.append(latents)
pred_x0_list.append(pred_x0)
if return_intermediates:
# return the intermediate laters during inversion
- #pred_x0_list = [self.latent2image(img, return_type="np") for img in pred_x0_list]
- #latents_list = [self.latent2image(img, return_type="np") for img in latents_list]
+ # pred_x0_list = [self.latent2image(img, return_type="np") for img in pred_x0_list]
+ # latents_list = [self.latent2image(img, return_type="np") for img in latents_list]
return latents, latents_list, pred_x0_list
return latents, start_latents
-
+
def opt(
self,
model_output: torch.FloatTensor,
@@ -1489,19 +1782,27 @@ def opt(
"""
predict the sampe the next step in the denoise process.
"""
- ref_noise = model_output[:1,:,:,:].expand(model_output.shape)
+ ref_noise = model_output[:1, :, :, :].expand(model_output.shape)
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
beta_prod_t = 1 - alpha_prod_t
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
- x_opt = alpha_prod_t**0.5 * pred_x0 + (1 - alpha_prod_t)**0.5 * ref_noise
+ x_opt = alpha_prod_t**0.5 * pred_x0 + \
+ (1 - alpha_prod_t) ** 0.5 * ref_noise
return x_opt, pred_x0
-
+
def regiter_attention_editor_diffusers(self, unet, editor: AttentionBase):
"""
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
"""
+
def ca_forward(self, place_in_unet):
- def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
+ def forward(
+ x,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ context=None,
+ mask=None,
+ ):
"""
The attention is similar to the original implementation of LDM CrossAttention class
except adding some modifications on the attention
@@ -1523,22 +1824,33 @@ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, ma
context = context if is_cross else x
k = self.to_k(context)
v = self.to_v(context)
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ q, k, v = map(
+ lambda t: rearrange(
+ t, "b n (h d) -> (b h) n d", h=h), (q, k, v)
+ )
- sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if mask is not None:
- mask = rearrange(mask, 'b ... -> b (...)')
+ mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
mask = mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~mask, max_neg_value)
attn = sim.softmax(dim=-1)
# the only difference
out = editor(
- q, k, v, sim, attn, is_cross, place_in_unet,
- self.heads, scale=self.scale)
+ q,
+ k,
+ v,
+ sim,
+ attn,
+ is_cross,
+ place_in_unet,
+ self.heads,
+ scale=self.scale,
+ )
return to_out(out)
@@ -1546,10 +1858,10 @@ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, ma
def register_editor(net, count, place_in_unet):
for name, subnet in net.named_children():
- if net.__class__.__name__ == 'Attention': # spatial Transformer layer
+ if net.__class__.__name__ == "Attention": # spatial Transformer layer
net.forward = ca_forward(net, place_in_unet)
return count + 1
- elif hasattr(net, 'children'):
+ elif hasattr(net, "children"):
count = register_editor(subnet, count, place_in_unet)
return count
@@ -1577,12 +1889,12 @@ def __call__(
padding_mask_crop: Optional[int] = None,
strength: float = 0.9999,
AAS: bool = True, # AE parameter
- rm_guidance_scale: float = 7.0, # AE parameter
+ rm_guidance_scale: float = 7.0, # AE parameter
ss_steps: int = 9, # AE parameter
- ss_scale: float = 0.3, # AE parameter
- AAS_start_step: int = 0, # AE parameter
- AAS_start_layer: int = 34, # AE parameter
- AAS_end_layer: int = 70, # AE parameter
+ ss_scale: float = 0.3, # AE parameter
+ AAS_start_step: int = 0, # AE parameter
+ AAS_start_layer: int = 34, # AE parameter
+ AAS_end_layer: int = 70, # AE parameter
num_inference_steps: int = 50,
timesteps: List[int] = None,
denoising_start: Optional[float] = None,
@@ -1592,7 +1904,8 @@ def __call__(
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ generator: Optional[Union[torch.Generator,
+ List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -1613,7 +1926,8 @@ def __call__(
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end: Optional[Callable[[
+ int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
@@ -1843,7 +2157,7 @@ def __call__(
self._denoising_start = denoising_start
self._interrupt = False
- ########### AE parameters
+ # AE parameters
self._num_timesteps = num_inference_steps
self._rm_guidance_scale = rm_guidance_scale
self._AAS = AAS
@@ -1866,7 +2180,9 @@ def __call__(
# 3. Encode input prompt
text_encoder_lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ self.cross_attention_kwargs.get("scale", None)
+ if self.cross_attention_kwargs is not None
+ else None
)
(
@@ -1894,27 +2210,39 @@ def __call__(
def denoising_value_valid(dnv):
return isinstance(dnv, float) and 0 < dnv < 1
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps
+ )
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps,
strength,
device,
- denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
+ denoising_start=(
+ self.denoising_start
+ if denoising_value_valid(self.denoising_start)
+ else None
+ ),
)
- # check that number of inference steps is not < 1 - as this doesn't make sense
+ # check that number of inference steps is not < 1 - as this doesn't
+ # make sense
if num_inference_steps < 1:
raise ValueError(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
- # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
- # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ # at which timestep to set the initial noise (n.b. 50% if strength is
+ # 0.5)
+ latent_timestep = timesteps[:1].repeat(
+ batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then
+ # initialise the latents with pure noise
is_strength_max = strength == 1.0
# 5. Preprocess mask and image
if padding_mask_crop is not None:
- crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ crops_coords = self.mask_processor.get_crop_region(
+ mask_image, width, height, pad=padding_mask_crop
+ )
resize_mode = "fill"
else:
crops_coords = None
@@ -1922,12 +2250,20 @@ def denoising_value_valid(dnv):
original_image = image
init_image = self.image_processor.preprocess(
- image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+ image,
+ height=height,
+ width=width,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
)
init_image = init_image.to(dtype=torch.float32)
mask = self.mask_processor.preprocess(
- mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ mask_image,
+ height=height,
+ width=width,
+ resize_mode=resize_mode,
+ crops_coords=crops_coords,
)
if masked_image_latents is not None:
@@ -1984,7 +2320,10 @@ def denoising_value_valid(dnv):
# default case for runwayml/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
- if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ if (
+ num_channels_latents + num_channels_mask + num_channels_masked_image
+ != self.unet.config.in_channels
+ ):
raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
@@ -1999,7 +2338,8 @@ def denoising_value_valid(dnv):
# 8.1 Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be
+ # moved out of the pipeline
height, width = latents.shape[-2:]
height = height * self.vae_scale_factor
width = width * self.vae_scale_factor
@@ -2031,19 +2371,25 @@ def denoising_value_valid(dnv):
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
- add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = add_time_ids.repeat(
+ batch_size * num_images_per_prompt, 1)
if self.do_classifier_free_guidance:
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
- add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ prompt_embeds = torch.cat(
+ [negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat(
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
+ )
+ add_neg_time_ids = add_neg_time_ids.repeat(
+ batch_size * num_images_per_prompt, 1
+ )
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
-
###########
if self.do_self_attention_redirection_guidance:
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)
- add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0)
+ add_text_embeds = torch.cat(
+ [add_text_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(2, 1)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
############
@@ -2064,13 +2410,24 @@ def denoising_value_valid(dnv):
# apply AAS to modify the attention module
if self.do_self_attention_redirection_guidance:
self._AAS_end_step = int(strength * self._num_timesteps)
- layer_idx=list(range(self._AAS_start_layer, self._AAS_end_layer))
- editor = AAS_XL(self._AAS_start_step, self._AAS_end_step, self._AAS_start_layer, self._AAS_end_layer, layer_idx= layer_idx, mask=mask_image,model_type="SDXL",ss_steps=self._ss_steps,ss_scale=self._ss_scale)
+ layer_idx = list(range(self._AAS_start_layer, self._AAS_end_layer))
+ editor = AAS_XL(
+ self._AAS_start_step,
+ self._AAS_end_step,
+ self._AAS_start_layer,
+ self._AAS_end_layer,
+ layer_idx=layer_idx,
+ mask=mask_image,
+ model_type="SDXL",
+ ss_steps=self._ss_steps,
+ ss_scale=self._ss_scale,
+ )
self.regiter_attention_editor_diffusers(self.unet, editor)
-
# 11. Denoising loop
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ num_warmup_steps = max(
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
+ )
if (
self.denoising_end is not None
@@ -2083,20 +2440,26 @@ def denoising_value_valid(dnv):
f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ f" {self.denoising_end} when using type float."
)
- elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+ elif self.denoising_end is not None and denoising_value_valid(
+ self.denoising_end
+ ):
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
)
)
- num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ num_inference_steps = len(
+ list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
+ )
timesteps = timesteps[:num_inference_steps]
# 11.1 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
+ batch_size * num_images_per_prompt
+ )
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
@@ -2109,21 +2472,37 @@ def denoising_value_valid(dnv):
continue
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = (
+ torch.cat([latents] * 2)
+ if self.do_classifier_free_guidance
+ else latents
+ )
- #removal guidance
- latent_model_input = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents #CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
- #latent_model_input_rm = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents
-
- # concat latents, mask, masked_image_latents in the channel dimension
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
- #latent_model_input = self.scheduler.scale_model_input(latent_model_input_rm, t)
+ # removal guidance
+ latent_model_input = (
+ torch.cat([latents] * 2)
+ if self.do_self_attention_redirection_guidance
+ else latents
+ ) # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+ # latent_model_input_rm = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel
+ # dimension
+ latent_model_input = self.scheduler.scale_model_input(
+ latent_model_input, t
+ )
+ # latent_model_input = self.scheduler.scale_model_input(latent_model_input_rm, t)
if num_channels_unet == 9:
- latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+ latent_model_input = torch.cat(
+ [latent_model_input, mask, masked_image_latents], dim=1
+ )
# predict the noise residual
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ added_cond_kwargs = {
+ "text_embeds": add_text_embeds,
+ "time_ids": add_time_ids,
+ }
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
added_cond_kwargs["image_embeds"] = image_embeds
noise_pred = self.unet(
@@ -2145,16 +2524,22 @@ def denoising_value_valid(dnv):
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
+ noise_pred_text - noise_pred_uncond
+ )
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
-
-
+ noise_pred = rescale_noise_cfg(
+ noise_pred,
+ noise_pred_text,
+ guidance_rescale=self.guidance_rescale,
+ )
# compute the previous noisy sample x_t -> x_t-1
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ latents = self.scheduler.step(
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
+ )[0]
if num_channels_unet == 4:
init_latents_proper = image_latents
@@ -2166,28 +2551,42 @@ def denoising_value_valid(dnv):
if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
init_latents_proper = self.scheduler.add_noise(
- init_latents_proper, noise, torch.tensor([noise_timestep])
+ init_latents_proper, noise, torch.tensor(
+ [noise_timestep])
)
- latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+ latents = (
+ 1 - init_mask
+ ) * init_latents_proper + init_mask * latents
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+ callback_outputs = callback_on_step_end(
+ self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
- add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ prompt_embeds = callback_outputs.pop(
+ "prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop(
+ "negative_prompt_embeds", negative_prompt_embeds
+ )
+ add_text_embeds = callback_outputs.pop(
+ "add_text_embeds", add_text_embeds
+ )
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
- add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
- add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+ add_time_ids = callback_outputs.pop(
+ "add_time_ids", add_time_ids)
+ add_neg_time_ids = callback_outputs.pop(
+ "add_neg_time_ids", add_neg_time_ids
+ )
mask = callback_outputs.pop("mask", mask)
- masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+ masked_image_latents = callback_outputs.pop(
+ "masked_image_latents", masked_image_latents
+ )
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -2201,25 +2600,42 @@ def denoising_value_valid(dnv):
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
- needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ needs_upcasting = (
+ self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ )
latents = latents[-1:]
if needs_upcasting:
self.upcast_vae()
- latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ latents = latents.to(
+ next(iter(self.vae.post_quant_conv.parameters())).dtype
+ )
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
- has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
- has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ has_latents_mean = (
+ hasattr(self.vae.config, "latents_mean")
+ and self.vae.config.latents_mean is not None
+ )
+ has_latents_std = (
+ hasattr(self.vae.config, "latents_std")
+ and self.vae.config.latents_std is not None
+ )
if has_latents_mean and has_latents_std:
latents_mean = (
- torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, 4, 1, 1)
+ .to(latents.device, latents.dtype)
)
latents_std = (
- torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ torch.tensor(self.vae.config.latents_std)
+ .view(1, 4, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents = (
+ latents * latents_std / self.vae.config.scaling_factor
+ + latents_mean
)
- latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
else:
latents = latents / self.vae.config.scaling_factor
@@ -2235,10 +2651,16 @@ def denoising_value_valid(dnv):
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
- image = self.image_processor.postprocess(image, output_type=output_type)
+ image = self.image_processor.postprocess(
+ image, output_type=output_type)
if padding_mask_crop is not None:
- image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
+ image = [
+ self.image_processor.apply_overlay(
+ mask_image, original_image, i, crops_coords
+ )
+ for i in image
+ ]
# Offload all models
self.maybe_free_model_hooks()
From 717e2c711963e587678736b9f572317b90ae235d Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Wed, 15 Jan 2025 17:04:41 +0800
Subject: [PATCH 3/7] make style and add example output
---
examples/community/README.md | 92 +-
...ne_stable_diffusion_xl_attentive_eraser.py | 2285 +++++++++++++++++
2 files changed, 2375 insertions(+), 2 deletions(-)
mode change 100755 => 100644 examples/community/README.md
create mode 100644 examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
diff --git a/examples/community/README.md b/examples/community/README.md
old mode 100755
new mode 100644
index c7c40c46ef2d..02777636d2d3
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -77,6 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
+| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3)|
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -4585,8 +4586,8 @@ image = pipe(
```
|  |  |  |
-| ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- |
-| Gradient | Input | Output |
+| -------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ |
+| Gradient | Input | Output |
A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
@@ -4634,6 +4635,93 @@ make_image_grid(image, rows=1, cols=len(image))
# 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
```
+### Stable Diffusion XL Attentive Eraser Pipeline
+
+
+**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the model’s self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://arxiv.org/abs/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).
+
+#### Key features
+
+- **Tuning-Free**: No additional training is required, making it easy to integrate and use.
+- **Flexible Mask Support**: Works with different types of masks for targeted object removal.
+- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.
+
+#### Usage example
+To use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:
+```py
+import torch
+from diffusers import DDIMScheduler, DiffusionPipeline
+from diffusers.utils import load_image
+import torch.nn.functional as F
+from torchvision.transforms.functional import to_tensor, gaussian_blur
+
+dtype = torch.float16
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
+ scheduler=scheduler,
+ variant="fp16",
+ use_safetensors=True,
+ torch_dtype=dtype,
+).to(device)
+
+
+def preprocess_image(image_path, device):
+ image = to_tensor((load_image(image_path)))
+ image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
+ if image.shape[1] != 3:
+ image = image.expand(-1, 3, -1, -1)
+ image = F.interpolate(image, (1024, 1024))
+ image = image.to(dtype).to(device)
+ return image
+
+def preprocess_mask(mask_path, device):
+ mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
+ mask = mask.unsqueeze_(0).float() # 0 or 1
+ mask = F.interpolate(mask, (1024, 1024))
+ mask = gaussian_blur(mask, kernel_size=(77, 77))
+ mask[mask < 0.1] = 0
+ mask[mask >= 0.1] = 1
+ mask = mask.to(dtype).to(device)
+ return mask
+
+prompt = "" # Set prompt to null
+seed=123
+generator = torch.Generator(device=device).manual_seed(seed)
+source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
+mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
+source_image = preprocess_image(source_image_path, device)
+mask = preprocess_mask(mask_path, device)
+
+image = pipeline(
+ prompt=prompt,
+ image=source_image,
+ mask_image=mask,
+ height=1024,
+ width=1024,
+ AAS=True, # enable AAS
+ strength=0.8, # inpainting strength
+ rm_guidance_scale=9, # removal guidance scale
+ ss_steps = 9, # similarity suppression steps
+ ss_scale = 0.3, # similarity suppression scale
+ AAS_start_step=0, # AAS start step
+ AAS_start_layer=34, # AAS start layer
+ AAS_end_layer=70, # AAS end layer
+ num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
+ generator=generator,
+ guidance_scale=1,
+).images[0]
+image.save('./removed_img.png')
+print("Object removal completed")
+```
+
+| Source Image | Mask | Output |
+| ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
+|  |  |  |
+
# Perturbed-Attention Guidance
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
new file mode 100644
index 000000000000..48c01318991a
--- /dev/null
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -0,0 +1,2285 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from PIL import Image
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_invisible_watermark_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
+ ... torch_dtype=torch.float16,
+ ... variant="fp16",
+ ... use_safetensors=True,
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> init_image = load_image(img_url).convert("RGB")
+ >>> mask_image = load_image(mask_url).convert("RGB")
+
+ >>> prompt = "A majestic tiger sitting on a bench"
+ >>> image = pipe(
+ ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
+ ... ).images[0]
+ ```
+"""
+
+
+class AttentionBase:
+ def __init__(self):
+ self.cur_step = 0
+ self.num_att_layers = -1
+ self.cur_att_layer = 0
+
+ def after_step(self):
+ pass
+
+ def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+ out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+ self.cur_att_layer += 1
+ if self.cur_att_layer == self.num_att_layers:
+ self.cur_att_layer = 0
+ self.cur_step += 1
+ # after step
+ self.after_step()
+ return out
+
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+ out = torch.einsum("b i j, b j d -> b i d", attn, v)
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=num_heads)
+ return out
+
+ def reset(self):
+ self.cur_step = 0
+ self.cur_att_layer = 0
+
+
+class AAS_XL(AttentionBase):
+ MODEL_TYPE = {"SD": 16, "SDXL": 70}
+
+ def __init__(
+ self,
+ start_step=4,
+ end_step=50,
+ start_layer=10,
+ end_layer=16,
+ layer_idx=None,
+ step_idx=None,
+ total_steps=50,
+ mask=None,
+ model_type="SD",
+ ss_steps=9,
+ ss_scale=1.0,
+ ):
+ """
+ Args:
+ start_step: the step to start AAS
+ start_layer: the layer to start AAS
+ layer_idx: list of the layers to apply AAS
+ step_idx: list the steps to apply AAS
+ total_steps: the total number of steps
+ mask: source mask with shape (h, w)
+ model_type: the model type, SD or SDXL
+ """
+ super().__init__()
+ self.total_steps = total_steps
+ self.total_layers = self.MODEL_TYPE.get(model_type, 16)
+ self.start_step = start_step
+ self.end_step = end_step
+ self.start_layer = start_layer
+ self.end_layer = end_layer
+ self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, end_layer))
+ self.step_idx = step_idx if step_idx is not None else list(range(start_step, end_step))
+ self.mask = mask # mask with shape (1, 1 ,h, w)
+ self.ss_steps = ss_steps
+ self.ss_scale = ss_scale
+ print("AAS at denoising steps: ", self.step_idx)
+ print("AAS at U-Net layers: ", self.layer_idx)
+ print("start AAS")
+ self.mask_16 = F.max_pool2d(mask, (1024 // 16, 1024 // 16)).round().squeeze().squeeze()
+ self.mask_32 = F.max_pool2d(mask, (1024 // 32, 1024 // 32)).round().squeeze().squeeze()
+ self.mask_64 = F.max_pool2d(mask, (1024 // 64, 1024 // 64)).round().squeeze().squeeze()
+ self.mask_128 = F.max_pool2d(mask, (1024 // 128, 1024 // 128)).round().squeeze().squeeze()
+
+ def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, is_mask_attn, mask, **kwargs):
+ B = q.shape[0] // num_heads
+ if is_mask_attn:
+ mask_flatten = mask.flatten(0)
+ if self.cur_step <= self.ss_steps:
+ # background
+ sim_bg = sim + mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+
+ # object
+ sim_fg = self.ss_scale * sim
+ sim_fg += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+ sim = torch.cat([sim_fg, sim_bg], dim=0)
+ else:
+ sim += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+
+ attn = sim.softmax(-1)
+ if len(attn) == 2 * len(v):
+ v = torch.cat([v] * 2)
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
+ out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
+ return out
+
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+ """
+ Attention forward function
+ """
+ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+ # B = q.shape[0] // num_heads // 2
+ H = int(np.sqrt(q.shape[1]))
+ # H = W = int(np.sqrt(q.shape[1]))
+ if H == 16:
+ mask = self.mask_16.to(sim.device)
+ elif H == 32:
+ mask = self.mask_32.to(sim.device)
+ elif H == 64:
+ mask = self.mask_64.to(sim.device)
+ else:
+ mask = self.mask_128.to(sim.device)
+
+ q_wo, q_w = q.chunk(2)
+ k_wo, k_w = k.chunk(2)
+ v_wo, v_w = v.chunk(2)
+ sim_wo, sim_w = sim.chunk(2)
+ attn_wo, attn_w = attn.chunk(2)
+
+ out_source = self.attn_batch(
+ q_wo,
+ k_wo,
+ v_wo,
+ sim_wo,
+ attn_wo,
+ is_cross,
+ place_in_unet,
+ num_heads,
+ is_mask_attn=False,
+ mask=None,
+ **kwargs,
+ )
+ out_target = self.attn_batch(
+ q_w, k_w, v_w, sim_w, attn_w, is_cross, place_in_unet, num_heads, is_mask_attn=True, mask=mask, **kwargs
+ )
+
+ if self.mask is not None:
+ if out_target.shape[0] == 2:
+ out_target_fg, out_target_bg = out_target.chunk(2, 0)
+ mask = mask.reshape(-1, 1) # (hw, 1)
+ out_target = out_target_fg * mask + out_target_bg * (1 - mask)
+ else:
+ out_target = out_target
+
+ out = torch.cat([out_source, out_target], dim=0)
+ return out
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+def mask_pil_to_torch(mask, height, width):
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask = torch.from_numpy(mask)
+ return mask
+
+
+def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ # checkpoint. TOD(Yiyi) - need to clean this up later
+ deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
+ deprecate(
+ "prepare_mask_and_masked_image",
+ "0.30.0",
+ deprecation_message,
+ )
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ mask = mask_pil_to_torch(mask, height, width)
+
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ # if image.min() < -1 or image.max() > 1:
+ # raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = mask_pil_to_torch(mask, height, width)
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ if image.shape[1] == 4:
+ # images are in latent space and thus can't
+ # be masked set masked_image to None
+ # we assume that the checkpoint is not an inpainting
+ # checkpoint. TOD(Yiyi) - need to clean this up later
+ masked_image = None
+ else:
+ masked_image = image * (mask < 0.5)
+
+ # n.b. ensure backwards compatibility as old function does not return image
+ if return_image:
+ return mask, masked_image, image
+
+ return mask, masked_image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionXL_AE_Pipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ FromSingleFileMixin,
+ IPAdapterMixin,
+):
+ r"""
+ Pipeline for object removal using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+ Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
+ of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "image_encoder",
+ "feature_extractor",
+ ]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ "negative_pooled_prompt_embeds",
+ "add_neg_time_ids",
+ "mask",
+ "masked_image_latents",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack(
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
+ )
+
+ if do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+ else:
+ repeat_dims = [1]
+ image_embeds = []
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+ )
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ else:
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ height,
+ width,
+ strength,
+ callback_steps,
+ output_type,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ add_noise=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if image.shape[1] == 4:
+ image_latents = image.to(device=device, dtype=dtype)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+ elif return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None and add_noise:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ elif add_noise:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+ else:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = image_latents.to(device)
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ dtype = image.dtype
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ # mask = torch.nn.functional.interpolate(
+ # mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ # )
+ mask = torch.nn.functional.max_pool2d(mask, (8, 8)).round()
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+ if masked_image is not None and masked_image.shape[1] == 4:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = None
+
+ if masked_image is not None:
+ if masked_image_latents is None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+ else:
+ t_start = 0
+
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ # Strength is irrelevant if we directly request a timestep to start at;
+ # that is, strength is determined by the denoising_start instead.
+ if denoising_start is not None:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
+ )
+ )
+
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
+ # if the scheduler is a 2nd order scheduler we might have to do +1
+ # because `num_inference_steps` might be even given that every timestep
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+ # mean that we cut the timesteps in the middle of the denoising step
+ # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+ num_inference_steps = num_inference_steps + 1
+
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
+ timesteps = timesteps[-num_inference_steps:]
+ return timesteps, num_inference_steps
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+ def _get_add_time_ids(
+ self,
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype,
+ text_encoder_projection_dim=None,
+ ):
+ if self.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+ )
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ @property
+ def do_self_attention_redirection_guidance(self): # SARG
+ return self._rm_guidance_scale > 1 and self._AAS
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return (
+ self._guidance_scale > 1
+ and self.unet.config.time_cond_proj_dim is None
+ and not self.do_self_attention_redirection_guidance
+ ) # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
+ @property
+ def denoising_start(self):
+ return self._denoising_start
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ def image2latent(self, image: torch.Tensor, generator: torch.Generator):
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ if type(image) is Image:
+ image = np.array(image)
+ image = torch.from_numpy(image).float() / 127.5 - 1
+ image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
+ # input image density range [-1, 1]
+ # latents = self.vae.encode(image)['latent_dist'].mean
+ latents = self._encode_vae_image(image, generator)
+ # latents = retrieve_latents(self.vae.encode(image))
+ # latents = latents * self.vae.config.scaling_factor
+ return latents
+
+ def next_step(self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta=0.0, verbose=False):
+ """
+ Inverse sampling for DDIM Inversion
+ """
+ if verbose:
+ print("timestep: ", timestep)
+ next_step = timestep
+ timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
+ beta_prod_t = 1 - alpha_prod_t
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+ pred_dir = (1 - alpha_prod_t_next) ** 0.5 * model_output
+ x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
+ return x_next, pred_x0
+
+ @torch.no_grad()
+ def invert(
+ self,
+ image: torch.Tensor,
+ prompt,
+ num_inference_steps=50,
+ eta=0.0,
+ original_size: Tuple[int, int] = None,
+ target_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ return_intermediates=False,
+ **kwds,
+ ):
+ """
+ invert a real image into noise map with determinisc DDIM inversion
+ """
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ batch_size = image.shape[0]
+ if isinstance(prompt, list):
+ if batch_size == 1:
+ image = image.expand(len(prompt), -1, -1, -1)
+ elif isinstance(prompt, str):
+ if batch_size > 1:
+ prompt = [prompt] * batch_size
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ prompt_2 = prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(DEVICE), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ prompt_embeds_list.append(prompt_embeds)
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=DEVICE)
+
+ # define initial latents
+ latents = self.image2latent(image, generator=None)
+
+ start_latents = latents
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = (height, width)
+ target_size = (height, width)
+ negative_original_size = original_size
+ negative_target_size = target_size
+
+ add_text_embeds = pooled_prompt_embeds
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+
+ add_time_ids = add_time_ids.repeat(batch_size, 1).to(DEVICE)
+
+ # interative sampling
+ self.scheduler.set_timesteps(num_inference_steps)
+ latents_list = [latents]
+ pred_x0_list = []
+ # for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
+ for i, t in enumerate(reversed(self.scheduler.timesteps)):
+ model_inputs = latents
+
+ # predict the noise
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ model_inputs, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs
+ ).sample
+
+ # compute the previous noise sample x_t-1 -> x_t
+ latents, pred_x0 = self.next_step(noise_pred, t, latents)
+ """
+ if t >= 1 and t < 41:
+ latents, pred_x0 = self.next_step_degrade(noise_pred, t, latents, mask)
+ else:
+ latents, pred_x0 = self.next_step(noise_pred, t, latents) """
+
+ latents_list.append(latents)
+ pred_x0_list.append(pred_x0)
+
+ if return_intermediates:
+ # return the intermediate laters during inversion
+ # pred_x0_list = [self.latent2image(img, return_type="np") for img in pred_x0_list]
+ # latents_list = [self.latent2image(img, return_type="np") for img in latents_list]
+ return latents, latents_list, pred_x0_list
+ return latents, start_latents
+
+ def opt(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ x: torch.FloatTensor,
+ ):
+ """
+ predict the sampe the next step in the denoise process.
+ """
+ ref_noise = model_output[:1, :, :, :].expand(model_output.shape)
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+ x_opt = alpha_prod_t**0.5 * pred_x0 + (1 - alpha_prod_t) ** 0.5 * ref_noise
+ return x_opt, pred_x0
+
+ def regiter_attention_editor_diffusers(self, unet, editor: AttentionBase):
+ """
+ Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
+ """
+
+ def ca_forward(self, place_in_unet):
+ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
+ """
+ The attention is similar to the original implementation of LDM CrossAttention class
+ except adding some modifications on the attention
+ """
+ if encoder_hidden_states is not None:
+ context = encoder_hidden_states
+ if attention_mask is not None:
+ mask = attention_mask
+
+ to_out = self.to_out
+ if isinstance(to_out, nn.modules.container.ModuleList):
+ to_out = self.to_out[0]
+ else:
+ to_out = self.to_out
+
+ h = self.heads
+ q = self.to_q(x)
+ is_cross = context is not None
+ context = context if is_cross else x
+ k = self.to_k(context)
+ v = self.to_v(context)
+ # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ q, k, v = (rearrange(t, "b n (h d) -> (b h) n d", h=h) for t in (q, k, v))
+
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
+
+ if mask is not None:
+ mask = rearrange(mask, "b ... -> b (...)")
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
+ mask = mask[:, None, :].repeat(h, 1, 1)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ attn = sim.softmax(dim=-1)
+ # the only difference
+ out = editor(q, k, v, sim, attn, is_cross, place_in_unet, self.heads, scale=self.scale)
+
+ return to_out(out)
+
+ return forward
+
+ def register_editor(net, count, place_in_unet):
+ for name, subnet in net.named_children():
+ if net.__class__.__name__ == "Attention": # spatial Transformer layer
+ net.forward = ca_forward(net, place_in_unet)
+ return count + 1
+ elif hasattr(net, "children"):
+ count = register_editor(subnet, count, place_in_unet)
+ return count
+
+ cross_att_count = 0
+ for net_name, net in unet.named_children():
+ if "down" in net_name:
+ cross_att_count += register_editor(net, 0, "down")
+ elif "mid" in net_name:
+ cross_att_count += register_editor(net, 0, "mid")
+ elif "up" in net_name:
+ cross_att_count += register_editor(net, 0, "up")
+ editor.num_att_layers = cross_att_count
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: torch.FloatTensor = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ padding_mask_crop: Optional[int] = None,
+ strength: float = 0.9999,
+ AAS: bool = True, # AE parameter
+ rm_guidance_scale: float = 7.0, # AE parameter
+ ss_steps: int = 9, # AE parameter
+ ss_scale: float = 0.3, # AE parameter
+ AAS_start_step: int = 0, # AE parameter
+ AAS_start_layer: int = 34, # AE parameter
+ AAS_end_layer: int = 70, # AE parameter
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
+ `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
+ contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
+ the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
+ and contain information inreleant for inpainging, such as background.
+ strength (`float`, *optional*, defaults to 0.9999):
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+ portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
+ integer, the value of `strength` will be ignored.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
+ Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
+ if `do_classifier_free_guidance` is set to `True`.
+ If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ height,
+ width,
+ strength,
+ callback_steps,
+ output_type,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs,
+ padding_mask_crop,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._denoising_start = denoising_start
+ self._interrupt = False
+
+ ########### AE parameters
+ self._num_timesteps = num_inference_steps
+ self._rm_guidance_scale = rm_guidance_scale
+ self._AAS = AAS
+ self._ss_steps = ss_steps
+ self._ss_scale = ss_scale
+ self._AAS_start_step = AAS_start_step
+ self._AAS_start_layer = AAS_start_layer
+ self._AAS_end_layer = AAS_end_layer
+ ###########
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. set timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(dnv, float) and 0 < dnv < 1
+
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps,
+ strength,
+ device,
+ denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
+ )
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ original_image = image
+ init_image = self.image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ init_image = init_image.to(dtype=torch.float32)
+
+ mask = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ if masked_image_latents is not None:
+ masked_image = masked_image_latents
+ elif init_image.shape[1] == 4:
+ # if images are in latent space, we can't mask it
+ masked_image = None
+ else:
+ masked_image = init_image * (mask < 0.5)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ add_noise = True if self.denoising_start is None else False
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ add_noise=add_noise,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+ # 8.1 Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 10. Prepare added time ids & embeddings
+ if negative_original_size is None:
+ negative_original_size = original_size
+ if negative_target_size is None:
+ negative_target_size = target_size
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ ###########
+ if self.do_self_attention_redirection_guidance:
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(2, 1)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+ ############
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # apply AAS to modify the attention module
+ if self.do_self_attention_redirection_guidance:
+ self._AAS_end_step = int(strength * self._num_timesteps)
+ layer_idx = list(range(self._AAS_start_layer, self._AAS_end_layer))
+ editor = AAS_XL(
+ self._AAS_start_step,
+ self._AAS_end_step,
+ self._AAS_start_layer,
+ self._AAS_end_layer,
+ layer_idx=layer_idx,
+ mask=mask_image,
+ model_type="SDXL",
+ ss_steps=self._ss_steps,
+ ss_scale=self._ss_scale,
+ )
+ self.regiter_attention_editor_diffusers(self.unet, editor)
+
+ # 11. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ if (
+ self.denoising_end is not None
+ and self.denoising_start is not None
+ and denoising_value_valid(self.denoising_end)
+ and denoising_value_valid(self.denoising_start)
+ and self.denoising_start >= self.denoising_end
+ ):
+ raise ValueError(
+ f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {self.denoising_end} when using type float."
+ )
+ elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 11.1 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ # removal guidance
+ latent_model_input = (
+ torch.cat([latents] * 2) if self.do_self_attention_redirection_guidance else latents
+ ) # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+ # latent_model_input_rm = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ # latent_model_input = self.scheduler.scale_model_input(latent_model_input_rm, t)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform SARG
+ if self.do_self_attention_redirection_guidance:
+ noise_pred_wo, noise_pred_w = noise_pred.chunk(2)
+ delta = noise_pred_w - noise_pred_wo
+ noise_pred = noise_pred_wo + self._rm_guidance_scale * delta
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents
+ if self.do_classifier_free_guidance:
+ init_mask, _ = mask.chunk(2)
+ else:
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+ mask = callback_outputs.pop("mask", mask)
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ latents = latents[-1:]
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ return StableDiffusionXLPipelineOutput(images=latents)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if padding_mask_crop is not None:
+ image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
From 1b47ff7e00d77822cdda353c09f60170de3805d7 Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Thu, 16 Jan 2025 22:00:09 +0800
Subject: [PATCH 4/7] update Docs
Co-authored-by: Other Contributor
---
examples/community/README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/community/README.md b/examples/community/README.md
index 02777636d2d3..a62d49f0d2c0 100644
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -77,7 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
-| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3)|
+| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
From 72ee35e4be022d1f8f287480ab31bfe09022b67a Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Sat, 18 Jan 2025 19:21:43 +0800
Subject: [PATCH 5/7] add Oral
Co-authored-by: Other Contributor
---
examples/community/README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/community/README.md b/examples/community/README.md
index a62d49f0d2c0..4c593a004893 100644
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -77,7 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
-| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
+| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
From e5fc2e40e43bd132b70d1f236dc0c9cfbd6fd2bc Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Fri, 24 Jan 2025 20:46:31 +0800
Subject: [PATCH 6/7] update_review
Co-authored-by: Other Contributor
---
...ne_stable_diffusion_xl_attentive_eraser.py | 96 ++++++++++++-------
1 file changed, 64 insertions(+), 32 deletions(-)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
index 48c01318991a..f78cedb9965a 100644
--- a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -81,27 +81,72 @@
Examples:
```py
>>> import torch
- >>> from diffusers import StableDiffusionXLInpaintPipeline
+ >>> from diffusers import DDIMScheduler, DiffusionPipeline
>>> from diffusers.utils import load_image
-
- >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
- ... "stabilityai/stable-diffusion-xl-base-1.0",
- ... torch_dtype=torch.float16,
- ... variant="fp16",
- ... use_safetensors=True,
- ... )
- >>> pipe.to("cuda")
-
- >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
- >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
-
- >>> init_image = load_image(img_url).convert("RGB")
- >>> mask_image = load_image(mask_url).convert("RGB")
-
- >>> prompt = "A majestic tiger sitting on a bench"
- >>> image = pipe(
- ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
+ >>> import torch.nn.functional as F
+ >>> from torchvision.transforms.functional import to_tensor, gaussian_blur
+
+ >>> dtype = torch.float16
+ >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ >>> scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+
+ >>> pipeline = DiffusionPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
+ ... custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
+ ... scheduler=scheduler,
+ ... variant="fp16",
+ ... use_safetensors=True,
+ ... torch_dtype=dtype,
+ ... ).to(device)
+
+
+ >>> def preprocess_image(image_path, device):
+ ... image = to_tensor((load_image(image_path)))
+ ... image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
+ ... if image.shape[1] != 3:
+ ... image = image.expand(-1, 3, -1, -1)
+ ... image = F.interpolate(image, (1024, 1024))
+ ... image = image.to(dtype).to(device)
+ ... return image
+
+ >>> def preprocess_mask(mask_path, device):
+ ... mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
+ ... mask = mask.unsqueeze_(0).float() # 0 or 1
+ ... mask = F.interpolate(mask, (1024, 1024))
+ ... mask = gaussian_blur(mask, kernel_size=(77, 77))
+ ... mask[mask < 0.1] = 0
+ ... mask[mask >= 0.1] = 1
+ ... mask = mask.to(dtype).to(device)
+ ... return mask
+
+ >>> prompt = "" # Set prompt to null
+ >>> seed=123
+ >>> generator = torch.Generator(device=device).manual_seed(seed)
+ >>> source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
+ >>> mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
+ >>> source_image = preprocess_image(source_image_path, device)
+ >>> mask = preprocess_mask(mask_path, device)
+
+ >>> image = pipeline(
+ ... prompt=prompt,
+ ... image=source_image,
+ ... mask_image=mask,
+ ... height=1024,
+ ... width=1024,
+ ... AAS=True, # enable AAS
+ ... strength=0.8, # inpainting strength
+ ... rm_guidance_scale=9, # removal guidance scale
+ ... ss_steps = 9, # similarity suppression steps
+ ... ss_scale = 0.3, # similarity suppression scale
+ ... AAS_start_step=0, # AAS start step
+ ... AAS_start_layer=34, # AAS start layer
+ ... AAS_end_layer=70, # AAS end layer
+ ... num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
+ ... generator=generator,
+ ... guidance_scale=1,
... ).images[0]
+ >>> image.save('./removed_img.png')
+ >>> print("Object removal completed")
```
"""
@@ -174,9 +219,6 @@ def __init__(
self.mask = mask # mask with shape (1, 1 ,h, w)
self.ss_steps = ss_steps
self.ss_scale = ss_scale
- print("AAS at denoising steps: ", self.step_idx)
- print("AAS at U-Net layers: ", self.layer_idx)
- print("start AAS")
self.mask_16 = F.max_pool2d(mask, (1024 // 16, 1024 // 16)).round().squeeze().squeeze()
self.mask_32 = F.max_pool2d(mask, (1024 // 32, 1024 // 32)).round().squeeze().squeeze()
self.mask_64 = F.max_pool2d(mask, (1024 // 64, 1024 // 64)).round().squeeze().squeeze()
@@ -209,10 +251,7 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
- return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
- # B = q.shape[0] // num_heads // 2
H = int(np.sqrt(q.shape[1]))
- # H = W = int(np.sqrt(q.shape[1]))
if H == 16:
mask = self.mask_16.to(sim.device)
elif H == 32:
@@ -317,13 +356,6 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
dimensions: ``batch x channels x height x width``.
"""
- # checkpoint. TOD(Yiyi) - need to clean this up later
- deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
- deprecate(
- "prepare_mask_and_masked_image",
- "0.30.0",
- deprecation_message,
- )
if image is None:
raise ValueError("`image` input cannot be undefined.")
From bb52aeb717c7dd8d85e249fbf57564687e1f3333 Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Fri, 24 Jan 2025 21:34:26 +0800
Subject: [PATCH 7/7] update_review_ms
Co-authored-by: Other Contributor
---
.../pipeline_stable_diffusion_xl_attentive_eraser.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
index f78cedb9965a..1269a69f0dc3 100644
--- a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -87,7 +87,7 @@
>>> from torchvision.transforms.functional import to_tensor, gaussian_blur
>>> dtype = torch.float16
- >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
>>> scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
>>> pipeline = DiffusionPipeline.from_pretrained(
@@ -120,7 +120,7 @@
... return mask
>>> prompt = "" # Set prompt to null
- >>> seed=123
+ >>> seed=123
>>> generator = torch.Generator(device=device).manual_seed(seed)
>>> source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
>>> mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
@@ -128,7 +128,7 @@
>>> mask = preprocess_mask(mask_path, device)
>>> image = pipeline(
- ... prompt=prompt,
+ ... prompt=prompt,
... image=source_image,
... mask_image=mask,
... height=1024,
@@ -251,6 +251,7 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
H = int(np.sqrt(q.shape[1]))
if H == 16:
mask = self.mask_16.to(sim.device)